Mercurial > repos > bgruening > sklearn_stacking_ensemble_models
comparison train_test_split.py @ 11:0380f10c4e04 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
| author | bgruening | 
|---|---|
| date | Fri, 30 Apr 2021 23:23:56 +0000 | 
| parents | 2d890789ac48 | 
| children | 
   comparison
  equal
  deleted
  inserted
  replaced
| 10:2d890789ac48 | 11:0380f10c4e04 | 
|---|---|
| 26 | 26 | 
| 27 nth_split = params["mode_selection"]["nth_split"] | 27 nth_split = params["mode_selection"]["nth_split"] | 
| 28 | 28 | 
| 29 # read groups | 29 # read groups | 
| 30 if infile_groups: | 30 if infile_groups: | 
| 31 header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None | 31 header = ( | 
| 32 column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][ | 32 "infer" | 
| 33 "selected_column_selector_option_g" | 33 if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) | 
| 34 ] | 34 else None | 
| 35 ) | |
| 36 column_option = params["mode_selection"]["cv_selector"]["groups_selector"][ | |
| 37 "column_selector_options_g" | |
| 38 ]["selected_column_selector_option_g"] | |
| 35 if column_option in [ | 39 if column_option in [ | 
| 36 "by_index_number", | 40 "by_index_number", | 
| 37 "all_but_by_index_number", | 41 "all_but_by_index_number", | 
| 38 "by_header_name", | 42 "by_header_name", | 
| 39 "all_but_by_header_name", | 43 "all_but_by_header_name", | 
| 40 ]: | 44 ]: | 
| 41 c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"] | 45 c = params["mode_selection"]["cv_selector"]["groups_selector"][ | 
| 46 "column_selector_options_g" | |
| 47 ]["col_g"] | |
| 42 else: | 48 else: | 
| 43 c = None | 49 c = None | 
| 44 | 50 | 
| 45 groups = read_columns( | 51 groups = read_columns( | 
| 46 infile_groups, | 52 infile_groups, | 
| 65 # construct the cv splitter object | 71 # construct the cv splitter object | 
| 66 splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) | 72 splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) | 
| 67 | 73 | 
| 68 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) | 74 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) | 
| 69 if nth_split > total_n_splits: | 75 if nth_split > total_n_splits: | 
| 70 raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split)) | 76 raise ValueError( | 
| 77 "Total number of splits is {}, but got `nth_split` " | |
| 78 "= {}".format(total_n_splits, nth_split) | |
| 79 ) | |
| 71 | 80 | 
| 72 i = 1 | 81 i = 1 | 
| 73 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): | 82 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): | 
| 74 # suppose nth_split >= 1 | 83 # suppose nth_split >= 1 | 
| 75 if i == nth_split: | 84 if i == nth_split: | 
| 135 | 144 | 
| 136 train, test = train_test_split(array, **options) | 145 train, test = train_test_split(array, **options) | 
| 137 | 146 | 
| 138 # cv splitter | 147 # cv splitter | 
| 139 else: | 148 else: | 
| 140 train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups) | 149 train, test = _get_single_cv_split( | 
| 150 params, array, infile_labels=infile_labels, infile_groups=infile_groups | |
| 151 ) | |
| 141 | 152 | 
| 142 print("Input shape: %s" % repr(array.shape)) | 153 print("Input shape: %s" % repr(array.shape)) | 
| 143 print("Train shape: %s" % repr(train.shape)) | 154 print("Train shape: %s" % repr(train.shape)) | 
| 144 print("Test shape: %s" % repr(test.shape)) | 155 print("Test shape: %s" % repr(test.shape)) | 
| 145 train.to_csv(outfile_train, sep="\t", header=input_header, index=False) | 156 train.to_csv(outfile_train, sep="\t", header=input_header, index=False) | 
