comparison simple_model_fit.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children
comparison
equal deleted inserted replaced
2:38c4f8a98038 3:0a1812986bc3
1 import argparse 1 import argparse
2 import json 2 import json
3
3 import pandas as pd 4 import pandas as pd
4 import pickle 5 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
5 6 from galaxy_ml.utils import read_columns
6 from galaxy_ml.utils import load_model, read_columns 7 from scipy.io import mmread
7 from sklearn.pipeline import Pipeline 8 from sklearn.pipeline import Pipeline
8 9
9 10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
11 11
12 12
13 # TODO import from galaxy_ml.utils in future versions 13 # TODO import from galaxy_ml.utils in future versions
14 def clean_params(estimator, n_jobs=None): 14 def clean_params(estimator, n_jobs=None):
15 """clean unwanted hyperparameter settings 15 """clean unwanted hyperparameter settings
18 18
19 Return 19 Return
20 ------ 20 ------
21 Cleaned estimator object 21 Cleaned estimator object
22 """ 22 """
23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 23 ALLOWED_CALLBACKS = (
24 'ReduceLROnPlateau', 'CSVLogger', 'None') 24 "EarlyStopping",
25 "TerminateOnNaN",
26 "ReduceLROnPlateau",
27 "CSVLogger",
28 "None",
29 )
25 30
26 estimator_params = estimator.get_params() 31 estimator_params = estimator.get_params()
27 32
28 for name, p in estimator_params.items(): 33 for name, p in estimator_params.items():
29 # all potential unauthorized file write 34 # all potential unauthorized file write
30 if name == 'memory' or name.endswith('__memory') \ 35 if name == "memory" or name.endswith("__memory") or name.endswith("_path"):
31 or name.endswith('_path'):
32 new_p = {name: None} 36 new_p = {name: None}
33 estimator.set_params(**new_p) 37 estimator.set_params(**new_p)
34 elif n_jobs is not None and (name == 'n_jobs' or 38 elif n_jobs is not None and (name == "n_jobs" or name.endswith("__n_jobs")):
35 name.endswith('__n_jobs')):
36 new_p = {name: n_jobs} 39 new_p = {name: n_jobs}
37 estimator.set_params(**new_p) 40 estimator.set_params(**new_p)
38 elif name.endswith('callbacks'): 41 elif name.endswith("callbacks"):
39 for cb in p: 42 for cb in p:
40 cb_type = cb['callback_selection']['callback_type'] 43 cb_type = cb["callback_selection"]["callback_type"]
41 if cb_type not in ALLOWED_CALLBACKS: 44 if cb_type not in ALLOWED_CALLBACKS:
42 raise ValueError( 45 raise ValueError("Prohibited callback type: %s!" % cb_type)
43 "Prohibited callback type: %s!" % cb_type)
44 46
45 return estimator 47 return estimator
46 48
47 49
48 def _get_X_y(params, infile1, infile2): 50 def _get_X_y(params, infile1, infile2):
49 """ read from inputs and output X and y 51 """read from inputs and output X and y
50 52
51 Parameters 53 Parameters
52 ---------- 54 ----------
53 params : dict 55 params : dict
54 Tool inputs parameter 56 Tool inputs parameter
59 61
60 """ 62 """
61 # store read dataframe object 63 # store read dataframe object
62 loaded_df = {} 64 loaded_df = {}
63 65
64 input_type = params['input_options']['selected_input'] 66 input_type = params["input_options"]["selected_input"]
65 # tabular input 67 # tabular input
66 if input_type == 'tabular': 68 if input_type == "tabular":
67 header = 'infer' if params['input_options']['header1'] else None 69 header = "infer" if params["input_options"]["header1"] else None
68 column_option = (params['input_options']['column_selector_options_1'] 70 column_option = params["input_options"]["column_selector_options_1"][
69 ['selected_column_selector_option']) 71 "selected_column_selector_option"
70 if column_option in ['by_index_number', 'all_but_by_index_number', 72 ]
71 'by_header_name', 'all_but_by_header_name']: 73 if column_option in [
72 c = params['input_options']['column_selector_options_1']['col1'] 74 "by_index_number",
75 "all_but_by_index_number",
76 "by_header_name",
77 "all_but_by_header_name",
78 ]:
79 c = params["input_options"]["column_selector_options_1"]["col1"]
73 else: 80 else:
74 c = None 81 c = None
75 82
76 df_key = infile1 + repr(header) 83 df_key = infile1 + repr(header)
77 df = pd.read_csv(infile1, sep='\t', header=header, 84 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
78 parse_dates=True)
79 loaded_df[df_key] = df 85 loaded_df[df_key] = df
80 86
81 X = read_columns(df, c=c, c_option=column_option).astype(float) 87 X = read_columns(df, c=c, c_option=column_option).astype(float)
82 # sparse input 88 # sparse input
83 elif input_type == 'sparse': 89 elif input_type == "sparse":
84 X = mmread(open(infile1, 'r')) 90 X = mmread(open(infile1, "r"))
85 91
86 # Get target y 92 # Get target y
87 header = 'infer' if params['input_options']['header2'] else None 93 header = "infer" if params["input_options"]["header2"] else None
88 column_option = (params['input_options']['column_selector_options_2'] 94 column_option = params["input_options"]["column_selector_options_2"][
89 ['selected_column_selector_option2']) 95 "selected_column_selector_option2"
90 if column_option in ['by_index_number', 'all_but_by_index_number', 96 ]
91 'by_header_name', 'all_but_by_header_name']: 97 if column_option in [
92 c = params['input_options']['column_selector_options_2']['col2'] 98 "by_index_number",
99 "all_but_by_index_number",
100 "by_header_name",
101 "all_but_by_header_name",
102 ]:
103 c = params["input_options"]["column_selector_options_2"]["col2"]
93 else: 104 else:
94 c = None 105 c = None
95 106
96 df_key = infile2 + repr(header) 107 df_key = infile2 + repr(header)
97 if df_key in loaded_df: 108 if df_key in loaded_df:
98 infile2 = loaded_df[df_key] 109 infile2 = loaded_df[df_key]
99 else: 110 else:
100 infile2 = pd.read_csv(infile2, sep='\t', 111 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
101 header=header, parse_dates=True)
102 loaded_df[df_key] = infile2 112 loaded_df[df_key] = infile2
103 113
104 y = read_columns( 114 y = read_columns(
105 infile2, 115 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True
106 c=c, 116 )
107 c_option=column_option,
108 sep='\t',
109 header=header,
110 parse_dates=True)
111 if len(y.shape) == 2 and y.shape[1] == 1: 117 if len(y.shape) == 2 and y.shape[1] == 1:
112 y = y.ravel() 118 y = y.ravel()
113 119
114 return X, y 120 return X, y
115 121
116 122
117 def main(inputs, infile_estimator, infile1, infile2, out_object, 123 def main(inputs, infile_estimator, infile1, infile2, out_object, out_weights=None):
118 out_weights=None): 124 """main
119 """ main
120 125
121 Parameters 126 Parameters
122 ---------- 127 ----------
123 inputs : str 128 inputs : str
124 File path to galaxy tool parameter 129 File path to galaxy tool parameter
137 142
138 out_weights : str 143 out_weights : str
139 File path for output of weights 144 File path for output of weights
140 145
141 """ 146 """
142 with open(inputs, 'r') as param_handler: 147 with open(inputs, "r") as param_handler:
143 params = json.load(param_handler) 148 params = json.load(param_handler)
144 149
145 # load model 150 # load model
146 with open(infile_estimator, 'rb') as est_handler: 151 estimator = load_model_from_h5(infile_estimator)
147 estimator = load_model(est_handler) 152
148 estimator = clean_params(estimator, n_jobs=N_JOBS) 153 estimator = clean_params(estimator)
149 154
150 X_train, y_train = _get_X_y(params, infile1, infile2) 155 X_train, y_train = _get_X_y(params, infile1, infile2)
151 156
152 estimator.fit(X_train, y_train) 157 estimator.fit(X_train, y_train)
153 158
154 main_est = estimator 159 main_est = estimator
155 if isinstance(main_est, Pipeline): 160 if isinstance(main_est, Pipeline):
156 main_est = main_est.steps[-1][-1] 161 main_est = main_est.steps[-1][-1]
157 if hasattr(main_est, 'model_') \ 162 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
158 and hasattr(main_est, 'save_weights'):
159 if out_weights: 163 if out_weights:
160 main_est.save_weights(out_weights) 164 main_est.save_weights(out_weights)
161 del main_est.model_ 165 del main_est.model_
162 del main_est.fit_params 166 del main_est.fit_params
163 del main_est.model_class_ 167 del main_est.model_class_
164 del main_est.validation_data 168 if getattr(main_est, "validation_data", None):
165 if getattr(main_est, 'data_generator_', None): 169 del main_est.validation_data
170 if getattr(main_est, "data_generator_", None):
166 del main_est.data_generator_ 171 del main_est.data_generator_
167 172
168 with open(out_object, 'wb') as output_handler: 173 dump_model_to_h5(estimator, out_object)
169 pickle.dump(estimator, output_handler,
170 pickle.HIGHEST_PROTOCOL)
171 174
172 175
173 if __name__ == '__main__': 176 if __name__ == "__main__":
174 aparser = argparse.ArgumentParser() 177 aparser = argparse.ArgumentParser()
175 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 178 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
176 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator") 179 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator")
177 aparser.add_argument("-y", "--infile1", dest="infile1") 180 aparser.add_argument("-y", "--infile1", dest="infile1")
178 aparser.add_argument("-g", "--infile2", dest="infile2") 181 aparser.add_argument("-g", "--infile2", dest="infile2")
179 aparser.add_argument("-o", "--out_object", dest="out_object") 182 aparser.add_argument("-o", "--out_object", dest="out_object")
180 aparser.add_argument("-t", "--out_weights", dest="out_weights") 183 aparser.add_argument("-t", "--out_weights", dest="out_weights")
181 args = aparser.parse_args() 184 args = aparser.parse_args()
182 185
183 main(args.inputs, args.infile_estimator, args.infile1, 186 main(
184 args.infile2, args.out_object, args.out_weights) 187 args.inputs,
188 args.infile_estimator,
189 args.infile1,
190 args.infile2,
191 args.out_object,
192 args.out_weights,
193 )