Mercurial > repos > bgruening > sklearn_generalized_linear
comparison search_model_validation.py @ 40:a8771df897b2 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening |
|---|---|
| date | Wed, 09 Aug 2023 11:13:19 +0000 |
| parents | 34f295eb5782 |
| children |
comparison
equal
deleted
inserted
replaced
| 39:1a72afcb0752 | 40:a8771df897b2 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import collections | |
| 3 import json | 2 import json |
| 4 import os | 3 import os |
| 5 import pickle | |
| 6 import sys | 4 import sys |
| 7 import warnings | 5 import warnings |
| 6 from distutils.version import LooseVersion as Version | |
| 8 | 7 |
| 9 import imblearn | 8 import imblearn |
| 10 import joblib | 9 import joblib |
| 11 import numpy as np | 10 import numpy as np |
| 12 import pandas as pd | 11 import pandas as pd |
| 13 import skrebate | 12 import skrebate |
| 14 from galaxy_ml.utils import (clean_params, get_cv, | 13 from galaxy_ml import __version__ as galaxy_ml_version |
| 15 get_main_estimator, get_module, get_scoring, | 14 from galaxy_ml.binarize_target import IRAPSClassifier |
| 16 load_model, read_columns, SafeEval, try_get_attr) | 15 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 |
| 16 from galaxy_ml.utils import ( | |
| 17 clean_params, | |
| 18 get_cv, | |
| 19 get_main_estimator, | |
| 20 get_module, | |
| 21 get_scoring, | |
| 22 read_columns, | |
| 23 SafeEval, | |
| 24 try_get_attr | |
| 25 ) | |
| 17 from scipy.io import mmread | 26 from scipy.io import mmread |
| 18 from sklearn import (cluster, decomposition, feature_selection, | 27 from sklearn import ( |
| 19 kernel_approximation, model_selection, preprocessing) | 28 cluster, |
| 29 decomposition, | |
| 30 feature_selection, | |
| 31 kernel_approximation, | |
| 32 model_selection, | |
| 33 preprocessing, | |
| 34 ) | |
| 20 from sklearn.exceptions import FitFailedWarning | 35 from sklearn.exceptions import FitFailedWarning |
| 21 from sklearn.model_selection import _search, _validation | 36 from sklearn.model_selection import _search, _validation |
| 22 from sklearn.model_selection._validation import _score, cross_validate | 37 from sklearn.model_selection._validation import _score, cross_validate |
| 23 | 38 from sklearn.preprocessing import LabelEncoder |
| 24 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 39 from skopt import BayesSearchCV |
| 25 setattr(_search, "_fit_and_score", _fit_and_score) | |
| 26 setattr(_validation, "_fit_and_score", _fit_and_score) | |
| 27 | 40 |
| 28 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1)) | 41 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1)) |
| 29 # handle disk cache | 42 # handle disk cache |
| 30 CACHE_DIR = os.path.join(os.getcwd(), "cached") | 43 CACHE_DIR = os.path.join(os.getcwd(), "cached") |
| 31 del os | 44 NON_SEARCHABLE = ( |
| 32 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks") | 45 "n_jobs", |
| 46 "pre_dispatch", | |
| 47 "memory", | |
| 48 "_path", | |
| 49 "_dir", | |
| 50 "nthread", | |
| 51 "callbacks", | |
| 52 ) | |
| 33 | 53 |
| 34 | 54 |
| 35 def _eval_search_params(params_builder): | 55 def _eval_search_params(params_builder): |
| 36 search_params = {} | 56 search_params = {} |
| 37 | 57 |
| 98 skrebate.MultiSURFstar(n_jobs=N_JOBS), | 118 skrebate.MultiSURFstar(n_jobs=N_JOBS), |
| 99 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), | 119 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), |
| 100 imblearn.under_sampling.CondensedNearestNeighbour( | 120 imblearn.under_sampling.CondensedNearestNeighbour( |
| 101 random_state=0, n_jobs=N_JOBS | 121 random_state=0, n_jobs=N_JOBS |
| 102 ), | 122 ), |
| 103 imblearn.under_sampling.EditedNearestNeighbours( | 123 imblearn.under_sampling.EditedNearestNeighbours(n_jobs=N_JOBS), |
| 104 random_state=0, n_jobs=N_JOBS | 124 imblearn.under_sampling.RepeatedEditedNearestNeighbours(n_jobs=N_JOBS), |
| 105 ), | 125 imblearn.under_sampling.AllKNN(n_jobs=N_JOBS), |
| 106 imblearn.under_sampling.RepeatedEditedNearestNeighbours( | |
| 107 random_state=0, n_jobs=N_JOBS | |
| 108 ), | |
| 109 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), | |
| 110 imblearn.under_sampling.InstanceHardnessThreshold( | 126 imblearn.under_sampling.InstanceHardnessThreshold( |
| 111 random_state=0, n_jobs=N_JOBS | 127 random_state=0, n_jobs=N_JOBS |
| 112 ), | 128 ), |
| 113 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS), | 129 imblearn.under_sampling.NearMiss(n_jobs=N_JOBS), |
| 114 imblearn.under_sampling.NeighbourhoodCleaningRule( | 130 imblearn.under_sampling.NeighbourhoodCleaningRule(n_jobs=N_JOBS), |
| 115 random_state=0, n_jobs=N_JOBS | |
| 116 ), | |
| 117 imblearn.under_sampling.OneSidedSelection( | 131 imblearn.under_sampling.OneSidedSelection( |
| 118 random_state=0, n_jobs=N_JOBS | 132 random_state=0, n_jobs=N_JOBS |
| 119 ), | 133 ), |
| 120 imblearn.under_sampling.RandomUnderSampler(random_state=0), | 134 imblearn.under_sampling.RandomUnderSampler(random_state=0), |
| 121 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS), | 135 imblearn.under_sampling.TomekLinks(n_jobs=N_JOBS), |
| 122 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), | 136 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), |
| 137 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), | |
| 138 imblearn.over_sampling.KMeansSMOTE(random_state=0, n_jobs=N_JOBS), | |
| 123 imblearn.over_sampling.RandomOverSampler(random_state=0), | 139 imblearn.over_sampling.RandomOverSampler(random_state=0), |
| 124 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), | 140 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), |
| 125 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), | 141 imblearn.over_sampling.SMOTEN(random_state=0, n_jobs=N_JOBS), |
| 126 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), | |
| 127 imblearn.over_sampling.SMOTENC( | 142 imblearn.over_sampling.SMOTENC( |
| 128 categorical_features=[], random_state=0, n_jobs=N_JOBS | 143 categorical_features=[], random_state=0, n_jobs=N_JOBS |
| 129 ), | 144 ), |
| 145 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), | |
| 130 imblearn.combine.SMOTEENN(random_state=0), | 146 imblearn.combine.SMOTEENN(random_state=0), |
| 131 imblearn.combine.SMOTETomek(random_state=0), | 147 imblearn.combine.SMOTETomek(random_state=0), |
| 132 ) | 148 ) |
| 133 newlist = [] | 149 newlist = [] |
| 134 for obj in ev: | 150 for obj in ev: |
| 286 else: | 302 else: |
| 287 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) | 303 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) |
| 288 loaded_df[df_key] = infile2 | 304 loaded_df[df_key] = infile2 |
| 289 | 305 |
| 290 y = read_columns( | 306 y = read_columns( |
| 291 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True | 307 infile2, |
| 308 c=c, | |
| 309 c_option=column_option, | |
| 310 sep="\t", | |
| 311 header=header, | |
| 312 parse_dates=True, | |
| 292 ) | 313 ) |
| 293 if len(y.shape) == 2 and y.shape[1] == 1: | 314 if len(y.shape) == 2 and y.shape[1] == 1: |
| 294 y = y.ravel() | 315 y = y.ravel() |
| 295 if input_type == "refseq_and_interval": | 316 if input_type == "refseq_and_interval": |
| 296 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) | 317 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) |
| 414 pass | 435 pass |
| 415 for warning in w: | 436 for warning in w: |
| 416 print(repr(warning.message)) | 437 print(repr(warning.message)) |
| 417 | 438 |
| 418 scorer_ = searcher.scorer_ | 439 scorer_ = searcher.scorer_ |
| 419 if isinstance(scorer_, collections.Mapping): | |
| 420 is_multimetric = True | |
| 421 else: | |
| 422 is_multimetric = False | |
| 423 | 440 |
| 424 best_estimator_ = getattr(searcher, "best_estimator_") | 441 best_estimator_ = getattr(searcher, "best_estimator_") |
| 425 | 442 |
| 426 # TODO Solve deep learning models in pipeline | 443 # TODO Solve deep learning models in pipeline |
| 427 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": | 444 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": |
| 428 test_score = best_estimator_.evaluate( | 445 test_score = best_estimator_.evaluate( |
| 429 X_test, scorer=scorer_, is_multimetric=is_multimetric | 446 X_test, |
| 447 scorer=scorer_, | |
| 430 ) | 448 ) |
| 431 else: | 449 else: |
| 432 test_score = _score( | 450 test_score = _score(best_estimator_, X_test, y_test, scorer_) |
| 433 best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric | 451 |
| 434 ) | 452 if not isinstance(scorer_, dict): |
| 435 | |
| 436 if not is_multimetric: | |
| 437 test_score = {primary_scoring: test_score} | 453 test_score = {primary_scoring: test_score} |
| 438 for key, value in test_score.items(): | 454 for key, value in test_score.items(): |
| 439 test_score[key] = [value] | 455 test_score[key] = [value] |
| 440 result_df = pd.DataFrame(test_score) | 456 result_df = pd.DataFrame(test_score) |
| 441 result_df.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False) | 457 result_df.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False) |
| 442 | 458 |
| 443 return searcher | 459 return searcher |
| 460 | |
| 461 | |
| 462 def _set_memory(estimator, memory): | |
| 463 """set memeory cache | |
| 464 | |
| 465 Parameters | |
| 466 ---------- | |
| 467 estimator : python object | |
| 468 memory : joblib.Memory object | |
| 469 | |
| 470 Returns | |
| 471 ------- | |
| 472 estimator : estimator object after setting new attributes | |
| 473 """ | |
| 474 if isinstance(estimator, IRAPSClassifier): | |
| 475 estimator.set_params(memory=memory) | |
| 476 return estimator | |
| 477 | |
| 478 estimator_params = estimator.get_params() | |
| 479 | |
| 480 new_params = {} | |
| 481 for k in estimator_params.keys(): | |
| 482 if k.endswith("irapsclassifier__memory"): | |
| 483 new_params[k] = memory | |
| 484 | |
| 485 estimator.set_params(**new_params) | |
| 486 | |
| 487 return estimator | |
| 444 | 488 |
| 445 | 489 |
| 446 def main( | 490 def main( |
| 447 inputs, | 491 inputs, |
| 448 infile_estimator, | 492 infile_estimator, |
| 449 infile1, | 493 infile1, |
| 450 infile2, | 494 infile2, |
| 451 outfile_result, | 495 outfile_result, |
| 452 outfile_object=None, | 496 outfile_object=None, |
| 453 outfile_weights=None, | |
| 454 groups=None, | 497 groups=None, |
| 455 ref_seq=None, | 498 ref_seq=None, |
| 456 intervals=None, | 499 intervals=None, |
| 457 targets=None, | 500 targets=None, |
| 458 fasta_path=None, | 501 fasta_path=None, |
| 459 ): | 502 ): |
| 460 """ | 503 """ |
| 461 Parameter | 504 Parameter |
| 462 --------- | 505 --------- |
| 463 inputs : str | 506 inputs : str |
| 464 File path to galaxy tool parameter | 507 File path to galaxy tool parameter. |
| 465 | 508 |
| 466 infile_estimator : str | 509 infile_estimator : str |
| 467 File path to estimator | 510 File path to estimator. |
| 468 | 511 |
| 469 infile1 : str | 512 infile1 : str |
| 470 File path to dataset containing features | 513 File path to dataset containing features |
| 471 | 514 |
| 472 infile2 : str | 515 infile2 : str |
| 475 outfile_result : str | 518 outfile_result : str |
| 476 File path to save the results, either cv_results or test result | 519 File path to save the results, either cv_results or test result |
| 477 | 520 |
| 478 outfile_object : str, optional | 521 outfile_object : str, optional |
| 479 File path to save searchCV object | 522 File path to save searchCV object |
| 480 | |
| 481 outfile_weights : str, optional | |
| 482 File path to save model weights | |
| 483 | 523 |
| 484 groups : str | 524 groups : str |
| 485 File path to dataset containing groups labels | 525 File path to dataset containing groups labels |
| 486 | 526 |
| 487 ref_seq : str | 527 ref_seq : str |
| 503 | 543 |
| 504 with open(inputs, "r") as param_handler: | 544 with open(inputs, "r") as param_handler: |
| 505 params = json.load(param_handler) | 545 params = json.load(param_handler) |
| 506 | 546 |
| 507 # Override the refit parameter | 547 # Override the refit parameter |
| 508 params["search_schemes"]["options"]["refit"] = ( | 548 params["options"]["refit"] = ( |
| 509 True if params["save"] != "nope" else False | 549 True |
| 550 if ( | |
| 551 params["save"] != "nope" | |
| 552 or params["outer_split"]["split_mode"] == "nested_cv" | |
| 553 ) | |
| 554 else False | |
| 510 ) | 555 ) |
| 511 | 556 |
| 512 with open(infile_estimator, "rb") as estimator_handler: | 557 estimator = load_model_from_h5(infile_estimator) |
| 513 estimator = load_model(estimator_handler) | 558 |
| 514 | 559 estimator = clean_params(estimator) |
| 515 optimizer = params["search_schemes"]["selected_search_scheme"] | 560 |
| 516 optimizer = getattr(model_selection, optimizer) | 561 if estimator.__class__.__name__ == "KerasGBatchClassifier": |
| 562 _fit_and_score = try_get_attr( | |
| 563 "galaxy_ml.model_validations", | |
| 564 "_fit_and_score", | |
| 565 ) | |
| 566 | |
| 567 setattr(_search, "_fit_and_score", _fit_and_score) | |
| 568 setattr(_validation, "_fit_and_score", _fit_and_score) | |
| 569 | |
| 570 search_algos_and_options = params["search_algos"] | |
| 571 optimizer = search_algos_and_options.pop("selected_search_algo") | |
| 572 if optimizer == "skopt.BayesSearchCV": | |
| 573 optimizer = BayesSearchCV | |
| 574 else: | |
| 575 optimizer = getattr(model_selection, optimizer) | |
| 517 | 576 |
| 518 # handle gridsearchcv options | 577 # handle gridsearchcv options |
| 519 options = params["search_schemes"]["options"] | 578 options = params["options"] |
| 579 options.update(search_algos_and_options) | |
| 520 | 580 |
| 521 if groups: | 581 if groups: |
| 522 header = ( | 582 header = ( |
| 523 "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None | 583 "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None |
| 524 ) | 584 ) |
| 551 parse_dates=True, | 611 parse_dates=True, |
| 552 ) | 612 ) |
| 553 groups = groups.ravel() | 613 groups = groups.ravel() |
| 554 options["cv_selector"]["groups_selector"] = groups | 614 options["cv_selector"]["groups_selector"] = groups |
| 555 | 615 |
| 556 splitter, groups = get_cv(options.pop("cv_selector")) | 616 cv_selector = options.pop("cv_selector") |
| 617 if Version(galaxy_ml_version) < Version("0.8.3"): | |
| 618 cv_selector.pop("n_stratification_bins", None) | |
| 619 splitter, groups = get_cv(cv_selector) | |
| 557 options["cv"] = splitter | 620 options["cv"] = splitter |
| 558 primary_scoring = options["scoring"]["primary_scoring"] | 621 primary_scoring = options["scoring"]["primary_scoring"] |
| 559 # get_scoring() expects secondary_scoring to be a comma separated string (not a list) | 622 options["scoring"] = get_scoring(options["scoring"]) |
| 560 # Check if secondary_scoring is specified | 623 # TODO make BayesSearchCV support multiple scoring |
| 561 secondary_scoring = options["scoring"].get("secondary_scoring", None) | 624 if optimizer == "skopt.BayesSearchCV" and isinstance(options["scoring"], dict): |
| 562 if secondary_scoring is not None: | 625 options["scoring"] = options["scoring"][primary_scoring] |
| 563 # If secondary_scoring is specified, convert the list into comman separated string | 626 warnings.warn( |
| 564 options["scoring"]["secondary_scoring"] = ",".join( | 627 "BayesSearchCV doesn't support multiple " |
| 565 options["scoring"]["secondary_scoring"] | 628 "scorings! Primary scoring is used." |
| 566 ) | 629 ) |
| 567 options["scoring"] = get_scoring(options["scoring"]) | |
| 568 if options["error_score"]: | 630 if options["error_score"]: |
| 569 options["error_score"] = "raise" | 631 options["error_score"] = "raise" |
| 570 else: | 632 else: |
| 571 options["error_score"] = np.nan | 633 options["error_score"] = np.NaN |
| 572 if options["refit"] and isinstance(options["scoring"], dict): | 634 if options["refit"] and isinstance(options["scoring"], dict): |
| 573 options["refit"] = primary_scoring | 635 options["refit"] = primary_scoring |
| 574 if "pre_dispatch" in options and options["pre_dispatch"] == "": | 636 if "pre_dispatch" in options and options["pre_dispatch"] == "": |
| 575 options["pre_dispatch"] = None | 637 options["pre_dispatch"] = None |
| 576 | 638 |
| 577 params_builder = params["search_schemes"]["search_params_builder"] | 639 params_builder = params["search_params_builder"] |
| 578 param_grid = _eval_search_params(params_builder) | 640 param_grid = _eval_search_params(params_builder) |
| 579 | |
| 580 estimator = clean_params(estimator) | |
| 581 | 641 |
| 582 # save the SearchCV object without fit | 642 # save the SearchCV object without fit |
| 583 if params["save"] == "save_no_fit": | 643 if params["save"] == "save_no_fit": |
| 584 searcher = optimizer(estimator, param_grid, **options) | 644 searcher = optimizer(estimator, param_grid, **options) |
| 585 print(searcher) | 645 dump_model_to_h5(searcher, outfile_object) |
| 586 with open(outfile_object, "wb") as output_handler: | |
| 587 pickle.dump(searcher, output_handler, pickle.HIGHEST_PROTOCOL) | |
| 588 return 0 | 646 return 0 |
| 589 | 647 |
| 590 # read inputs and loads new attributes, like paths | 648 # read inputs and loads new attributes, like paths |
| 591 estimator, X, y = _handle_X_y( | 649 estimator, X, y = _handle_X_y( |
| 592 estimator, | 650 estimator, |
| 598 intervals=intervals, | 656 intervals=intervals, |
| 599 targets=targets, | 657 targets=targets, |
| 600 fasta_path=fasta_path, | 658 fasta_path=fasta_path, |
| 601 ) | 659 ) |
| 602 | 660 |
| 661 label_encoder = LabelEncoder() | |
| 662 if get_main_estimator(estimator).__class__.__name__ == "XGBClassifier": | |
| 663 y = label_encoder.fit_transform(y) | |
| 664 | |
| 603 # cache iraps_core fits could increase search speed significantly | 665 # cache iraps_core fits could increase search speed significantly |
| 604 memory = joblib.Memory(location=CACHE_DIR, verbose=0) | 666 memory = joblib.Memory(location=CACHE_DIR, verbose=0) |
| 605 main_est = get_main_estimator(estimator) | 667 estimator = _set_memory(estimator, memory) |
| 606 if main_est.__class__.__name__ == "IRAPSClassifier": | |
| 607 main_est.set_params(memory=memory) | |
| 608 | 668 |
| 609 searcher = optimizer(estimator, param_grid, **options) | 669 searcher = optimizer(estimator, param_grid, **options) |
| 610 | 670 |
| 611 split_mode = params["outer_split"].pop("split_mode") | 671 split_mode = params["outer_split"].pop("split_mode") |
| 612 | 672 |
| 673 # Nested CV | |
| 613 if split_mode == "nested_cv": | 674 if split_mode == "nested_cv": |
| 614 # make sure refit is choosen | 675 cv_selector = params["outer_split"]["cv_selector"] |
| 615 # this could be True for sklearn models, but not the case for | 676 if Version(galaxy_ml_version) < Version("0.8.3"): |
| 616 # deep learning models | 677 cv_selector.pop("n_stratification_bins", None) |
| 617 if not options["refit"] and not all( | 678 outer_cv, _ = get_cv(cv_selector) |
| 618 hasattr(estimator, attr) for attr in ("config", "model_type") | |
| 619 ): | |
| 620 warnings.warn("Refit is change to `True` for nested validation!") | |
| 621 setattr(searcher, "refit", True) | |
| 622 | |
| 623 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"]) | |
| 624 # nested CV, outer cv using cross_validate | 679 # nested CV, outer cv using cross_validate |
| 625 if options["error_score"] == "raise": | 680 if options["error_score"] == "raise": |
| 626 rval = cross_validate( | 681 rval = cross_validate( |
| 627 searcher, | 682 searcher, |
| 628 X, | 683 X, |
| 629 y, | 684 y, |
| 685 groups=groups, | |
| 630 scoring=options["scoring"], | 686 scoring=options["scoring"], |
| 631 cv=outer_cv, | 687 cv=outer_cv, |
| 632 n_jobs=N_JOBS, | 688 n_jobs=N_JOBS, |
| 633 verbose=options["verbose"], | 689 verbose=options["verbose"], |
| 690 fit_params={"groups": groups}, | |
| 634 return_estimator=(params["save"] == "save_estimator"), | 691 return_estimator=(params["save"] == "save_estimator"), |
| 635 error_score=options["error_score"], | 692 error_score=options["error_score"], |
| 636 return_train_score=True, | 693 return_train_score=True, |
| 637 ) | 694 ) |
| 638 else: | 695 else: |
| 641 try: | 698 try: |
| 642 rval = cross_validate( | 699 rval = cross_validate( |
| 643 searcher, | 700 searcher, |
| 644 X, | 701 X, |
| 645 y, | 702 y, |
| 703 groups=groups, | |
| 646 scoring=options["scoring"], | 704 scoring=options["scoring"], |
| 647 cv=outer_cv, | 705 cv=outer_cv, |
| 648 n_jobs=N_JOBS, | 706 n_jobs=N_JOBS, |
| 649 verbose=options["verbose"], | 707 verbose=options["verbose"], |
| 708 fit_params={"groups": groups}, | |
| 650 return_estimator=(params["save"] == "save_estimator"), | 709 return_estimator=(params["save"] == "save_estimator"), |
| 651 error_score=options["error_score"], | 710 error_score=options["error_score"], |
| 652 return_train_score=True, | 711 return_train_score=True, |
| 653 ) | 712 ) |
| 654 except ValueError: | 713 except ValueError: |
| 674 cv_results_ = pd.DataFrame(cv_results_) | 733 cv_results_ = pd.DataFrame(cv_results_) |
| 675 cv_results_ = cv_results_[sorted(cv_results_.columns)] | 734 cv_results_ = cv_results_[sorted(cv_results_.columns)] |
| 676 cv_results_.to_csv(target_path, sep="\t", header=True, index=False) | 735 cv_results_.to_csv(target_path, sep="\t", header=True, index=False) |
| 677 except Exception as e: | 736 except Exception as e: |
| 678 print(e) | 737 print(e) |
| 679 finally: | |
| 680 del os | |
| 681 | 738 |
| 682 keys = list(rval.keys()) | 739 keys = list(rval.keys()) |
| 683 for k in keys: | 740 for k in keys: |
| 684 if k.startswith("test"): | 741 if k.startswith("test"): |
| 685 rval["mean_" + k] = np.mean(rval[k]) | 742 rval["mean_" + k] = np.mean(rval[k]) |
| 687 if k.endswith("time"): | 744 if k.endswith("time"): |
| 688 rval.pop(k) | 745 rval.pop(k) |
| 689 rval = pd.DataFrame(rval) | 746 rval = pd.DataFrame(rval) |
| 690 rval = rval[sorted(rval.columns)] | 747 rval = rval[sorted(rval.columns)] |
| 691 rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) | 748 rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) |
| 749 | |
| 750 return 0 | |
| 751 | |
| 692 # deprecate train test split mode | 752 # deprecate train test split mode |
| 693 """searcher = _do_train_test_split_val( | 753 """searcher = _do_train_test_split_val( |
| 694 searcher, X, y, params, | 754 searcher, X, y, params, |
| 695 primary_scoring=primary_scoring, | 755 primary_scoring=primary_scoring, |
| 696 error_score=options['error_score'], | 756 error_score=options['error_score'], |
| 697 groups=groups, | 757 groups=groups, |
| 698 outfile=outfile_result)""" | 758 outfile=outfile_result)""" |
| 699 return 0 | |
| 700 | 759 |
| 701 # no outer split | 760 # no outer split |
| 702 else: | 761 else: |
| 703 searcher.set_params(n_jobs=N_JOBS) | 762 searcher.set_params(n_jobs=N_JOBS) |
| 704 if options["error_score"] == "raise": | 763 if options["error_score"] == "raise": |
| 730 "'best_estimator_', because either it's " | 789 "'best_estimator_', because either it's " |
| 731 "nested gridsearch or `refit` is False!" | 790 "nested gridsearch or `refit` is False!" |
| 732 ) | 791 ) |
| 733 return | 792 return |
| 734 | 793 |
| 735 # clean prams | 794 dump_model_to_h5(best_estimator_, outfile_object) |
| 736 best_estimator_ = clean_params(best_estimator_) | |
| 737 | |
| 738 main_est = get_main_estimator(best_estimator_) | |
| 739 | |
| 740 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"): | |
| 741 if outfile_weights: | |
| 742 main_est.save_weights(outfile_weights) | |
| 743 del main_est.model_ | |
| 744 del main_est.fit_params | |
| 745 del main_est.model_class_ | |
| 746 del main_est.validation_data | |
| 747 if getattr(main_est, "data_generator_", None): | |
| 748 del main_est.data_generator_ | |
| 749 | |
| 750 with open(outfile_object, "wb") as output_handler: | |
| 751 print("Best estimator is saved: %s " % repr(best_estimator_)) | |
| 752 pickle.dump(best_estimator_, output_handler, pickle.HIGHEST_PROTOCOL) | |
| 753 | 795 |
| 754 | 796 |
| 755 if __name__ == "__main__": | 797 if __name__ == "__main__": |
| 756 aparser = argparse.ArgumentParser() | 798 aparser = argparse.ArgumentParser() |
| 757 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 799 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
| 758 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | 800 aparser.add_argument("-e", "--estimator", dest="infile_estimator") |
| 759 aparser.add_argument("-X", "--infile1", dest="infile1") | 801 aparser.add_argument("-X", "--infile1", dest="infile1") |
| 760 aparser.add_argument("-y", "--infile2", dest="infile2") | 802 aparser.add_argument("-y", "--infile2", dest="infile2") |
| 761 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | 803 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") |
| 762 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | 804 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") |
| 763 aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights") | |
| 764 aparser.add_argument("-g", "--groups", dest="groups") | 805 aparser.add_argument("-g", "--groups", dest="groups") |
| 765 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 806 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
| 766 aparser.add_argument("-b", "--intervals", dest="intervals") | 807 aparser.add_argument("-b", "--intervals", dest="intervals") |
| 767 aparser.add_argument("-t", "--targets", dest="targets") | 808 aparser.add_argument("-t", "--targets", dest="targets") |
| 768 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 809 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
| 769 args = aparser.parse_args() | 810 args = aparser.parse_args() |
| 770 | 811 |
| 771 main( | 812 main(**vars(args)) |
| 772 args.inputs, | |
| 773 args.infile_estimator, | |
| 774 args.infile1, | |
| 775 args.infile2, | |
| 776 args.outfile_result, | |
| 777 outfile_object=args.outfile_object, | |
| 778 outfile_weights=args.outfile_weights, | |
| 779 groups=args.groups, | |
| 780 ref_seq=args.ref_seq, | |
| 781 intervals=args.intervals, | |
| 782 targets=args.targets, | |
| 783 fasta_path=args.fasta_path, | |
| 784 ) |
