Mercurial > repos > bgruening > sklearn_nn_classifier
comparison nn_classifier.xml @ 0:edf1078e21bb draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 38fa34f4b3f7298582644add88f9fe63554f32bb
| author | bgruening |
|---|---|
| date | Sat, 04 Aug 2018 13:25:22 -0400 |
| parents | |
| children | 9081adc10423 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:edf1078e21bb |
|---|---|
| 1 <tool id="sklearn_nn_classifier" name="Nearest Neighbors Classification" version="@VERSION@"> | |
| 2 <description></description> | |
| 3 <macros> | |
| 4 <import>main_macros.xml</import> | |
| 5 </macros> | |
| 6 <expand macro="python_requirements"/> | |
| 7 <expand macro="macro_stdio"/> | |
| 8 <version_command>echo "@VERSION@"</version_command> | |
| 9 <command><![CDATA[ | |
| 10 python "$nnc_script" '$inputs' | |
| 11 ]]> | |
| 12 </command> | |
| 13 <configfiles> | |
| 14 <inputs name="inputs"/> | |
| 15 <configfile name="nnc_script"> | |
| 16 <![CDATA[ | |
| 17 import sys | |
| 18 import json | |
| 19 import numpy as np | |
| 20 import sklearn.neighbors | |
| 21 import pandas | |
| 22 import pickle | |
| 23 | |
| 24 @COLUMNS_FUNCTION@ | |
| 25 @GET_X_y_FUNCTION@ | |
| 26 | |
| 27 input_json_path = sys.argv[1] | |
| 28 with open(input_json_path, "r") as param_handler: | |
| 29 params = json.load(param_handler) | |
| 30 | |
| 31 #if $selected_tasks.selected_task == "load": | |
| 32 | |
| 33 with open("$infile_model", 'rb') as model_handler: | |
| 34 classifier_object = pickle.load(model_handler) | |
| 35 | |
| 36 header = 'infer' if params["selected_tasks"]["header"] else None | |
| 37 data = pandas.read_csv("$selected_tasks.infile_data", sep='\t', header=header, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False) | |
| 38 prediction = classifier_object.predict(data) | |
| 39 prediction_df = pandas.DataFrame(prediction) | |
| 40 res = pandas.concat([data, prediction_df], axis=1) | |
| 41 res.to_csv(path_or_buf = "$outfile_predict", sep="\t", index=False) | |
| 42 | |
| 43 #else: | |
| 44 | |
| 45 X, y = get_X_y(params, "$selected_tasks.selected_algorithms.input_options.infile1" ,"$selected_tasks.selected_algorithms.input_options.infile2") | |
| 46 | |
| 47 selected_algorithm = params["selected_tasks"]["selected_algorithms"]["selected_algorithm"] | |
| 48 | |
| 49 if selected_algorithm == "nneighbors": | |
| 50 classifier = params["selected_tasks"]["selected_algorithms"]["sampling_methods"]["sampling_method"] | |
| 51 sys.stdout.write(classifier) | |
| 52 options = params["selected_tasks"]["selected_algorithms"]["sampling_methods"]["options"] | |
| 53 sys.stdout.write(str(options)) | |
| 54 elif selected_algorithm == "ncentroid": | |
| 55 options = params["selected_tasks"]["selected_algorithms"]["options"] | |
| 56 classifier = "NearestCentroid" | |
| 57 | |
| 58 my_class = getattr(sklearn.neighbors, classifier) | |
| 59 classifier_object = my_class(**options) | |
| 60 classifier_object.fit(X, y) | |
| 61 | |
| 62 with open("$outfile_fit", 'wb') as out_handler: | |
| 63 pickle.dump(classifier_object, out_handler) | |
| 64 | |
| 65 #end if | |
| 66 | |
| 67 ]]> | |
| 68 </configfile> | |
| 69 </configfiles> | |
| 70 <inputs> | |
| 71 <expand macro="sl_Conditional" model="zip"><!--Todo: add sparse to targets--> | |
| 72 <param name="selected_algorithm" type="select" label="Classifier type"> | |
| 73 <option value="nneighbors">Nearest Neighbors</option> | |
| 74 <option value="ncentroid">Nearest Centroid</option> | |
| 75 </param> | |
| 76 <when value="nneighbors"> | |
| 77 <expand macro="sl_mixed_input"/> | |
| 78 <conditional name="sampling_methods"> | |
| 79 <param name="sampling_method" type="select" label="Neighbor selection method"> | |
| 80 <option value="KNeighborsClassifier" selected="true">K-nearest neighbors</option> | |
| 81 <option value="RadiusNeighborsClassifier">Radius-based</option> | |
| 82 </param> | |
| 83 <when value="KNeighborsClassifier"> | |
| 84 <expand macro="nn_advanced_options"> | |
| 85 <param argument="n_neighbors" type="integer" optional="true" value="5" label="Number of neighbors" help=" "/> | |
| 86 </expand> | |
| 87 </when> | |
| 88 <when value="RadiusNeighborsClassifier"> | |
| 89 <expand macro="nn_advanced_options"> | |
| 90 <param argument="radius" type="float" optional="true" value="1.0" label="Radius" | |
| 91 help="Range of parameter space to use by default for :meth ''radius_neighbors'' queries."/> | |
| 92 </expand> | |
| 93 </when> | |
| 94 </conditional> | |
| 95 </when> | |
| 96 <when value="ncentroid"> | |
| 97 <expand macro="sl_mixed_input"/> | |
| 98 <section name="options" title="Advanced Options" expanded="False"> | |
| 99 <param argument="metric" type="text" optional="true" value="euclidean" label="Metric" | |
| 100 help="The metric to use when calculating distance between instances in a feature array."/> | |
| 101 <param argument="shrink_threshold" type="float" optional="true" value="" label="Shrink threshold" | |
| 102 help="Floating point number for shrinking centroids to remove features."/> | |
| 103 </section> | |
| 104 </when> | |
| 105 </expand> | |
| 106 </inputs> | |
| 107 | |
| 108 <expand macro="output"/> | |
| 109 | |
| 110 <tests> | |
| 111 <test> | |
| 112 <param name="infile1" value="train_set.tabular" ftype="tabular"/> | |
| 113 <param name="infile2" value="train_set.tabular" ftype="tabular"/> | |
| 114 <param name="header1" value="True"/> | |
| 115 <param name="header2" value="True"/> | |
| 116 <param name="col1" value="1,2,3,4"/> | |
| 117 <param name="col2" value="5"/> | |
| 118 <param name="selected_task" value="train"/> | |
| 119 <param name="selected_algorithm" value="nneighbors"/> | |
| 120 <param name="sampling_method" value="KNeighborsClassifier" /> | |
| 121 <param name="algorithm" value="brute" /> | |
| 122 <output name="outfile_fit" file="nn_model01.txt"/> | |
| 123 </test> | |
| 124 <test> | |
| 125 <param name="infile1" value="train_set.tabular" ftype="tabular"/> | |
| 126 <param name="infile2" value="train_set.tabular" ftype="tabular"/> | |
| 127 <param name="header1" value="True"/> | |
| 128 <param name="header2" value="True"/> | |
| 129 <param name="col1" value="1,2,3,4"/> | |
| 130 <param name="col2" value="5"/> | |
| 131 <param name="selected_task" value="train"/> | |
| 132 <param name="selected_algorithm" value=""/> | |
| 133 <param name="selected_algorithm" value="nneighbors"/> | |
| 134 <param name="sampling_method" value="RadiusNeighborsClassifier" /> | |
| 135 <output name="outfile_fit" file="nn_model02.txt"/> | |
| 136 </test> | |
| 137 <test> | |
| 138 <param name="infile1" value="train_set.tabular" ftype="tabular"/> | |
| 139 <param name="infile2" value="train_set.tabular" ftype="tabular"/> | |
| 140 <param name="header1" value="True"/> | |
| 141 <param name="header2" value="True"/> | |
| 142 <param name="col1" value="1,2,3,4"/> | |
| 143 <param name="col2" value="5"/> | |
| 144 <param name="selected_task" value="train"/> | |
| 145 <param name="selected_algorithm" value="ncentroid"/> | |
| 146 <output name="outfile_fit" file="nn_model03.txt"/> | |
| 147 </test> | |
| 148 <test> | |
| 149 <param name="infile_model" value="nn_model01.txt" ftype="txt"/> | |
| 150 <param name="infile_data" value="test_set.tabular" ftype="tabular"/> | |
| 151 <param name="header" value="True"/> | |
| 152 <param name="selected_task" value="load"/> | |
| 153 <output name="outfile_predict" file="nn_prediction_result01.tabular"/> | |
| 154 </test> | |
| 155 <test> | |
| 156 <param name="infile_model" value="nn_model02.txt" ftype="txt"/> | |
| 157 <param name="infile_data" value="test_set.tabular" ftype="tabular"/> | |
| 158 <param name="header" value="True"/> | |
| 159 <param name="selected_task" value="load"/> | |
| 160 <output name="outfile_predict" file="nn_prediction_result02.tabular"/> | |
| 161 </test> | |
| 162 <test> | |
| 163 <param name="infile_model" value="nn_model03.txt" ftype="txt"/> | |
| 164 <param name="infile_data" value="test_set.tabular" ftype="tabular"/> | |
| 165 <param name="header" value="True"/> | |
| 166 <param name="selected_task" value="load"/> | |
| 167 <output name="outfile_predict" file="nn_prediction_result03.tabular"/> | |
| 168 </test> | |
| 169 </tests> | |
| 170 <help><![CDATA[ | |
| 171 **What it does** | |
| 172 This module implements the k-nearest neighbors classification algorithms. | |
| 173 For more information check http://scikit-learn.org/stable/modules/neighbors.html | |
| 174 ]]></help> | |
| 175 <expand macro="sklearn_citation"/> | |
| 176 </tool> |
