Mercurial > repos > bgruening > sklearn_generalized_linear
comparison keras_train_and_eval.py @ 41:5fd0565c5323 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
| author | bgruening |
|---|---|
| date | Mon, 02 Oct 2023 08:49:18 +0000 |
| parents | a8771df897b2 |
| children |
comparison
equal
deleted
inserted
replaced
| 40:a8771df897b2 | 41:5fd0565c5323 |
|---|---|
| 186 inputs, | 186 inputs, |
| 187 infile_estimator, | 187 infile_estimator, |
| 188 infile1, | 188 infile1, |
| 189 infile2, | 189 infile2, |
| 190 outfile_result, | 190 outfile_result, |
| 191 outfile_history=None, | |
| 191 outfile_object=None, | 192 outfile_object=None, |
| 192 outfile_y_true=None, | 193 outfile_y_true=None, |
| 193 outfile_y_preds=None, | 194 outfile_y_preds=None, |
| 194 groups=None, | 195 groups=None, |
| 195 ref_seq=None, | 196 ref_seq=None, |
| 213 File path to dataset containing target values. | 214 File path to dataset containing target values. |
| 214 | 215 |
| 215 outfile_result : str | 216 outfile_result : str |
| 216 File path to save the results, either cv_results or test result. | 217 File path to save the results, either cv_results or test result. |
| 217 | 218 |
| 219 outfile_history : str, optional | |
| 220 File path to save the training history. | |
| 221 | |
| 218 outfile_object : str, optional | 222 outfile_object : str, optional |
| 219 File path to save searchCV object. | 223 File path to save searchCV object. |
| 220 | 224 |
| 221 outfile_y_true : str, optional | 225 outfile_y_true : str, optional |
| 222 File path to target values for prediction. | 226 File path to target values for prediction. |
| 251 | 255 |
| 252 # swap hyperparameter | 256 # swap hyperparameter |
| 253 swapping = params["experiment_schemes"]["hyperparams_swapping"] | 257 swapping = params["experiment_schemes"]["hyperparams_swapping"] |
| 254 swap_params = _eval_swap_params(swapping) | 258 swap_params = _eval_swap_params(swapping) |
| 255 estimator.set_params(**swap_params) | 259 estimator.set_params(**swap_params) |
| 256 | |
| 257 estimator_params = estimator.get_params() | 260 estimator_params = estimator.get_params() |
| 258 | |
| 259 # store read dataframe object | 261 # store read dataframe object |
| 260 loaded_df = {} | 262 loaded_df = {} |
| 261 | 263 |
| 262 input_type = params["input_options"]["selected_input"] | 264 input_type = params["input_options"]["selected_input"] |
| 263 # tabular input | 265 # tabular input |
| 446 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options) | 448 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options) |
| 447 | 449 |
| 448 # train and eval | 450 # train and eval |
| 449 if hasattr(estimator, "config") and hasattr(estimator, "model_type"): | 451 if hasattr(estimator, "config") and hasattr(estimator, "model_type"): |
| 450 if exp_scheme == "train_val_test": | 452 if exp_scheme == "train_val_test": |
| 451 estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) | 453 history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) |
| 452 else: | 454 else: |
| 453 estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) | 455 history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) |
| 454 else: | 456 else: |
| 455 estimator.fit(X_train, y_train) | 457 history = estimator.fit(X_train, y_train) |
| 456 | 458 if "callbacks" in estimator_params: |
| 459 for cb in estimator_params["callbacks"]: | |
| 460 if cb["callback_selection"]["callback_type"] == "CSVLogger": | |
| 461 hist_df = pd.DataFrame(history.history) | |
| 462 hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1) | |
| 463 epo_col = hist_df.pop('epoch') | |
| 464 hist_df.insert(0, 'epoch', epo_col) | |
| 465 hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False) | |
| 466 break | |
| 457 if isinstance(estimator, KerasGBatchClassifier): | 467 if isinstance(estimator, KerasGBatchClassifier): |
| 458 scores = {} | 468 scores = {} |
| 459 steps = estimator.prediction_steps | 469 steps = estimator.prediction_steps |
| 460 batch_size = estimator.batch_size | 470 batch_size = estimator.batch_size |
| 461 data_generator = estimator.data_generator_ | 471 data_generator = estimator.data_generator_ |
| 524 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 534 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
| 525 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | 535 aparser.add_argument("-e", "--estimator", dest="infile_estimator") |
| 526 aparser.add_argument("-X", "--infile1", dest="infile1") | 536 aparser.add_argument("-X", "--infile1", dest="infile1") |
| 527 aparser.add_argument("-y", "--infile2", dest="infile2") | 537 aparser.add_argument("-y", "--infile2", dest="infile2") |
| 528 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | 538 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") |
| 539 aparser.add_argument("-hi", "--outfile_history", dest="outfile_history") | |
| 529 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | 540 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") |
| 530 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") | 541 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") |
| 531 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") | 542 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") |
| 532 aparser.add_argument("-g", "--groups", dest="groups") | 543 aparser.add_argument("-g", "--groups", dest="groups") |
| 533 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 544 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
| 540 args.inputs, | 551 args.inputs, |
| 541 args.infile_estimator, | 552 args.infile_estimator, |
| 542 args.infile1, | 553 args.infile1, |
| 543 args.infile2, | 554 args.infile2, |
| 544 args.outfile_result, | 555 args.outfile_result, |
| 556 outfile_history=args.outfile_history, | |
| 545 outfile_object=args.outfile_object, | 557 outfile_object=args.outfile_object, |
| 546 outfile_y_true=args.outfile_y_true, | 558 outfile_y_true=args.outfile_y_true, |
| 547 outfile_y_preds=args.outfile_y_preds, | 559 outfile_y_preds=args.outfile_y_preds, |
| 548 groups=args.groups, | 560 groups=args.groups, |
| 549 ref_seq=args.ref_seq, | 561 ref_seq=args.ref_seq, |
