Mercurial > repos > bgruening > sklearn_train_test_split
comparison train_test_eval.py @ 6:81ab4951f2a3 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ca87db9c038a6fcf96aa39da50f384865fd932ff"
| author | bgruening |
|---|---|
| date | Tue, 20 Apr 2021 17:09:29 +0000 |
| parents | c0ed68e280a7 |
| children | 82f89e379413 |
comparison
equal
deleted
inserted
replaced
| 5:c0ed68e280a7 | 6:81ab4951f2a3 |
|---|---|
| 7 | 7 |
| 8 import joblib | 8 import joblib |
| 9 import numpy as np | 9 import numpy as np |
| 10 import pandas as pd | 10 import pandas as pd |
| 11 from galaxy_ml.model_validations import train_test_split | 11 from galaxy_ml.model_validations import train_test_split |
| 12 from galaxy_ml.utils import ( | 12 from galaxy_ml.utils import (get_module, get_scoring, load_model, |
| 13 get_module, | 13 read_columns, SafeEval, try_get_attr) |
| 14 get_scoring, | |
| 15 load_model, | |
| 16 read_columns, | |
| 17 SafeEval, | |
| 18 try_get_attr, | |
| 19 ) | |
| 20 from scipy.io import mmread | 14 from scipy.io import mmread |
| 21 from sklearn import pipeline | 15 from sklearn import pipeline |
| 22 from sklearn.metrics.scorer import _check_multimetric_scoring | 16 from sklearn.metrics.scorer import _check_multimetric_scoring |
| 23 from sklearn.model_selection import _search, _validation | 17 from sklearn.model_selection import _search, _validation |
| 24 from sklearn.model_selection._validation import _score | 18 from sklearn.model_selection._validation import _score |
| 25 from sklearn.utils import indexable, safe_indexing | 19 from sklearn.utils import indexable, safe_indexing |
| 26 | |
| 27 | 20 |
| 28 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 21 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") |
| 29 setattr(_search, "_fit_and_score", _fit_and_score) | 22 setattr(_search, "_fit_and_score", _fit_and_score) |
| 30 setattr(_validation, "_fit_and_score", _fit_and_score) | 23 setattr(_validation, "_fit_and_score", _fit_and_score) |
| 31 | 24 |
| 260 infile2 = loaded_df[df_key] | 253 infile2 = loaded_df[df_key] |
| 261 else: | 254 else: |
| 262 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) | 255 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) |
| 263 loaded_df[df_key] = infile2 | 256 loaded_df[df_key] = infile2 |
| 264 | 257 |
| 265 y = read_columns(infile2, | 258 y = read_columns( |
| 266 c=c, | 259 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True |
| 267 c_option=column_option, | 260 ) |
| 268 sep='\t', | |
| 269 header=header, | |
| 270 parse_dates=True) | |
| 271 if len(y.shape) == 2 and y.shape[1] == 1: | 261 if len(y.shape) == 2 and y.shape[1] == 1: |
| 272 y = y.ravel() | 262 y = y.ravel() |
| 273 if input_type == "refseq_and_interval": | 263 if input_type == "refseq_and_interval": |
| 274 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) | 264 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) |
| 275 y = None | 265 y = None |
| 297 | 287 |
| 298 df_key = groups + repr(header) | 288 df_key = groups + repr(header) |
| 299 if df_key in loaded_df: | 289 if df_key in loaded_df: |
| 300 groups = loaded_df[df_key] | 290 groups = loaded_df[df_key] |
| 301 | 291 |
| 302 groups = read_columns(groups, | 292 groups = read_columns( |
| 303 c=c, | 293 groups, |
| 304 c_option=column_option, | 294 c=c, |
| 305 sep='\t', | 295 c_option=column_option, |
| 306 header=header, | 296 sep="\t", |
| 307 parse_dates=True) | 297 header=header, |
| 298 parse_dates=True, | |
| 299 ) | |
| 308 groups = groups.ravel() | 300 groups = groups.ravel() |
| 309 | 301 |
| 310 # del loaded_df | 302 # del loaded_df |
| 311 del loaded_df | 303 del loaded_df |
| 312 | 304 |
| 369 else: | 361 else: |
| 370 raise ValueError( | 362 raise ValueError( |
| 371 "Stratified shuffle split is not " "applicable on empty target values!" | 363 "Stratified shuffle split is not " "applicable on empty target values!" |
| 372 ) | 364 ) |
| 373 | 365 |
| 374 X_train, X_test, y_train, y_test, groups_train, _groups_test = train_test_split_none( | 366 ( |
| 375 X, y, groups, **test_split_options | 367 X_train, |
| 376 ) | 368 X_test, |
| 369 y_train, | |
| 370 y_test, | |
| 371 groups_train, | |
| 372 _groups_test, | |
| 373 ) = train_test_split_none(X, y, groups, **test_split_options) | |
| 377 | 374 |
| 378 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"] | 375 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"] |
| 379 | 376 |
| 380 # handle validation (second) split | 377 # handle validation (second) split |
| 381 if exp_scheme == "train_val_test": | 378 if exp_scheme == "train_val_test": |
