comparison fitted_model_eval.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children
comparison
equal deleted inserted replaced
2:38c4f8a98038 3:0a1812986bc3
1 import argparse 1 import argparse
2 import json 2 import json
3 import pandas as pd
4 import warnings 3 import warnings
5 4
5 import pandas as pd
6 from galaxy_ml.model_persist import load_model_from_h5
7 from galaxy_ml.utils import clean_params, get_scoring, read_columns
6 from scipy.io import mmread 8 from scipy.io import mmread
7 from sklearn.pipeline import Pipeline 9 from sklearn.metrics._scorer import _check_multimetric_scoring
8 from sklearn.metrics.scorer import _check_multimetric_scoring
9 from sklearn.model_selection._validation import _score 10 from sklearn.model_selection._validation import _score
10 from galaxy_ml.utils import get_scoring, load_model, read_columns
11 11
12 12
13 def _get_X_y(params, infile1, infile2): 13 def _get_X_y(params, infile1, infile2):
14 """ read from inputs and output X and y 14 """read from inputs and output X and y
15 15
16 Parameters 16 Parameters
17 ---------- 17 ----------
18 params : dict 18 params : dict
19 Tool inputs parameter 19 Tool inputs parameter
24 24
25 """ 25 """
26 # store read dataframe object 26 # store read dataframe object
27 loaded_df = {} 27 loaded_df = {}
28 28
29 input_type = params['input_options']['selected_input'] 29 input_type = params["input_options"]["selected_input"]
30 # tabular input 30 # tabular input
31 if input_type == 'tabular': 31 if input_type == "tabular":
32 header = 'infer' if params['input_options']['header1'] else None 32 header = "infer" if params["input_options"]["header1"] else None
33 column_option = (params['input_options']['column_selector_options_1'] 33 column_option = params["input_options"]["column_selector_options_1"][
34 ['selected_column_selector_option']) 34 "selected_column_selector_option"
35 if column_option in ['by_index_number', 'all_but_by_index_number', 35 ]
36 'by_header_name', 'all_but_by_header_name']: 36 if column_option in [
37 c = params['input_options']['column_selector_options_1']['col1'] 37 "by_index_number",
38 "all_but_by_index_number",
39 "by_header_name",
40 "all_but_by_header_name",
41 ]:
42 c = params["input_options"]["column_selector_options_1"]["col1"]
38 else: 43 else:
39 c = None 44 c = None
40 45
41 df_key = infile1 + repr(header) 46 df_key = infile1 + repr(header)
42 df = pd.read_csv(infile1, sep='\t', header=header, 47 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
43 parse_dates=True)
44 loaded_df[df_key] = df 48 loaded_df[df_key] = df
45 49
46 X = read_columns(df, c=c, c_option=column_option).astype(float) 50 X = read_columns(df, c=c, c_option=column_option).astype(float)
47 # sparse input 51 # sparse input
48 elif input_type == 'sparse': 52 elif input_type == "sparse":
49 X = mmread(open(infile1, 'r')) 53 X = mmread(open(infile1, "r"))
50 54
51 # Get target y 55 # Get target y
52 header = 'infer' if params['input_options']['header2'] else None 56 header = "infer" if params["input_options"]["header2"] else None
53 column_option = (params['input_options']['column_selector_options_2'] 57 column_option = params["input_options"]["column_selector_options_2"][
54 ['selected_column_selector_option2']) 58 "selected_column_selector_option2"
55 if column_option in ['by_index_number', 'all_but_by_index_number', 59 ]
56 'by_header_name', 'all_but_by_header_name']: 60 if column_option in [
57 c = params['input_options']['column_selector_options_2']['col2'] 61 "by_index_number",
62 "all_but_by_index_number",
63 "by_header_name",
64 "all_but_by_header_name",
65 ]:
66 c = params["input_options"]["column_selector_options_2"]["col2"]
58 else: 67 else:
59 c = None 68 c = None
60 69
61 df_key = infile2 + repr(header) 70 df_key = infile2 + repr(header)
62 if df_key in loaded_df: 71 if df_key in loaded_df:
63 infile2 = loaded_df[df_key] 72 infile2 = loaded_df[df_key]
64 else: 73 else:
65 infile2 = pd.read_csv(infile2, sep='\t', 74 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
66 header=header, parse_dates=True)
67 loaded_df[df_key] = infile2 75 loaded_df[df_key] = infile2
68 76
69 y = read_columns( 77 y = read_columns(
70 infile2, 78 infile2,
71 c=c, 79 c=c,
72 c_option=column_option, 80 c_option=column_option,
73 sep='\t', 81 sep="\t",
74 header=header, 82 header=header,
75 parse_dates=True) 83 parse_dates=True,
84 )
76 if len(y.shape) == 2 and y.shape[1] == 1: 85 if len(y.shape) == 2 and y.shape[1] == 1:
77 y = y.ravel() 86 y = y.ravel()
78 87
79 return X, y 88 return X, y
80 89
81 90
82 def main(inputs, infile_estimator, outfile_eval, 91 def main(inputs, infile_estimator, outfile_eval, infile1=None, infile2=None):
83 infile_weights=None, infile1=None,
84 infile2=None):
85 """ 92 """
86 Parameter 93 Parameter
87 --------- 94 ---------
88 inputs : str 95 inputs : str
89 File path to galaxy tool parameter 96 File path to galaxy tool parameter
92 File path to trained estimator input 99 File path to trained estimator input
93 100
94 outfile_eval : str 101 outfile_eval : str
95 File path to save the evalulation results, tabular 102 File path to save the evalulation results, tabular
96 103
97 infile_weights : str
98 File path to weights input
99
100 infile1 : str 104 infile1 : str
101 File path to dataset containing features 105 File path to dataset containing features
102 106
103 infile2 : str 107 infile2 : str
104 File path to dataset containing target values 108 File path to dataset containing target values
105 """ 109 """
106 warnings.filterwarnings('ignore') 110 warnings.filterwarnings("ignore")
107 111
108 with open(inputs, 'r') as param_handler: 112 with open(inputs, "r") as param_handler:
109 params = json.load(param_handler) 113 params = json.load(param_handler)
110 114
111 X_test, y_test = _get_X_y(params, infile1, infile2) 115 X_test, y_test = _get_X_y(params, infile1, infile2)
112 116
113 # load model 117 # load model
114 with open(infile_estimator, 'rb') as est_handler: 118 estimator = load_model_from_h5(infile_estimator)
115 estimator = load_model(est_handler) 119 estimator = clean_params(estimator)
116
117 main_est = estimator
118 if isinstance(estimator, Pipeline):
119 main_est = estimator.steps[-1][-1]
120 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'):
121 if not infile_weights or infile_weights == 'None':
122 raise ValueError("The selected model skeleton asks for weights, "
123 "but no dataset for weights was provided!")
124 main_est.load_weights(infile_weights)
125 120
126 # handle scorer, convert to scorer dict 121 # handle scorer, convert to scorer dict
127 scoring = params['scoring'] 122 scoring = params["scoring"]
128 scorer = get_scoring(scoring) 123 scorer = get_scoring(scoring)
129 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 124 if not isinstance(scorer, (dict, list)):
125 scorer = [scoring["primary_scoring"]]
126 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
130 127
131 if hasattr(estimator, 'evaluate'): 128 if hasattr(estimator, "evaluate"):
132 scores = estimator.evaluate(X_test, y_test=y_test, 129 scores = estimator.evaluate(X_test, y_test=y_test, scorer=scorer)
133 scorer=scorer,
134 is_multimetric=True)
135 else: 130 else:
136 scores = _score(estimator, X_test, y_test, scorer, 131 scores = _score(estimator, X_test, y_test, scorer)
137 is_multimetric=True)
138 132
139 # handle output 133 # handle output
140 for name, score in scores.items(): 134 for name, score in scores.items():
141 scores[name] = [score] 135 scores[name] = [score]
142 df = pd.DataFrame(scores) 136 df = pd.DataFrame(scores)
143 df = df[sorted(df.columns)] 137 df = df[sorted(df.columns)]
144 df.to_csv(path_or_buf=outfile_eval, sep='\t', 138 df.to_csv(path_or_buf=outfile_eval, sep="\t", header=True, index=False)
145 header=True, index=False)
146 139
147 140
148 if __name__ == '__main__': 141 if __name__ == "__main__":
149 aparser = argparse.ArgumentParser() 142 aparser = argparse.ArgumentParser()
150 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 143 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
151 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") 144 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
152 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
153 aparser.add_argument("-X", "--infile1", dest="infile1") 145 aparser.add_argument("-X", "--infile1", dest="infile1")
154 aparser.add_argument("-y", "--infile2", dest="infile2") 146 aparser.add_argument("-y", "--infile2", dest="infile2")
155 aparser.add_argument("-O", "--outfile_eval", dest="outfile_eval") 147 aparser.add_argument("-O", "--outfile_eval", dest="outfile_eval")
156 args = aparser.parse_args() 148 args = aparser.parse_args()
157 149
158 main(args.inputs, args.infile_estimator, args.outfile_eval, 150 main(
159 infile_weights=args.infile_weights, infile1=args.infile1, 151 args.inputs,
160 infile2=args.infile2) 152 args.infile_estimator,
153 args.outfile_eval,
154 infile1=args.infile1,
155 infile2=args.infile2,
156 )