comparison train_test_eval.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children
comparison
equal deleted inserted replaced
2:38c4f8a98038 3:0a1812986bc3
1 import argparse 1 import argparse
2 import joblib
3 import json 2 import json
4 import numpy as np
5 import os 3 import os
6 import pandas as pd
7 import pickle
8 import warnings 4 import warnings
9 from itertools import chain 5 from itertools import chain
6
7 import joblib
8 import numpy as np
9 import pandas as pd
10 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
11 from galaxy_ml.model_validations import train_test_split
12 from galaxy_ml.utils import (
13 clean_params,
14 get_module,
15 get_scoring,
16 read_columns,
17 SafeEval,
18 try_get_attr
19 )
10 from scipy.io import mmread 20 from scipy.io import mmread
11 from sklearn.base import clone 21 from sklearn import pipeline
12 from sklearn import (cluster, compose, decomposition, ensemble,
13 feature_extraction, feature_selection,
14 gaussian_process, kernel_approximation, metrics,
15 model_selection, naive_bayes, neighbors,
16 pipeline, preprocessing, svm, linear_model,
17 tree, discriminant_analysis)
18 from sklearn.exceptions import FitFailedWarning
19 from sklearn.metrics.scorer import _check_multimetric_scoring
20 from sklearn.model_selection._validation import _score, cross_validate
21 from sklearn.model_selection import _search, _validation 22 from sklearn.model_selection import _search, _validation
22 from sklearn.utils import indexable, safe_indexing 23 from sklearn.model_selection._validation import _score
23 24 from sklearn.utils import _safe_indexing, indexable
24 from galaxy_ml.model_validations import train_test_split 25
25 from galaxy_ml.utils import (SafeEval, get_scoring, load_model, 26 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
26 read_columns, try_get_attr, get_module) 27 setattr(_search, "_fit_and_score", _fit_and_score)
27 28 setattr(_validation, "_fit_and_score", _fit_and_score)
28 29
29 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') 30 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
30 setattr(_search, '_fit_and_score', _fit_and_score) 31 CACHE_DIR = os.path.join(os.getcwd(), "cached")
31 setattr(_validation, '_fit_and_score', _fit_and_score)
32
33 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
34 CACHE_DIR = os.path.join(os.getcwd(), 'cached')
35 del os 32 del os
36 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 33 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
37 'nthread', 'callbacks') 34 ALLOWED_CALLBACKS = (
38 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', 35 "EarlyStopping",
39 'CSVLogger', 'None') 36 "TerminateOnNaN",
37 "ReduceLROnPlateau",
38 "CSVLogger",
39 "None",
40 )
40 41
41 42
42 def _eval_swap_params(params_builder): 43 def _eval_swap_params(params_builder):
43 swap_params = {} 44 swap_params = {}
44 45
45 for p in params_builder['param_set']: 46 for p in params_builder["param_set"]:
46 swap_value = p['sp_value'].strip() 47 swap_value = p["sp_value"].strip()
47 if swap_value == '': 48 if swap_value == "":
48 continue 49 continue
49 50
50 param_name = p['sp_name'] 51 param_name = p["sp_name"]
51 if param_name.lower().endswith(NON_SEARCHABLE): 52 if param_name.lower().endswith(NON_SEARCHABLE):
52 warnings.warn("Warning: `%s` is not eligible for search and was " 53 warnings.warn(
53 "omitted!" % param_name) 54 "Warning: `%s` is not eligible for search and was "
55 "omitted!" % param_name
56 )
54 continue 57 continue
55 58
56 if not swap_value.startswith(':'): 59 if not swap_value.startswith(":"):
57 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 60 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
58 ev = safe_eval(swap_value) 61 ev = safe_eval(swap_value)
59 else: 62 else:
60 # Have `:` before search list, asks for estimator evaluatio 63 # Have `:` before search list, asks for estimator evaluatio
61 safe_eval_es = SafeEval(load_estimators=True) 64 safe_eval_es = SafeEval(load_estimators=True)
78 if arr is None: 81 if arr is None:
79 nones.append(idx) 82 nones.append(idx)
80 else: 83 else:
81 new_arrays.append(arr) 84 new_arrays.append(arr)
82 85
83 if kwargs['shuffle'] == 'None': 86 if kwargs["shuffle"] == "None":
84 kwargs['shuffle'] = None 87 kwargs["shuffle"] = None
85 88
86 group_names = kwargs.pop('group_names', None) 89 group_names = kwargs.pop("group_names", None)
87 90
88 if group_names is not None and group_names.strip(): 91 if group_names is not None and group_names.strip():
89 group_names = [name.strip() for name in 92 group_names = [name.strip() for name in group_names.split(",")]
90 group_names.split(',')]
91 new_arrays = indexable(*new_arrays) 93 new_arrays = indexable(*new_arrays)
92 groups = kwargs['labels'] 94 groups = kwargs["labels"]
93 n_samples = new_arrays[0].shape[0] 95 n_samples = new_arrays[0].shape[0]
94 index_arr = np.arange(n_samples) 96 index_arr = np.arange(n_samples)
95 test = index_arr[np.isin(groups, group_names)] 97 test = index_arr[np.isin(groups, group_names)]
96 train = index_arr[~np.isin(groups, group_names)] 98 train = index_arr[~np.isin(groups, group_names)]
97 rval = list(chain.from_iterable( 99 rval = list(
98 (safe_indexing(a, train), 100 chain.from_iterable(
99 safe_indexing(a, test)) for a in new_arrays)) 101 (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays
102 )
103 )
100 else: 104 else:
101 rval = train_test_split(*new_arrays, **kwargs) 105 rval = train_test_split(*new_arrays, **kwargs)
102 106
103 for pos in nones: 107 for pos in nones:
104 rval[pos * 2: 2] = [None, None] 108 rval[pos * 2: 2] = [None, None]
105 109
106 return rval 110 return rval
107 111
108 112
109 def main(inputs, infile_estimator, infile1, infile2, 113 def main(
110 outfile_result, outfile_object=None, 114 inputs,
111 outfile_weights=None, groups=None, 115 infile_estimator,
112 ref_seq=None, intervals=None, targets=None, 116 infile1,
113 fasta_path=None): 117 infile2,
118 outfile_result,
119 outfile_object=None,
120 outfile_weights=None,
121 groups=None,
122 ref_seq=None,
123 intervals=None,
124 targets=None,
125 fasta_path=None,
126 ):
114 """ 127 """
115 Parameter 128 Parameter
116 --------- 129 ---------
117 inputs : str 130 inputs : str
118 File path to galaxy tool parameter 131 File path to galaxy tool parameter
148 File path to dataset compressed target bed file 161 File path to dataset compressed target bed file
149 162
150 fasta_path : str 163 fasta_path : str
151 File path to dataset containing fasta file 164 File path to dataset containing fasta file
152 """ 165 """
153 warnings.simplefilter('ignore') 166 warnings.simplefilter("ignore")
154 167
155 with open(inputs, 'r') as param_handler: 168 with open(inputs, "r") as param_handler:
156 params = json.load(param_handler) 169 params = json.load(param_handler)
157 170
158 # load estimator 171 # load estimator
159 with open(infile_estimator, 'rb') as estimator_handler: 172 estimator = load_model_from_h5(infile_estimator)
160 estimator = load_model(estimator_handler) 173 estimator = clean_params(estimator)
161 174
162 # swap hyperparameter 175 # swap hyperparameter
163 swapping = params['experiment_schemes']['hyperparams_swapping'] 176 swapping = params["experiment_schemes"]["hyperparams_swapping"]
164 swap_params = _eval_swap_params(swapping) 177 swap_params = _eval_swap_params(swapping)
165 estimator.set_params(**swap_params) 178 estimator.set_params(**swap_params)
166 179
167 estimator_params = estimator.get_params() 180 estimator_params = estimator.get_params()
168 181
169 # store read dataframe object 182 # store read dataframe object
170 loaded_df = {} 183 loaded_df = {}
171 184
172 input_type = params['input_options']['selected_input'] 185 input_type = params["input_options"]["selected_input"]
173 # tabular input 186 # tabular input
174 if input_type == 'tabular': 187 if input_type == "tabular":
175 header = 'infer' if params['input_options']['header1'] else None 188 header = "infer" if params["input_options"]["header1"] else None
176 column_option = (params['input_options']['column_selector_options_1'] 189 column_option = params["input_options"]["column_selector_options_1"][
177 ['selected_column_selector_option']) 190 "selected_column_selector_option"
178 if column_option in ['by_index_number', 'all_but_by_index_number', 191 ]
179 'by_header_name', 'all_but_by_header_name']: 192 if column_option in [
180 c = params['input_options']['column_selector_options_1']['col1'] 193 "by_index_number",
194 "all_but_by_index_number",
195 "by_header_name",
196 "all_but_by_header_name",
197 ]:
198 c = params["input_options"]["column_selector_options_1"]["col1"]
181 else: 199 else:
182 c = None 200 c = None
183 201
184 df_key = infile1 + repr(header) 202 df_key = infile1 + repr(header)
185 df = pd.read_csv(infile1, sep='\t', header=header, 203 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
186 parse_dates=True)
187 loaded_df[df_key] = df 204 loaded_df[df_key] = df
188 205
189 X = read_columns(df, c=c, c_option=column_option).astype(float) 206 X = read_columns(df, c=c, c_option=column_option).astype(float)
190 # sparse input 207 # sparse input
191 elif input_type == 'sparse': 208 elif input_type == "sparse":
192 X = mmread(open(infile1, 'r')) 209 X = mmread(open(infile1, "r"))
193 210
194 # fasta_file input 211 # fasta_file input
195 elif input_type == 'seq_fasta': 212 elif input_type == "seq_fasta":
196 pyfaidx = get_module('pyfaidx') 213 pyfaidx = get_module("pyfaidx")
197 sequences = pyfaidx.Fasta(fasta_path) 214 sequences = pyfaidx.Fasta(fasta_path)
198 n_seqs = len(sequences.keys()) 215 n_seqs = len(sequences.keys())
199 X = np.arange(n_seqs)[:, np.newaxis] 216 X = np.arange(n_seqs)[:, np.newaxis]
200 for param in estimator_params.keys(): 217 for param in estimator_params.keys():
201 if param.endswith('fasta_path'): 218 if param.endswith("fasta_path"):
202 estimator.set_params( 219 estimator.set_params(**{param: fasta_path})
203 **{param: fasta_path})
204 break 220 break
205 else: 221 else:
206 raise ValueError( 222 raise ValueError(
207 "The selected estimator doesn't support " 223 "The selected estimator doesn't support "
208 "fasta file input! Please consider using " 224 "fasta file input! Please consider using "
209 "KerasGBatchClassifier with " 225 "KerasGBatchClassifier with "
210 "FastaDNABatchGenerator/FastaProteinBatchGenerator " 226 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
211 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " 227 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
212 "in pipeline!") 228 "in pipeline!"
213 229 )
214 elif input_type == 'refseq_and_interval': 230
231 elif input_type == "refseq_and_interval":
215 path_params = { 232 path_params = {
216 'data_batch_generator__ref_genome_path': ref_seq, 233 "data_batch_generator__ref_genome_path": ref_seq,
217 'data_batch_generator__intervals_path': intervals, 234 "data_batch_generator__intervals_path": intervals,
218 'data_batch_generator__target_path': targets 235 "data_batch_generator__target_path": targets,
219 } 236 }
220 estimator.set_params(**path_params) 237 estimator.set_params(**path_params)
221 n_intervals = sum(1 for line in open(intervals)) 238 n_intervals = sum(1 for line in open(intervals))
222 X = np.arange(n_intervals)[:, np.newaxis] 239 X = np.arange(n_intervals)[:, np.newaxis]
223 240
224 # Get target y 241 # Get target y
225 header = 'infer' if params['input_options']['header2'] else None 242 header = "infer" if params["input_options"]["header2"] else None
226 column_option = (params['input_options']['column_selector_options_2'] 243 column_option = params["input_options"]["column_selector_options_2"][
227 ['selected_column_selector_option2']) 244 "selected_column_selector_option2"
228 if column_option in ['by_index_number', 'all_but_by_index_number', 245 ]
229 'by_header_name', 'all_but_by_header_name']: 246 if column_option in [
230 c = params['input_options']['column_selector_options_2']['col2'] 247 "by_index_number",
248 "all_but_by_index_number",
249 "by_header_name",
250 "all_but_by_header_name",
251 ]:
252 c = params["input_options"]["column_selector_options_2"]["col2"]
231 else: 253 else:
232 c = None 254 c = None
233 255
234 df_key = infile2 + repr(header) 256 df_key = infile2 + repr(header)
235 if df_key in loaded_df: 257 if df_key in loaded_df:
236 infile2 = loaded_df[df_key] 258 infile2 = loaded_df[df_key]
237 else: 259 else:
238 infile2 = pd.read_csv(infile2, sep='\t', 260 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
239 header=header, parse_dates=True)
240 loaded_df[df_key] = infile2 261 loaded_df[df_key] = infile2
241 262
242 y = read_columns( 263 y = read_columns(
243 infile2, 264 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True
244 c=c, 265 )
245 c_option=column_option,
246 sep='\t',
247 header=header,
248 parse_dates=True)
249 if len(y.shape) == 2 and y.shape[1] == 1: 266 if len(y.shape) == 2 and y.shape[1] == 1:
250 y = y.ravel() 267 y = y.ravel()
251 if input_type == 'refseq_and_interval': 268 if input_type == "refseq_and_interval":
252 estimator.set_params( 269 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
253 data_batch_generator__features=y.ravel().tolist())
254 y = None 270 y = None
255 # end y 271 # end y
256 272
257 # load groups 273 # load groups
258 if groups: 274 if groups:
259 groups_selector = (params['experiment_schemes']['test_split'] 275 groups_selector = (
260 ['split_algos']).pop('groups_selector') 276 params["experiment_schemes"]["test_split"]["split_algos"]
261 277 ).pop("groups_selector")
262 header = 'infer' if groups_selector['header_g'] else None 278
263 column_option = \ 279 header = "infer" if groups_selector["header_g"] else None
264 (groups_selector['column_selector_options_g'] 280 column_option = groups_selector["column_selector_options_g"][
265 ['selected_column_selector_option_g']) 281 "selected_column_selector_option_g"
266 if column_option in ['by_index_number', 'all_but_by_index_number', 282 ]
267 'by_header_name', 'all_but_by_header_name']: 283 if column_option in [
268 c = groups_selector['column_selector_options_g']['col_g'] 284 "by_index_number",
285 "all_but_by_index_number",
286 "by_header_name",
287 "all_but_by_header_name",
288 ]:
289 c = groups_selector["column_selector_options_g"]["col_g"]
269 else: 290 else:
270 c = None 291 c = None
271 292
272 df_key = groups + repr(header) 293 df_key = groups + repr(header)
273 if df_key in loaded_df: 294 if df_key in loaded_df:
274 groups = loaded_df[df_key] 295 groups = loaded_df[df_key]
275 296
276 groups = read_columns( 297 groups = read_columns(
277 groups, 298 groups,
278 c=c, 299 c=c,
279 c_option=column_option, 300 c_option=column_option,
280 sep='\t', 301 sep="\t",
281 header=header, 302 header=header,
282 parse_dates=True) 303 parse_dates=True,
304 )
283 groups = groups.ravel() 305 groups = groups.ravel()
284 306
285 # del loaded_df 307 # del loaded_df
286 del loaded_df 308 del loaded_df
287 309
288 # handle memory 310 # handle memory
289 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 311 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
290 # cache iraps_core fits could increase search speed significantly 312 # cache iraps_core fits could increase search speed significantly
291 if estimator.__class__.__name__ == 'IRAPSClassifier': 313 if estimator.__class__.__name__ == "IRAPSClassifier":
292 estimator.set_params(memory=memory) 314 estimator.set_params(memory=memory)
293 else: 315 else:
294 # For iraps buried in pipeline 316 # For iraps buried in pipeline
295 new_params = {} 317 new_params = {}
296 for p, v in estimator_params.items(): 318 for p, v in estimator_params.items():
297 if p.endswith('memory'): 319 if p.endswith("memory"):
298 # for case of `__irapsclassifier__memory` 320 # for case of `__irapsclassifier__memory`
299 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): 321 if len(p) > 8 and p[:-8].endswith("irapsclassifier"):
300 # cache iraps_core fits could increase search 322 # cache iraps_core fits could increase search
301 # speed significantly 323 # speed significantly
302 new_params[p] = memory 324 new_params[p] = memory
303 # security reason, we don't want memory being 325 # security reason, we don't want memory being
304 # modified unexpectedly 326 # modified unexpectedly
305 elif v: 327 elif v:
306 new_params[p] = None 328 new_params[p] = None
307 # handle n_jobs 329 # handle n_jobs
308 elif p.endswith('n_jobs'): 330 elif p.endswith("n_jobs"):
309 # For now, 1 CPU is suggested for iprasclassifier 331 # For now, 1 CPU is suggested for iprasclassifier
310 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): 332 if len(p) > 8 and p[:-8].endswith("irapsclassifier"):
311 new_params[p] = 1 333 new_params[p] = 1
312 else: 334 else:
313 new_params[p] = N_JOBS 335 new_params[p] = N_JOBS
314 # for security reason, types of callback are limited 336 # for security reason, types of callback are limited
315 elif p.endswith('callbacks'): 337 elif p.endswith("callbacks"):
316 for cb in v: 338 for cb in v:
317 cb_type = cb['callback_selection']['callback_type'] 339 cb_type = cb["callback_selection"]["callback_type"]
318 if cb_type not in ALLOWED_CALLBACKS: 340 if cb_type not in ALLOWED_CALLBACKS:
319 raise ValueError( 341 raise ValueError("Prohibited callback type: %s!" % cb_type)
320 "Prohibited callback type: %s!" % cb_type)
321 342
322 estimator.set_params(**new_params) 343 estimator.set_params(**new_params)
323 344
324 # handle scorer, convert to scorer dict 345 # handle scorer, convert to scorer dict
325 scoring = params['experiment_schemes']['metrics']['scoring'] 346 # Check if scoring is specified
347 scoring = params["experiment_schemes"]["metrics"].get("scoring", None)
348 if scoring is not None:
349 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
350 # Check if secondary_scoring is specified
351 secondary_scoring = scoring.get("secondary_scoring", None)
352 if secondary_scoring is not None:
353 # If secondary_scoring is specified, convert the list into comman separated string
354 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
326 scorer = get_scoring(scoring) 355 scorer = get_scoring(scoring)
327 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer)
328 356
329 # handle test (first) split 357 # handle test (first) split
330 test_split_options = (params['experiment_schemes'] 358 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
331 ['test_split']['split_algos']) 359
332 360 if test_split_options["shuffle"] == "group":
333 if test_split_options['shuffle'] == 'group': 361 test_split_options["labels"] = groups
334 test_split_options['labels'] = groups 362 if test_split_options["shuffle"] == "stratified":
335 if test_split_options['shuffle'] == 'stratified':
336 if y is not None: 363 if y is not None:
337 test_split_options['labels'] = y 364 test_split_options["labels"] = y
338 else: 365 else:
339 raise ValueError("Stratified shuffle split is not " 366 raise ValueError(
340 "applicable on empty target values!") 367 "Stratified shuffle split is not " "applicable on empty target values!"
341 368 )
342 X_train, X_test, y_train, y_test, groups_train, groups_test = \ 369
343 train_test_split_none(X, y, groups, **test_split_options) 370 (
344 371 X_train,
345 exp_scheme = params['experiment_schemes']['selected_exp_scheme'] 372 X_test,
373 y_train,
374 y_test,
375 groups_train,
376 _groups_test,
377 ) = train_test_split_none(X, y, groups, **test_split_options)
378
379 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
346 380
347 # handle validation (second) split 381 # handle validation (second) split
348 if exp_scheme == 'train_val_test': 382 if exp_scheme == "train_val_test":
349 val_split_options = (params['experiment_schemes'] 383 val_split_options = params["experiment_schemes"]["val_split"]["split_algos"]
350 ['val_split']['split_algos']) 384
351 385 if val_split_options["shuffle"] == "group":
352 if val_split_options['shuffle'] == 'group': 386 val_split_options["labels"] = groups_train
353 val_split_options['labels'] = groups_train 387 if val_split_options["shuffle"] == "stratified":
354 if val_split_options['shuffle'] == 'stratified':
355 if y_train is not None: 388 if y_train is not None:
356 val_split_options['labels'] = y_train 389 val_split_options["labels"] = y_train
357 else: 390 else:
358 raise ValueError("Stratified shuffle split is not " 391 raise ValueError(
359 "applicable on empty target values!") 392 "Stratified shuffle split is not "
360 393 "applicable on empty target values!"
361 X_train, X_val, y_train, y_val, groups_train, groups_val = \ 394 )
362 train_test_split_none(X_train, y_train, groups_train, 395
363 **val_split_options) 396 (
397 X_train,
398 X_val,
399 y_train,
400 y_val,
401 groups_train,
402 _groups_val,
403 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
364 404
365 # train and eval 405 # train and eval
366 if hasattr(estimator, 'validation_data'): 406 if hasattr(estimator, "validation_data"):
367 if exp_scheme == 'train_val_test': 407 if exp_scheme == "train_val_test":
368 estimator.fit(X_train, y_train, 408 estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
369 validation_data=(X_val, y_val)) 409 else:
370 else: 410 estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
371 estimator.fit(X_train, y_train,
372 validation_data=(X_test, y_test))
373 else: 411 else:
374 estimator.fit(X_train, y_train) 412 estimator.fit(X_train, y_train)
375 413
376 if hasattr(estimator, 'evaluate'): 414 if hasattr(estimator, "evaluate"):
377 scores = estimator.evaluate(X_test, y_test=y_test, 415 scores = estimator.evaluate(
378 scorer=scorer, 416 X_test, y_test=y_test, scorer=scorer, is_multimetric=True
379 is_multimetric=True) 417 )
380 else: 418 else:
381 scores = _score(estimator, X_test, y_test, scorer, 419 scores = _score(estimator, X_test, y_test, scorer)
382 is_multimetric=True)
383 # handle output 420 # handle output
384 for name, score in scores.items(): 421 for name, score in scores.items():
385 scores[name] = [score] 422 scores[name] = [score]
386 df = pd.DataFrame(scores) 423 df = pd.DataFrame(scores)
387 df = df[sorted(df.columns)] 424 df = df[sorted(df.columns)]
388 df.to_csv(path_or_buf=outfile_result, sep='\t', 425 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
389 header=True, index=False)
390 426
391 memory.clear(warn=False) 427 memory.clear(warn=False)
392 428
393 if outfile_object: 429 if outfile_object:
394 main_est = estimator 430 main_est = estimator
395 if isinstance(estimator, pipeline.Pipeline): 431 if isinstance(estimator, pipeline.Pipeline):
396 main_est = estimator.steps[-1][-1] 432 main_est = estimator.steps[-1][-1]
397 433
398 if hasattr(main_est, 'model_') \ 434 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
399 and hasattr(main_est, 'save_weights'):
400 if outfile_weights: 435 if outfile_weights:
401 main_est.save_weights(outfile_weights) 436 main_est.save_weights(outfile_weights)
402 del main_est.model_ 437 if getattr(main_est, "model_", None):
403 del main_est.fit_params 438 del main_est.model_
404 del main_est.model_class_ 439 if getattr(main_est, "fit_params", None):
405 del main_est.validation_data 440 del main_est.fit_params
406 if getattr(main_est, 'data_generator_', None): 441 if getattr(main_est, "model_class_", None):
442 del main_est.model_class_
443 if getattr(main_est, "validation_data", None):
444 del main_est.validation_data
445 if getattr(main_est, "data_generator_", None):
407 del main_est.data_generator_ 446 del main_est.data_generator_
408 447
409 with open(outfile_object, 'wb') as output_handler: 448 dump_model_to_h5(estimator, outfile_object)
410 pickle.dump(estimator, output_handler, 449
411 pickle.HIGHEST_PROTOCOL) 450
412 451 if __name__ == "__main__":
413
414 if __name__ == '__main__':
415 aparser = argparse.ArgumentParser() 452 aparser = argparse.ArgumentParser()
416 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 453 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
417 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 454 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
418 aparser.add_argument("-X", "--infile1", dest="infile1") 455 aparser.add_argument("-X", "--infile1", dest="infile1")
419 aparser.add_argument("-y", "--infile2", dest="infile2") 456 aparser.add_argument("-y", "--infile2", dest="infile2")
425 aparser.add_argument("-b", "--intervals", dest="intervals") 462 aparser.add_argument("-b", "--intervals", dest="intervals")
426 aparser.add_argument("-t", "--targets", dest="targets") 463 aparser.add_argument("-t", "--targets", dest="targets")
427 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 464 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
428 args = aparser.parse_args() 465 args = aparser.parse_args()
429 466
430 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 467 main(
431 args.outfile_result, outfile_object=args.outfile_object, 468 args.inputs,
432 outfile_weights=args.outfile_weights, groups=args.groups, 469 args.infile_estimator,
433 ref_seq=args.ref_seq, intervals=args.intervals, 470 args.infile1,
434 targets=args.targets, fasta_path=args.fasta_path) 471 args.infile2,
472 args.outfile_result,
473 outfile_object=args.outfile_object,
474 outfile_weights=args.outfile_weights,
475 groups=args.groups,
476 ref_seq=args.ref_seq,
477 intervals=args.intervals,
478 targets=args.targets,
479 fasta_path=args.fasta_path,
480 )