Mercurial > repos > bgruening > sklearn_generalized_linear
comparison train_test_split.py @ 40:a8771df897b2 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening |
|---|---|
| date | Wed, 09 Aug 2023 11:13:19 +0000 |
| parents | 34f295eb5782 |
| children |
comparison
equal
deleted
inserted
replaced
| 39:1a72afcb0752 | 40:a8771df897b2 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import json | 2 import json |
| 3 import warnings | 3 import warnings |
| 4 from distutils.version import LooseVersion as Version | |
| 4 | 5 |
| 5 import pandas as pd | 6 import pandas as pd |
| 7 from galaxy_ml import __version__ as galaxy_ml_version | |
| 6 from galaxy_ml.model_validations import train_test_split | 8 from galaxy_ml.model_validations import train_test_split |
| 7 from galaxy_ml.utils import get_cv, read_columns | 9 from galaxy_ml.utils import get_cv, read_columns |
| 8 | 10 |
| 9 | 11 |
| 10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None): | 12 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None): |
| 67 col_index = target_input["col"][0] - 1 | 69 col_index = target_input["col"][0] - 1 |
| 68 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) | 70 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) |
| 69 y = df.iloc[:, col_index].values | 71 y = df.iloc[:, col_index].values |
| 70 | 72 |
| 71 # construct the cv splitter object | 73 # construct the cv splitter object |
| 72 splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) | 74 cv_selector = params["mode_selection"]["cv_selector"] |
| 75 if Version(galaxy_ml_version) < Version("0.8.3"): | |
| 76 cv_selector.pop("n_stratification_bins", None) | |
| 77 splitter, groups = get_cv(cv_selector) | |
| 73 | 78 |
| 74 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) | 79 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) |
| 75 if nth_split > total_n_splits: | 80 if nth_split > total_n_splits: |
| 76 raise ValueError( | 81 raise ValueError( |
| 77 "Total number of splits is {}, but got `nth_split` " | 82 "Total number of splits is {}, but got `nth_split` " |
