Mercurial > repos > bgruening > stacking_ensemble_models
comparison train_test_split.py @ 3:0a1812986bc3 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening |
|---|---|
| date | Wed, 09 Aug 2023 11:10:37 +0000 |
| parents | 38c4f8a98038 |
| children |
comparison
equal
deleted
inserted
replaced
| 2:38c4f8a98038 | 3:0a1812986bc3 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import json | 2 import json |
| 3 import warnings | |
| 4 from distutils.version import LooseVersion as Version | |
| 5 | |
| 3 import pandas as pd | 6 import pandas as pd |
| 4 import warnings | 7 from galaxy_ml import __version__ as galaxy_ml_version |
| 5 | |
| 6 from galaxy_ml.model_validations import train_test_split | 8 from galaxy_ml.model_validations import train_test_split |
| 7 from galaxy_ml.utils import get_cv, read_columns | 9 from galaxy_ml.utils import get_cv, read_columns |
| 8 | 10 |
| 9 | 11 |
| 10 def _get_single_cv_split(params, array, infile_labels=None, | 12 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None): |
| 11 infile_groups=None): | 13 """output (train, test) subset from a cv splitter |
| 12 """ output (train, test) subset from a cv splitter | |
| 13 | 14 |
| 14 Parameters | 15 Parameters |
| 15 ---------- | 16 ---------- |
| 16 params : dict | 17 params : dict |
| 17 Galaxy tool inputs | 18 Galaxy tool inputs |
| 23 File path to dataset containing group values | 24 File path to dataset containing group values |
| 24 """ | 25 """ |
| 25 y = None | 26 y = None |
| 26 groups = None | 27 groups = None |
| 27 | 28 |
| 28 nth_split = params['mode_selection']['nth_split'] | 29 nth_split = params["mode_selection"]["nth_split"] |
| 29 | 30 |
| 30 # read groups | 31 # read groups |
| 31 if infile_groups: | 32 if infile_groups: |
| 32 header = 'infer' if (params['mode_selection']['cv_selector'] | 33 header = ( |
| 33 ['groups_selector']['header_g']) else None | 34 "infer" |
| 34 column_option = (params['mode_selection']['cv_selector'] | 35 if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) |
| 35 ['groups_selector']['column_selector_options_g'] | 36 else None |
| 36 ['selected_column_selector_option_g']) | 37 ) |
| 37 if column_option in ['by_index_number', 'all_but_by_index_number', | 38 column_option = params["mode_selection"]["cv_selector"]["groups_selector"][ |
| 38 'by_header_name', 'all_but_by_header_name']: | 39 "column_selector_options_g" |
| 39 c = (params['mode_selection']['cv_selector']['groups_selector'] | 40 ]["selected_column_selector_option_g"] |
| 40 ['column_selector_options_g']['col_g']) | 41 if column_option in [ |
| 42 "by_index_number", | |
| 43 "all_but_by_index_number", | |
| 44 "by_header_name", | |
| 45 "all_but_by_header_name", | |
| 46 ]: | |
| 47 c = params["mode_selection"]["cv_selector"]["groups_selector"][ | |
| 48 "column_selector_options_g" | |
| 49 ]["col_g"] | |
| 41 else: | 50 else: |
| 42 c = None | 51 c = None |
| 43 | 52 |
| 44 groups = read_columns(infile_groups, c=c, c_option=column_option, | 53 groups = read_columns( |
| 45 sep='\t', header=header, parse_dates=True) | 54 infile_groups, |
| 55 c=c, | |
| 56 c_option=column_option, | |
| 57 sep="\t", | |
| 58 header=header, | |
| 59 parse_dates=True, | |
| 60 ) | |
| 46 groups = groups.ravel() | 61 groups = groups.ravel() |
| 47 | 62 |
| 48 params['mode_selection']['cv_selector']['groups_selector'] = groups | 63 params["mode_selection"]["cv_selector"]["groups_selector"] = groups |
| 49 | 64 |
| 50 # read labels | 65 # read labels |
| 51 if infile_labels: | 66 if infile_labels: |
| 52 target_input = (params['mode_selection'] | 67 target_input = params["mode_selection"]["cv_selector"].pop("target_input") |
| 53 ['cv_selector'].pop('target_input')) | 68 header = "infer" if target_input["header1"] else None |
| 54 header = 'infer' if target_input['header1'] else None | 69 col_index = target_input["col"][0] - 1 |
| 55 col_index = target_input['col'][0] - 1 | 70 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) |
| 56 df = pd.read_csv(infile_labels, sep='\t', header=header, | |
| 57 parse_dates=True) | |
| 58 y = df.iloc[:, col_index].values | 71 y = df.iloc[:, col_index].values |
| 59 | 72 |
| 60 # construct the cv splitter object | 73 # construct the cv splitter object |
| 61 splitter, groups = get_cv(params['mode_selection']['cv_selector']) | 74 cv_selector = params["mode_selection"]["cv_selector"] |
| 75 if Version(galaxy_ml_version) < Version("0.8.3"): | |
| 76 cv_selector.pop("n_stratification_bins", None) | |
| 77 splitter, groups = get_cv(cv_selector) | |
| 62 | 78 |
| 63 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) | 79 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) |
| 64 if nth_split > total_n_splits: | 80 if nth_split > total_n_splits: |
| 65 raise ValueError("Total number of splits is {}, but got `nth_split` " | 81 raise ValueError( |
| 66 "= {}".format(total_n_splits, nth_split)) | 82 "Total number of splits is {}, but got `nth_split` " |
| 83 "= {}".format(total_n_splits, nth_split) | |
| 84 ) | |
| 67 | 85 |
| 68 i = 1 | 86 i = 1 |
| 69 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): | 87 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): |
| 70 # suppose nth_split >= 1 | 88 # suppose nth_split >= 1 |
| 71 if i == nth_split: | 89 if i == nth_split: |
| 77 test = array.iloc[test_index, :] | 95 test = array.iloc[test_index, :] |
| 78 | 96 |
| 79 return train, test | 97 return train, test |
| 80 | 98 |
| 81 | 99 |
| 82 def main(inputs, infile_array, outfile_train, outfile_test, | 100 def main( |
| 83 infile_labels=None, infile_groups=None): | 101 inputs, |
| 102 infile_array, | |
| 103 outfile_train, | |
| 104 outfile_test, | |
| 105 infile_labels=None, | |
| 106 infile_groups=None, | |
| 107 ): | |
| 84 """ | 108 """ |
| 85 Parameter | 109 Parameter |
| 86 --------- | 110 --------- |
| 87 inputs : str | 111 inputs : str |
| 88 File path to galaxy tool parameter | 112 File path to galaxy tool parameter |
| 100 File path to dataset containing train split | 124 File path to dataset containing train split |
| 101 | 125 |
| 102 outfile_test : str | 126 outfile_test : str |
| 103 File path to dataset containing test split | 127 File path to dataset containing test split |
| 104 """ | 128 """ |
| 105 warnings.simplefilter('ignore') | 129 warnings.simplefilter("ignore") |
| 106 | 130 |
| 107 with open(inputs, 'r') as param_handler: | 131 with open(inputs, "r") as param_handler: |
| 108 params = json.load(param_handler) | 132 params = json.load(param_handler) |
| 109 | 133 |
| 110 input_header = params['header0'] | 134 input_header = params["header0"] |
| 111 header = 'infer' if input_header else None | 135 header = "infer" if input_header else None |
| 112 array = pd.read_csv(infile_array, sep='\t', header=header, | 136 array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True) |
| 113 parse_dates=True) | |
| 114 | 137 |
| 115 # train test split | 138 # train test split |
| 116 if params['mode_selection']['selected_mode'] == 'train_test_split': | 139 if params["mode_selection"]["selected_mode"] == "train_test_split": |
| 117 options = params['mode_selection']['options'] | 140 options = params["mode_selection"]["options"] |
| 118 shuffle_selection = options.pop('shuffle_selection') | 141 shuffle_selection = options.pop("shuffle_selection") |
| 119 options['shuffle'] = shuffle_selection['shuffle'] | 142 options["shuffle"] = shuffle_selection["shuffle"] |
| 120 if infile_labels: | 143 if infile_labels: |
| 121 header = 'infer' if shuffle_selection['header1'] else None | 144 header = "infer" if shuffle_selection["header1"] else None |
| 122 col_index = shuffle_selection['col'][0] - 1 | 145 col_index = shuffle_selection["col"][0] - 1 |
| 123 df = pd.read_csv(infile_labels, sep='\t', header=header, | 146 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) |
| 124 parse_dates=True) | |
| 125 labels = df.iloc[:, col_index].values | 147 labels = df.iloc[:, col_index].values |
| 126 options['labels'] = labels | 148 options["labels"] = labels |
| 127 | 149 |
| 128 train, test = train_test_split(array, **options) | 150 train, test = train_test_split(array, **options) |
| 129 | 151 |
| 130 # cv splitter | 152 # cv splitter |
| 131 else: | 153 else: |
| 132 train, test = _get_single_cv_split(params, array, | 154 train, test = _get_single_cv_split( |
| 133 infile_labels=infile_labels, | 155 params, array, infile_labels=infile_labels, infile_groups=infile_groups |
| 134 infile_groups=infile_groups) | 156 ) |
| 135 | 157 |
| 136 print("Input shape: %s" % repr(array.shape)) | 158 print("Input shape: %s" % repr(array.shape)) |
| 137 print("Train shape: %s" % repr(train.shape)) | 159 print("Train shape: %s" % repr(train.shape)) |
| 138 print("Test shape: %s" % repr(test.shape)) | 160 print("Test shape: %s" % repr(test.shape)) |
| 139 train.to_csv(outfile_train, sep='\t', header=input_header, index=False) | 161 train.to_csv(outfile_train, sep="\t", header=input_header, index=False) |
| 140 test.to_csv(outfile_test, sep='\t', header=input_header, index=False) | 162 test.to_csv(outfile_test, sep="\t", header=input_header, index=False) |
| 141 | 163 |
| 142 | 164 |
| 143 if __name__ == '__main__': | 165 if __name__ == "__main__": |
| 144 aparser = argparse.ArgumentParser() | 166 aparser = argparse.ArgumentParser() |
| 145 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 167 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
| 146 aparser.add_argument("-X", "--infile_array", dest="infile_array") | 168 aparser.add_argument("-X", "--infile_array", dest="infile_array") |
| 147 aparser.add_argument("-y", "--infile_labels", dest="infile_labels") | 169 aparser.add_argument("-y", "--infile_labels", dest="infile_labels") |
| 148 aparser.add_argument("-g", "--infile_groups", dest="infile_groups") | 170 aparser.add_argument("-g", "--infile_groups", dest="infile_groups") |
| 149 aparser.add_argument("-o", "--outfile_train", dest="outfile_train") | 171 aparser.add_argument("-o", "--outfile_train", dest="outfile_train") |
| 150 aparser.add_argument("-t", "--outfile_test", dest="outfile_test") | 172 aparser.add_argument("-t", "--outfile_test", dest="outfile_test") |
| 151 args = aparser.parse_args() | 173 args = aparser.parse_args() |
| 152 | 174 |
| 153 main(args.inputs, args.infile_array, args.outfile_train, | 175 main( |
| 154 args.outfile_test, args.infile_labels, args.infile_groups) | 176 args.inputs, |
| 177 args.infile_array, | |
| 178 args.outfile_train, | |
| 179 args.outfile_test, | |
| 180 args.infile_labels, | |
| 181 args.infile_groups, | |
| 182 ) |
