Mercurial > repos > bgruening > sklearn_fitted_model_eval
comparison ml_visualization_ex.py @ 0:4fc9e02801f9 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
| author | bgruening |
|---|---|
| date | Fri, 01 Nov 2019 16:39:35 -0400 |
| parents | |
| children | 153c7a85f117 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:4fc9e02801f9 |
|---|---|
| 1 import argparse | |
| 2 import json | |
| 3 import numpy as np | |
| 4 import pandas as pd | |
| 5 import plotly | |
| 6 import plotly.graph_objs as go | |
| 7 import warnings | |
| 8 | |
| 9 from keras.models import model_from_json | |
| 10 from keras.utils import plot_model | |
| 11 from sklearn.feature_selection.base import SelectorMixin | |
| 12 from sklearn.metrics import precision_recall_curve, average_precision_score | |
| 13 from sklearn.metrics import roc_curve, auc | |
| 14 from sklearn.pipeline import Pipeline | |
| 15 from galaxy_ml.utils import load_model, read_columns, SafeEval | |
| 16 | |
| 17 | |
| 18 safe_eval = SafeEval() | |
| 19 | |
| 20 | |
| 21 def main(inputs, infile_estimator=None, infile1=None, | |
| 22 infile2=None, outfile_result=None, | |
| 23 outfile_object=None, groups=None, | |
| 24 ref_seq=None, intervals=None, | |
| 25 targets=None, fasta_path=None, | |
| 26 model_config=None): | |
| 27 """ | |
| 28 Parameter | |
| 29 --------- | |
| 30 inputs : str | |
| 31 File path to galaxy tool parameter | |
| 32 | |
| 33 infile_estimator : str, default is None | |
| 34 File path to estimator | |
| 35 | |
| 36 infile1 : str, default is None | |
| 37 File path to dataset containing features or true labels. | |
| 38 | |
| 39 infile2 : str, default is None | |
| 40 File path to dataset containing target values or predicted | |
| 41 probabilities. | |
| 42 | |
| 43 outfile_result : str, default is None | |
| 44 File path to save the results, either cv_results or test result | |
| 45 | |
| 46 outfile_object : str, default is None | |
| 47 File path to save searchCV object | |
| 48 | |
| 49 groups : str, default is None | |
| 50 File path to dataset containing groups labels | |
| 51 | |
| 52 ref_seq : str, default is None | |
| 53 File path to dataset containing genome sequence file | |
| 54 | |
| 55 intervals : str, default is None | |
| 56 File path to dataset containing interval file | |
| 57 | |
| 58 targets : str, default is None | |
| 59 File path to dataset compressed target bed file | |
| 60 | |
| 61 fasta_path : str, default is None | |
| 62 File path to dataset containing fasta file | |
| 63 | |
| 64 model_config : str, default is None | |
| 65 File path to dataset containing JSON config for neural networks | |
| 66 """ | |
| 67 warnings.simplefilter('ignore') | |
| 68 | |
| 69 with open(inputs, 'r') as param_handler: | |
| 70 params = json.load(param_handler) | |
| 71 | |
| 72 title = params['plotting_selection']['title'].strip() | |
| 73 plot_type = params['plotting_selection']['plot_type'] | |
| 74 if plot_type == 'feature_importances': | |
| 75 with open(infile_estimator, 'rb') as estimator_handler: | |
| 76 estimator = load_model(estimator_handler) | |
| 77 | |
| 78 column_option = (params['plotting_selection'] | |
| 79 ['column_selector_options'] | |
| 80 ['selected_column_selector_option']) | |
| 81 if column_option in ['by_index_number', 'all_but_by_index_number', | |
| 82 'by_header_name', 'all_but_by_header_name']: | |
| 83 c = (params['plotting_selection'] | |
| 84 ['column_selector_options']['col1']) | |
| 85 else: | |
| 86 c = None | |
| 87 | |
| 88 _, input_df = read_columns(infile1, c=c, | |
| 89 c_option=column_option, | |
| 90 return_df=True, | |
| 91 sep='\t', header='infer', | |
| 92 parse_dates=True) | |
| 93 | |
| 94 feature_names = input_df.columns.values | |
| 95 | |
| 96 if isinstance(estimator, Pipeline): | |
| 97 for st in estimator.steps[:-1]: | |
| 98 if isinstance(st[-1], SelectorMixin): | |
| 99 mask = st[-1].get_support() | |
| 100 feature_names = feature_names[mask] | |
| 101 estimator = estimator.steps[-1][-1] | |
| 102 | |
| 103 if hasattr(estimator, 'coef_'): | |
| 104 coefs = estimator.coef_ | |
| 105 else: | |
| 106 coefs = getattr(estimator, 'feature_importances_', None) | |
| 107 if coefs is None: | |
| 108 raise RuntimeError('The classifier does not expose ' | |
| 109 '"coef_" or "feature_importances_" ' | |
| 110 'attributes') | |
| 111 | |
| 112 threshold = params['plotting_selection']['threshold'] | |
| 113 if threshold is not None: | |
| 114 mask = (coefs > threshold) | (coefs < -threshold) | |
| 115 coefs = coefs[mask] | |
| 116 feature_names = feature_names[mask] | |
| 117 | |
| 118 # sort | |
| 119 indices = np.argsort(coefs)[::-1] | |
| 120 | |
| 121 trace = go.Bar(x=feature_names[indices], | |
| 122 y=coefs[indices]) | |
| 123 layout = go.Layout(title=title or "Feature Importances") | |
| 124 fig = go.Figure(data=[trace], layout=layout) | |
| 125 | |
| 126 elif plot_type == 'pr_curve': | |
| 127 df1 = pd.read_csv(infile1, sep='\t', header=None) | |
| 128 df2 = pd.read_csv(infile2, sep='\t', header=None) | |
| 129 | |
| 130 precision = {} | |
| 131 recall = {} | |
| 132 ap = {} | |
| 133 | |
| 134 pos_label = params['plotting_selection']['pos_label'].strip() \ | |
| 135 or None | |
| 136 for col in df1.columns: | |
| 137 y_true = df1[col].values | |
| 138 y_score = df2[col].values | |
| 139 | |
| 140 precision[col], recall[col], _ = precision_recall_curve( | |
| 141 y_true, y_score, pos_label=pos_label) | |
| 142 ap[col] = average_precision_score( | |
| 143 y_true, y_score, pos_label=pos_label or 1) | |
| 144 | |
| 145 if len(df1.columns) > 1: | |
| 146 precision["micro"], recall["micro"], _ = precision_recall_curve( | |
| 147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) | |
| 148 ap['micro'] = average_precision_score( | |
| 149 df1.values, df2.values, average='micro', | |
| 150 pos_label=pos_label or 1) | |
| 151 | |
| 152 data = [] | |
| 153 for key in precision.keys(): | |
| 154 trace = go.Scatter( | |
| 155 x=recall[key], | |
| 156 y=precision[key], | |
| 157 mode='lines', | |
| 158 name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro' | |
| 159 else 'column %s (area = %.2f)' % (key, ap[key]) | |
| 160 ) | |
| 161 data.append(trace) | |
| 162 | |
| 163 layout = go.Layout( | |
| 164 title=title or "Precision-Recall curve", | |
| 165 xaxis=dict(title='Recall'), | |
| 166 yaxis=dict(title='Precision') | |
| 167 ) | |
| 168 | |
| 169 fig = go.Figure(data=data, layout=layout) | |
| 170 | |
| 171 elif plot_type == 'roc_curve': | |
| 172 df1 = pd.read_csv(infile1, sep='\t', header=None) | |
| 173 df2 = pd.read_csv(infile2, sep='\t', header=None) | |
| 174 | |
| 175 fpr = {} | |
| 176 tpr = {} | |
| 177 roc_auc = {} | |
| 178 | |
| 179 pos_label = params['plotting_selection']['pos_label'].strip() \ | |
| 180 or None | |
| 181 for col in df1.columns: | |
| 182 y_true = df1[col].values | |
| 183 y_score = df2[col].values | |
| 184 | |
| 185 fpr[col], tpr[col], _ = roc_curve( | |
| 186 y_true, y_score, pos_label=pos_label) | |
| 187 roc_auc[col] = auc(fpr[col], tpr[col]) | |
| 188 | |
| 189 if len(df1.columns) > 1: | |
| 190 fpr["micro"], tpr["micro"], _ = roc_curve( | |
| 191 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) | |
| 192 roc_auc['micro'] = auc(fpr["micro"], tpr["micro"]) | |
| 193 | |
| 194 data = [] | |
| 195 for key in fpr.keys(): | |
| 196 trace = go.Scatter( | |
| 197 x=fpr[key], | |
| 198 y=tpr[key], | |
| 199 mode='lines', | |
| 200 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro' | |
| 201 else 'column %s (area = %.2f)' % (key, roc_auc[key]) | |
| 202 ) | |
| 203 data.append(trace) | |
| 204 | |
| 205 trace = go.Scatter(x=[0, 1], y=[0, 1], | |
| 206 mode='lines', | |
| 207 line=dict(color='black', dash='dash'), | |
| 208 showlegend=False) | |
| 209 data.append(trace) | |
| 210 | |
| 211 layout = go.Layout( | |
| 212 title=title or "Receiver operating characteristic curve", | |
| 213 xaxis=dict(title='False Positive Rate'), | |
| 214 yaxis=dict(title='True Positive Rate') | |
| 215 ) | |
| 216 | |
| 217 fig = go.Figure(data=data, layout=layout) | |
| 218 | |
| 219 elif plot_type == 'rfecv_gridscores': | |
| 220 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
| 221 scores = input_df.iloc[:, 0] | |
| 222 steps = params['plotting_selection']['steps'].strip() | |
| 223 steps = safe_eval(steps) | |
| 224 | |
| 225 data = go.Scatter( | |
| 226 x=list(range(len(scores))), | |
| 227 y=scores, | |
| 228 text=[str(_) for _ in steps] if steps else None, | |
| 229 mode='lines' | |
| 230 ) | |
| 231 layout = go.Layout( | |
| 232 xaxis=dict(title="Number of features selected"), | |
| 233 yaxis=dict(title="Cross validation score"), | |
| 234 title=title or None | |
| 235 ) | |
| 236 | |
| 237 fig = go.Figure(data=[data], layout=layout) | |
| 238 | |
| 239 elif plot_type == 'learning_curve': | |
| 240 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
| 241 plot_std_err = params['plotting_selection']['plot_std_err'] | |
| 242 data1 = go.Scatter( | |
| 243 x=input_df['train_sizes_abs'], | |
| 244 y=input_df['mean_train_scores'], | |
| 245 error_y=dict( | |
| 246 array=input_df['std_train_scores'] | |
| 247 ) if plot_std_err else None, | |
| 248 mode='lines', | |
| 249 name="Train Scores", | |
| 250 ) | |
| 251 data2 = go.Scatter( | |
| 252 x=input_df['train_sizes_abs'], | |
| 253 y=input_df['mean_test_scores'], | |
| 254 error_y=dict( | |
| 255 array=input_df['std_test_scores'] | |
| 256 ) if plot_std_err else None, | |
| 257 mode='lines', | |
| 258 name="Test Scores", | |
| 259 ) | |
| 260 layout = dict( | |
| 261 xaxis=dict( | |
| 262 title='No. of samples' | |
| 263 ), | |
| 264 yaxis=dict( | |
| 265 title='Performance Score' | |
| 266 ), | |
| 267 title=title or 'Learning Curve' | |
| 268 ) | |
| 269 fig = go.Figure(data=[data1, data2], layout=layout) | |
| 270 | |
| 271 elif plot_type == 'keras_plot_model': | |
| 272 with open(model_config, 'r') as f: | |
| 273 model_str = f.read() | |
| 274 model = model_from_json(model_str) | |
| 275 plot_model(model, to_file="output.png") | |
| 276 __import__('os').rename('output.png', 'output') | |
| 277 | |
| 278 return 0 | |
| 279 | |
| 280 plotly.offline.plot(fig, filename="output.html", | |
| 281 auto_open=False) | |
| 282 # to be discovered by `from_work_dir` | |
| 283 __import__('os').rename('output.html', 'output') | |
| 284 | |
| 285 | |
| 286 if __name__ == '__main__': | |
| 287 aparser = argparse.ArgumentParser() | |
| 288 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
| 289 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | |
| 290 aparser.add_argument("-X", "--infile1", dest="infile1") | |
| 291 aparser.add_argument("-y", "--infile2", dest="infile2") | |
| 292 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | |
| 293 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | |
| 294 aparser.add_argument("-g", "--groups", dest="groups") | |
| 295 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
| 296 aparser.add_argument("-b", "--intervals", dest="intervals") | |
| 297 aparser.add_argument("-t", "--targets", dest="targets") | |
| 298 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
| 299 aparser.add_argument("-c", "--model_config", dest="model_config") | |
| 300 args = aparser.parse_args() | |
| 301 | |
| 302 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | |
| 303 args.outfile_result, outfile_object=args.outfile_object, | |
| 304 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | |
| 305 targets=args.targets, fasta_path=args.fasta_path, | |
| 306 model_config=args.model_config) |
