Mercurial > repos > bgruening > sklearn_ensemble
comparison ensemble.xml @ 6:4c2fae2db5d1 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tools/sklearn commit 641ac64ded23fbb6fe85d5f13926da12dcce4e76
| author | bgruening |
|---|---|
| date | Tue, 13 Mar 2018 04:51:40 -0400 |
| parents | 3bc536788043 |
| children | ea8b1c89c20b |
comparison
equal
deleted
inserted
replaced
| 5:1059756bb41b | 6:4c2fae2db5d1 |
|---|---|
| 23 from scipy.io import mmread | 23 from scipy.io import mmread |
| 24 | 24 |
| 25 input_json_path = sys.argv[1] | 25 input_json_path = sys.argv[1] |
| 26 params = json.load(open(input_json_path, "r")) | 26 params = json.load(open(input_json_path, "r")) |
| 27 | 27 |
| 28 @COLUMNS_FUNCTION@ | |
| 29 | |
| 28 #if $selected_tasks.selected_task == "train": | 30 #if $selected_tasks.selected_task == "train": |
| 29 | 31 |
| 30 algorithm = params["selected_tasks"]["selected_algorithms"]["selected_algorithm"] | 32 algorithm = params["selected_tasks"]["selected_algorithms"]["selected_algorithm"] |
| 31 options = params["selected_tasks"]["selected_algorithms"]["options"] | 33 options = params["selected_tasks"]["selected_algorithms"]["options"] |
| 32 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] | 34 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] |
| 33 if input_type=="tabular": | 35 if input_type=="tabular": |
| 34 col1 = params["selected_tasks"]["selected_algorithms"]["input_options"]["col1"] | 36 X = read_columns( |
| 35 col1 = list(map(lambda x: x - 1, col1)) | 37 "$selected_tasks.selected_algorithms.input_options.infile1", |
| 36 f1 = pandas.read_csv("$selected_tasks.selected_algorithms.input_options.infile1", sep='\t', header=None, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 38 "$selected_tasks.selected_algorithms.input_options.col1", |
| 37 X = f1.iloc[:,col1].values | 39 sep='\t', |
| 40 header=None, | |
| 41 parse_dates=True | |
| 42 ) | |
| 38 else: | 43 else: |
| 39 X = mmread(open("$selected_tasks.selected_algorithms.input_options.infile1", 'r')) | 44 X = mmread(open("$selected_tasks.selected_algorithms.input_options.infile1", 'r')) |
| 40 | 45 |
| 41 col2 = params["selected_tasks"]["selected_algorithms"]["input_options"]["col2"] | 46 y = read_columns( |
| 42 col2 = list(map(lambda x: x - 1, col2)) | 47 "$selected_tasks.selected_algorithms.input_options.infile2", |
| 43 f2 = pandas.read_csv("$selected_tasks.selected_algorithms.input_options.infile2", sep='\t', header=None, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 48 "$selected_tasks.selected_algorithms.input_options.col2", |
| 44 y = f2.iloc[:,col2].values | 49 sep='\t', |
| 50 header=None, | |
| 51 parse_dates=True | |
| 52 ) | |
| 45 | 53 |
| 46 my_class = getattr(sklearn.ensemble, algorithm) | 54 my_class = getattr(sklearn.ensemble, algorithm) |
| 47 estimator = my_class(**options) | 55 estimator = my_class(**options) |
| 48 estimator.fit(X,y) | 56 estimator.fit(X,y) |
| 49 pickle.dump(estimator,open("$outfile_fit", 'w+'), pickle.HIGHEST_PROTOCOL) | 57 pickle.dump(estimator,open("$outfile_fit", 'w+'), pickle.HIGHEST_PROTOCOL) |
| 50 | 58 |
| 51 #else: | 59 #else: |
| 52 classifier_object = pickle.load(open("$selected_tasks.infile_model", 'r')) | 60 classifier_object = pickle.load(open("$selected_tasks.infile_model", 'r')) |
| 53 data = pandas.read_csv("$selected_tasks.infile_data", sep='\t', header=0, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 61 data = pandas.read_csv("$selected_tasks.infile_data", sep='\t', header=0, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False) |
| 54 prediction = classifier_object.predict(data) | 62 prediction = classifier_object.predict(data) |
| 55 prediction_df = pandas.DataFrame(prediction) | 63 prediction_df = pandas.DataFrame(prediction) |
| 56 res = pandas.concat([data, prediction_df], axis=1) | 64 res = pandas.concat([data, prediction_df], axis=1) |
| 57 res.to_csv(path_or_buf = "$outfile_predict", sep="\t", index=False) | 65 res.to_csv(path_or_buf = "$outfile_predict", sep="\t", index=False) |
| 58 #end if | 66 #end if |
