实现水体含量反演程序。
This commit is contained in:
169
retrieval.py
Normal file
169
retrieval.py
Normal file
@ -0,0 +1,169 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from osgeo import gdal
|
||||
from util import *
|
||||
from type_define import *
|
||||
import math
|
||||
from pyproj import CRS
|
||||
from pyproj import Transformer
|
||||
import argparse
|
||||
|
||||
|
||||
def find_index(wavelength, array):
|
||||
differences = np.abs(array - wavelength)
|
||||
min_position = np.argmin(differences)
|
||||
return min_position
|
||||
|
||||
|
||||
def get_mean_value(index, array, window):
|
||||
window = int(window)
|
||||
result = array[1:, index - window:index + window + 1].mean(axis=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def calculate(x1, x2, coefficients):
|
||||
A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1))))
|
||||
y_pred = A.dot(coefficients)
|
||||
|
||||
return y_pred
|
||||
|
||||
|
||||
def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5):
|
||||
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
|
||||
coor_spectral = np.loadtxt(coor_spectral_path)
|
||||
|
||||
wave1 = 651
|
||||
index_wave1 = find_index(wave1, coor_spectral[0, :])
|
||||
band_651 = get_mean_value(index_wave1, coor_spectral, window)
|
||||
|
||||
wave2 = 707
|
||||
index_wave2 = find_index(wave2, coor_spectral[0, :])
|
||||
band_707 = get_mean_value(index_wave2, coor_spectral, window)
|
||||
|
||||
wave3 = 670
|
||||
index_wave3 = find_index(wave3, coor_spectral[0, :])
|
||||
band_670 = get_mean_value(index_wave3, coor_spectral, window)
|
||||
|
||||
x = (band_651 - band_707) / (band_707 - band_670)
|
||||
retrieval_result = np.polyval(model_info, list(x))
|
||||
|
||||
position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1))))
|
||||
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
|
||||
|
||||
return position_content
|
||||
|
||||
|
||||
def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5):
|
||||
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
|
||||
coor_spectral = np.loadtxt(coor_spectral_path)
|
||||
|
||||
wave1 = 600
|
||||
index_wave1 = find_index(wave1, coor_spectral[0, :])
|
||||
band_600 = get_mean_value(index_wave1, coor_spectral, window)
|
||||
|
||||
wave2 = 500
|
||||
index_wave2 = find_index(wave2, coor_spectral[0, :])
|
||||
band_500 = get_mean_value(index_wave2, coor_spectral, window)
|
||||
|
||||
wave3 = 850
|
||||
index_wave3 = find_index(wave3, coor_spectral[0, :])
|
||||
band_850 = get_mean_value(index_wave3, coor_spectral, window)
|
||||
|
||||
x13 = np.log(band_500 / band_850)
|
||||
x23 = np.exp(band_600 / band_500)
|
||||
retrieval_result = calculate(x13, x23, model_info)
|
||||
|
||||
position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1))))
|
||||
if output_path is not None:
|
||||
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
|
||||
|
||||
return position_content
|
||||
|
||||
|
||||
def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5):
|
||||
position_content = retrieval_nh3(model_info_path, coor_spectral_path, window=window)
|
||||
|
||||
tmp = np.exp(position_content[:, -1])
|
||||
position_content[:, -1] = tmp
|
||||
|
||||
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
|
||||
|
||||
return position_content
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="此程序使用修正后的模型进行反演。")
|
||||
|
||||
# parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True)
|
||||
|
||||
# 创建子命令解析器
|
||||
subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode")
|
||||
|
||||
chl_a_ = subparsers.add_parser("chl_a", help="叶绿素")
|
||||
chl_a_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
||||
chl_a_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
||||
help='输入坐标-光谱文件的路径')
|
||||
chl_a_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
||||
chl_a_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
||||
chl_a_.set_defaults(func=retrieval_chl_a)
|
||||
|
||||
nh3_ = subparsers.add_parser("nh3", help="氨氮")
|
||||
nh3_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
||||
nh3_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
||||
help='输入坐标-光谱文件的路径')
|
||||
nh3_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
||||
nh3_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
||||
nh3_.set_defaults(func=retrieval_nh3)
|
||||
|
||||
mno4_ = subparsers.add_parser("mno4", help="高猛酸盐")
|
||||
mno4_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
||||
mno4_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
||||
help='输入坐标-光谱文件的路径')
|
||||
mno4_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
||||
mno4_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
||||
mno4_.set_defaults(func=retrieval_nh3)
|
||||
|
||||
tn_ = subparsers.add_parser("tn", help="总氮")
|
||||
tn_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
||||
tn_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
||||
help='输入坐标-光谱文件的路径')
|
||||
tn_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
||||
tn_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
||||
tn_.set_defaults(func=retrieval_nh3)
|
||||
|
||||
tp_ = subparsers.add_parser("tp", help="总磷")
|
||||
tp_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
||||
tp_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
||||
help='输入坐标-光谱文件的路径')
|
||||
tp_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
||||
tp_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
||||
tp_.set_defaults(func=retrieval_nh3)
|
||||
|
||||
tss_ = subparsers.add_parser("tss", help="总悬浮物")
|
||||
tss_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
||||
tss_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
||||
help='输入坐标-光谱文件的路径')
|
||||
tss_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
||||
tss_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
||||
tss_.set_defaults(func=retrieval_tss)
|
||||
|
||||
# 解析参数
|
||||
args = parser.parse_args()
|
||||
if args.algorithm == "chl_a":
|
||||
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
||||
elif args.algorithm == "nh3":
|
||||
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
||||
elif args.algorithm == "mno4":
|
||||
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
||||
elif args.algorithm == "tn":
|
||||
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
||||
elif args.algorithm == "tp":
|
||||
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
||||
elif args.algorithm == "tss":
|
||||
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
||||
|
||||
|
||||
# Press the green button in the gutter to run the script.
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user