Mercurial > repos > bgruening > stacking_ensemble_models
comparison ml_visualization_ex.py @ 3:0a1812986bc3 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 11:10:37 +0000 |
parents | 38c4f8a98038 |
children | ba7fb6b33cd0 |
comparison
equal
deleted
inserted
replaced
2:38c4f8a98038 | 3:0a1812986bc3 |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import os | |
4 import warnings | |
5 | |
3 import matplotlib | 6 import matplotlib |
4 import matplotlib.pyplot as plt | 7 import matplotlib.pyplot as plt |
5 import numpy as np | 8 import numpy as np |
6 import os | |
7 import pandas as pd | 9 import pandas as pd |
8 import plotly | 10 import plotly |
9 import plotly.graph_objs as go | 11 import plotly.graph_objs as go |
10 import warnings | 12 from galaxy_ml.model_persist import load_model_from_h5 |
11 | 13 from galaxy_ml.utils import read_columns, SafeEval |
12 from keras.models import model_from_json | 14 from sklearn.feature_selection._base import SelectorMixin |
13 from keras.utils import plot_model | 15 from sklearn.metrics import ( |
14 from sklearn.feature_selection.base import SelectorMixin | 16 auc, |
15 from sklearn.metrics import precision_recall_curve, average_precision_score | 17 average_precision_score, |
16 from sklearn.metrics import roc_curve, auc | 18 precision_recall_curve, |
19 roc_curve, | |
20 ) | |
17 from sklearn.pipeline import Pipeline | 21 from sklearn.pipeline import Pipeline |
18 from galaxy_ml.utils import load_model, read_columns, SafeEval | 22 from tensorflow.keras.models import model_from_json |
19 | 23 from tensorflow.keras.utils import plot_model |
20 | 24 |
21 safe_eval = SafeEval() | 25 safe_eval = SafeEval() |
22 | 26 |
23 # plotly default colors | 27 # plotly default colors |
24 default_colors = [ | 28 default_colors = [ |
25 '#1f77b4', # muted blue | 29 "#1f77b4", # muted blue |
26 '#ff7f0e', # safety orange | 30 "#ff7f0e", # safety orange |
27 '#2ca02c', # cooked asparagus green | 31 "#2ca02c", # cooked asparagus green |
28 '#d62728', # brick red | 32 "#d62728", # brick red |
29 '#9467bd', # muted purple | 33 "#9467bd", # muted purple |
30 '#8c564b', # chestnut brown | 34 "#8c564b", # chestnut brown |
31 '#e377c2', # raspberry yogurt pink | 35 "#e377c2", # raspberry yogurt pink |
32 '#7f7f7f', # middle gray | 36 "#7f7f7f", # middle gray |
33 '#bcbd22', # curry yellow-green | 37 "#bcbd22", # curry yellow-green |
34 '#17becf' # blue-teal | 38 "#17becf", # blue-teal |
35 ] | 39 ] |
36 | 40 |
37 | 41 |
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): | 42 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): |
39 """output pr-curve in html using plotly | 43 """output pr-curve in html using plotly |
51 for idx in range(df1.shape[1]): | 55 for idx in range(df1.shape[1]): |
52 y_true = df1.iloc[:, idx].values | 56 y_true = df1.iloc[:, idx].values |
53 y_score = df2.iloc[:, idx].values | 57 y_score = df2.iloc[:, idx].values |
54 | 58 |
55 precision, recall, _ = precision_recall_curve( | 59 precision, recall, _ = precision_recall_curve( |
56 y_true, y_score, pos_label=pos_label) | 60 y_true, y_score, pos_label=pos_label |
57 ap = average_precision_score( | 61 ) |
58 y_true, y_score, pos_label=pos_label or 1) | 62 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) |
59 | 63 |
60 trace = go.Scatter( | 64 trace = go.Scatter( |
61 x=recall, | 65 x=recall, |
62 y=precision, | 66 y=precision, |
63 mode='lines', | 67 mode="lines", |
64 marker=dict( | 68 marker=dict(color=default_colors[idx % len(default_colors)]), |
65 color=default_colors[idx % len(default_colors)] | 69 name="%s (area = %.3f)" % (idx, ap), |
66 ), | |
67 name='%s (area = %.3f)' % (idx, ap) | |
68 ) | 70 ) |
69 data.append(trace) | 71 data.append(trace) |
70 | 72 |
71 layout = go.Layout( | 73 layout = go.Layout( |
72 xaxis=dict( | 74 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), |
73 title='Recall', | 75 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), |
74 linecolor='lightslategray', | |
75 linewidth=1 | |
76 ), | |
77 yaxis=dict( | |
78 title='Precision', | |
79 linecolor='lightslategray', | |
80 linewidth=1 | |
81 ), | |
82 title=dict( | 76 title=dict( |
83 text=title or 'Precision-Recall Curve', | 77 text=title or "Precision-Recall Curve", |
84 x=0.5, | 78 x=0.5, |
85 y=0.92, | 79 y=0.92, |
86 xanchor='center', | 80 xanchor="center", |
87 yanchor='top' | 81 yanchor="top", |
88 ), | 82 ), |
89 font=dict( | 83 font=dict(family="sans-serif", size=11), |
90 family="sans-serif", | |
91 size=11 | |
92 ), | |
93 # control backgroud colors | 84 # control backgroud colors |
94 plot_bgcolor='rgba(255,255,255,0)' | 85 plot_bgcolor="rgba(255,255,255,0)", |
95 ) | 86 ) |
96 """ | 87 """ |
97 legend=dict( | 88 legend=dict( |
98 x=0.95, | 89 x=0.95, |
99 y=0, | 90 y=0, |
110 | 101 |
111 fig = go.Figure(data=data, layout=layout) | 102 fig = go.Figure(data=data, layout=layout) |
112 | 103 |
113 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 104 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
114 # to be discovered by `from_work_dir` | 105 # to be discovered by `from_work_dir` |
115 os.rename('output.html', 'output') | 106 os.rename("output.html", "output") |
116 | 107 |
117 | 108 |
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): | 109 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): |
119 """visualize pr-curve using matplotlib and output svg image | 110 """visualize pr-curve using matplotlib and output svg image""" |
120 """ | |
121 backend = matplotlib.get_backend() | 111 backend = matplotlib.get_backend() |
122 if "inline" not in backend: | 112 if "inline" not in backend: |
123 matplotlib.use("SVG") | 113 matplotlib.use("SVG") |
124 plt.style.use('seaborn-colorblind') | 114 plt.style.use("seaborn-colorblind") |
125 plt.figure() | 115 plt.figure() |
126 | 116 |
127 for idx in range(df1.shape[1]): | 117 for idx in range(df1.shape[1]): |
128 y_true = df1.iloc[:, idx].values | 118 y_true = df1.iloc[:, idx].values |
129 y_score = df2.iloc[:, idx].values | 119 y_score = df2.iloc[:, idx].values |
130 | 120 |
131 precision, recall, _ = precision_recall_curve( | 121 precision, recall, _ = precision_recall_curve( |
132 y_true, y_score, pos_label=pos_label) | 122 y_true, y_score, pos_label=pos_label |
133 ap = average_precision_score( | 123 ) |
134 y_true, y_score, pos_label=pos_label or 1) | 124 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) |
135 | 125 |
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3, | 126 plt.step( |
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) | 127 recall, |
128 precision, | |
129 "r-", | |
130 color="black", | |
131 alpha=0.3, | |
132 lw=1, | |
133 where="post", | |
134 label="%s (area = %.3f)" % (idx, ap), | |
135 ) | |
138 | 136 |
139 plt.xlim([0.0, 1.0]) | 137 plt.xlim([0.0, 1.0]) |
140 plt.ylim([0.0, 1.05]) | 138 plt.ylim([0.0, 1.05]) |
141 plt.xlabel('Recall') | 139 plt.xlabel("Recall") |
142 plt.ylabel('Precision') | 140 plt.ylabel("Precision") |
143 title = title or 'Precision-Recall Curve' | 141 title = title or "Precision-Recall Curve" |
144 plt.title(title) | 142 plt.title(title) |
145 folder = os.getcwd() | 143 folder = os.getcwd() |
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 144 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
147 os.rename(os.path.join(folder, "output.svg"), | 145 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
148 os.path.join(folder, "output")) | 146 |
149 | 147 |
150 | 148 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): |
151 def visualize_roc_curve_plotly(df1, df2, pos_label, | |
152 drop_intermediate=True, | |
153 title=None): | |
154 """output roc-curve in html using plotly | 149 """output roc-curve in html using plotly |
155 | 150 |
156 df1 : pandas.DataFrame | 151 df1 : pandas.DataFrame |
157 Containing y_true | 152 Containing y_true |
158 df2 : pandas.DataFrame | 153 df2 : pandas.DataFrame |
167 data = [] | 162 data = [] |
168 for idx in range(df1.shape[1]): | 163 for idx in range(df1.shape[1]): |
169 y_true = df1.iloc[:, idx].values | 164 y_true = df1.iloc[:, idx].values |
170 y_score = df2.iloc[:, idx].values | 165 y_score = df2.iloc[:, idx].values |
171 | 166 |
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | 167 fpr, tpr, _ = roc_curve( |
173 drop_intermediate=drop_intermediate) | 168 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate |
169 ) | |
174 roc_auc = auc(fpr, tpr) | 170 roc_auc = auc(fpr, tpr) |
175 | 171 |
176 trace = go.Scatter( | 172 trace = go.Scatter( |
177 x=fpr, | 173 x=fpr, |
178 y=tpr, | 174 y=tpr, |
179 mode='lines', | 175 mode="lines", |
180 marker=dict( | 176 marker=dict(color=default_colors[idx % len(default_colors)]), |
181 color=default_colors[idx % len(default_colors)] | 177 name="%s (area = %.3f)" % (idx, roc_auc), |
182 ), | |
183 name='%s (area = %.3f)' % (idx, roc_auc) | |
184 ) | 178 ) |
185 data.append(trace) | 179 data.append(trace) |
186 | 180 |
187 layout = go.Layout( | 181 layout = go.Layout( |
188 xaxis=dict( | 182 xaxis=dict( |
189 title='False Positive Rate', | 183 title="False Positive Rate", linecolor="lightslategray", linewidth=1 |
190 linecolor='lightslategray', | |
191 linewidth=1 | |
192 ), | 184 ), |
193 yaxis=dict( | 185 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), |
194 title='True Positive Rate', | |
195 linecolor='lightslategray', | |
196 linewidth=1 | |
197 ), | |
198 title=dict( | 186 title=dict( |
199 text=title or 'Receiver Operating Characteristic (ROC) Curve', | 187 text=title or "Receiver Operating Characteristic (ROC) Curve", |
200 x=0.5, | 188 x=0.5, |
201 y=0.92, | 189 y=0.92, |
202 xanchor='center', | 190 xanchor="center", |
203 yanchor='top' | 191 yanchor="top", |
204 ), | 192 ), |
205 font=dict( | 193 font=dict(family="sans-serif", size=11), |
206 family="sans-serif", | |
207 size=11 | |
208 ), | |
209 # control backgroud colors | 194 # control backgroud colors |
210 plot_bgcolor='rgba(255,255,255,0)' | 195 plot_bgcolor="rgba(255,255,255,0)", |
211 ) | 196 ) |
212 """ | 197 """ |
213 # legend=dict( | 198 # legend=dict( |
214 # x=0.95, | 199 # x=0.95, |
215 # y=0, | 200 # y=0, |
227 | 212 |
228 fig = go.Figure(data=data, layout=layout) | 213 fig = go.Figure(data=data, layout=layout) |
229 | 214 |
230 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 215 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
231 # to be discovered by `from_work_dir` | 216 # to be discovered by `from_work_dir` |
232 os.rename('output.html', 'output') | 217 os.rename("output.html", "output") |
233 | 218 |
234 | 219 |
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label, | 220 def visualize_roc_curve_matplotlib( |
236 drop_intermediate=True, | 221 df1, df2, pos_label, drop_intermediate=True, title=None |
237 title=None): | 222 ): |
238 """visualize roc-curve using matplotlib and output svg image | 223 """visualize roc-curve using matplotlib and output svg image""" |
239 """ | |
240 backend = matplotlib.get_backend() | 224 backend = matplotlib.get_backend() |
241 if "inline" not in backend: | 225 if "inline" not in backend: |
242 matplotlib.use("SVG") | 226 matplotlib.use("SVG") |
243 plt.style.use('seaborn-colorblind') | 227 plt.style.use("seaborn-colorblind") |
244 plt.figure() | 228 plt.figure() |
245 | 229 |
246 for idx in range(df1.shape[1]): | 230 for idx in range(df1.shape[1]): |
247 y_true = df1.iloc[:, idx].values | 231 y_true = df1.iloc[:, idx].values |
248 y_score = df2.iloc[:, idx].values | 232 y_score = df2.iloc[:, idx].values |
249 | 233 |
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | 234 fpr, tpr, _ = roc_curve( |
251 drop_intermediate=drop_intermediate) | 235 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate |
236 ) | |
252 roc_auc = auc(fpr, tpr) | 237 roc_auc = auc(fpr, tpr) |
253 | 238 |
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, | 239 plt.step( |
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc)) | 240 fpr, |
241 tpr, | |
242 "r-", | |
243 color="black", | |
244 alpha=0.3, | |
245 lw=1, | |
246 where="post", | |
247 label="%s (area = %.3f)" % (idx, roc_auc), | |
248 ) | |
256 | 249 |
257 plt.xlim([0.0, 1.0]) | 250 plt.xlim([0.0, 1.0]) |
258 plt.ylim([0.0, 1.05]) | 251 plt.ylim([0.0, 1.05]) |
259 plt.xlabel('False Positive Rate') | 252 plt.xlabel("False Positive Rate") |
260 plt.ylabel('True Positive Rate') | 253 plt.ylabel("True Positive Rate") |
261 title = title or 'Receiver Operating Characteristic (ROC) Curve' | 254 title = title or "Receiver Operating Characteristic (ROC) Curve" |
262 plt.title(title) | 255 plt.title(title) |
263 folder = os.getcwd() | 256 folder = os.getcwd() |
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 257 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
265 os.rename(os.path.join(folder, "output.svg"), | 258 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
266 os.path.join(folder, "output")) | 259 |
267 | 260 |
268 | 261 def main( |
269 def main(inputs, infile_estimator=None, infile1=None, | 262 inputs, |
270 infile2=None, outfile_result=None, | 263 infile_estimator=None, |
271 outfile_object=None, groups=None, | 264 infile1=None, |
272 ref_seq=None, intervals=None, | 265 infile2=None, |
273 targets=None, fasta_path=None, | 266 outfile_result=None, |
274 model_config=None): | 267 outfile_object=None, |
268 groups=None, | |
269 ref_seq=None, | |
270 intervals=None, | |
271 targets=None, | |
272 fasta_path=None, | |
273 model_config=None, | |
274 ): | |
275 """ | 275 """ |
276 Parameter | 276 Parameter |
277 --------- | 277 --------- |
278 inputs : str | 278 inputs : str |
279 File path to galaxy tool parameter | 279 File path to galaxy tool parameter |
310 File path to dataset containing fasta file | 310 File path to dataset containing fasta file |
311 | 311 |
312 model_config : str, default is None | 312 model_config : str, default is None |
313 File path to dataset containing JSON config for neural networks | 313 File path to dataset containing JSON config for neural networks |
314 """ | 314 """ |
315 warnings.simplefilter('ignore') | 315 warnings.simplefilter("ignore") |
316 | 316 |
317 with open(inputs, 'r') as param_handler: | 317 with open(inputs, "r") as param_handler: |
318 params = json.load(param_handler) | 318 params = json.load(param_handler) |
319 | 319 |
320 title = params['plotting_selection']['title'].strip() | 320 title = params["plotting_selection"]["title"].strip() |
321 plot_type = params['plotting_selection']['plot_type'] | 321 plot_type = params["plotting_selection"]["plot_type"] |
322 plot_format = params['plotting_selection']['plot_format'] | 322 plot_format = params["plotting_selection"]["plot_format"] |
323 | 323 |
324 if plot_type == 'feature_importances': | 324 if plot_type == "feature_importances": |
325 with open(infile_estimator, 'rb') as estimator_handler: | 325 estimator = load_model_from_h5(infile_estimator) |
326 estimator = load_model(estimator_handler) | 326 |
327 | 327 column_option = params["plotting_selection"]["column_selector_options"][ |
328 column_option = (params['plotting_selection'] | 328 "selected_column_selector_option" |
329 ['column_selector_options'] | 329 ] |
330 ['selected_column_selector_option']) | 330 if column_option in [ |
331 if column_option in ['by_index_number', 'all_but_by_index_number', | 331 "by_index_number", |
332 'by_header_name', 'all_but_by_header_name']: | 332 "all_but_by_index_number", |
333 c = (params['plotting_selection'] | 333 "by_header_name", |
334 ['column_selector_options']['col1']) | 334 "all_but_by_header_name", |
335 ]: | |
336 c = params["plotting_selection"]["column_selector_options"]["col1"] | |
335 else: | 337 else: |
336 c = None | 338 c = None |
337 | 339 |
338 _, input_df = read_columns(infile1, c=c, | 340 _, input_df = read_columns( |
339 c_option=column_option, | 341 infile1, |
340 return_df=True, | 342 c=c, |
341 sep='\t', header='infer', | 343 c_option=column_option, |
342 parse_dates=True) | 344 return_df=True, |
345 sep="\t", | |
346 header="infer", | |
347 parse_dates=True, | |
348 ) | |
343 | 349 |
344 feature_names = input_df.columns.values | 350 feature_names = input_df.columns.values |
345 | 351 |
346 if isinstance(estimator, Pipeline): | 352 if isinstance(estimator, Pipeline): |
347 for st in estimator.steps[:-1]: | 353 for st in estimator.steps[:-1]: |
348 if isinstance(st[-1], SelectorMixin): | 354 if isinstance(st[-1], SelectorMixin): |
349 mask = st[-1].get_support() | 355 mask = st[-1].get_support() |
350 feature_names = feature_names[mask] | 356 feature_names = feature_names[mask] |
351 estimator = estimator.steps[-1][-1] | 357 estimator = estimator.steps[-1][-1] |
352 | 358 |
353 if hasattr(estimator, 'coef_'): | 359 if hasattr(estimator, "coef_"): |
354 coefs = estimator.coef_ | 360 coefs = estimator.coef_ |
355 else: | 361 else: |
356 coefs = getattr(estimator, 'feature_importances_', None) | 362 coefs = getattr(estimator, "feature_importances_", None) |
357 if coefs is None: | 363 if coefs is None: |
358 raise RuntimeError('The classifier does not expose ' | 364 raise RuntimeError( |
359 '"coef_" or "feature_importances_" ' | 365 "The classifier does not expose " |
360 'attributes') | 366 '"coef_" or "feature_importances_" ' |
361 | 367 "attributes" |
362 threshold = params['plotting_selection']['threshold'] | 368 ) |
369 | |
370 threshold = params["plotting_selection"]["threshold"] | |
363 if threshold is not None: | 371 if threshold is not None: |
364 mask = (coefs > threshold) | (coefs < -threshold) | 372 mask = (coefs > threshold) | (coefs < -threshold) |
365 coefs = coefs[mask] | 373 coefs = coefs[mask] |
366 feature_names = feature_names[mask] | 374 feature_names = feature_names[mask] |
367 | 375 |
368 # sort | 376 # sort |
369 indices = np.argsort(coefs)[::-1] | 377 indices = np.argsort(coefs)[::-1] |
370 | 378 |
371 trace = go.Bar(x=feature_names[indices], | 379 trace = go.Bar(x=feature_names[indices], y=coefs[indices]) |
372 y=coefs[indices]) | |
373 layout = go.Layout(title=title or "Feature Importances") | 380 layout = go.Layout(title=title or "Feature Importances") |
374 fig = go.Figure(data=[trace], layout=layout) | 381 fig = go.Figure(data=[trace], layout=layout) |
375 | 382 |
376 plotly.offline.plot(fig, filename="output.html", | 383 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
377 auto_open=False) | |
378 # to be discovered by `from_work_dir` | 384 # to be discovered by `from_work_dir` |
379 os.rename('output.html', 'output') | 385 os.rename("output.html", "output") |
380 | 386 |
381 return 0 | 387 return 0 |
382 | 388 |
383 elif plot_type in ('pr_curve', 'roc_curve'): | 389 elif plot_type in ("pr_curve", "roc_curve"): |
384 df1 = pd.read_csv(infile1, sep='\t', header='infer') | 390 df1 = pd.read_csv(infile1, sep="\t", header="infer") |
385 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) | 391 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) |
386 | 392 |
387 minimum = params['plotting_selection']['report_minimum_n_positives'] | 393 minimum = params["plotting_selection"]["report_minimum_n_positives"] |
388 # filter out columns whose n_positives is beblow the threhold | 394 # filter out columns whose n_positives is beblow the threhold |
389 if minimum: | 395 if minimum: |
390 mask = df1.sum(axis=0) >= minimum | 396 mask = df1.sum(axis=0) >= minimum |
391 df1 = df1.loc[:, mask] | 397 df1 = df1.loc[:, mask] |
392 df2 = df2.loc[:, mask] | 398 df2 = df2.loc[:, mask] |
393 | 399 |
394 pos_label = params['plotting_selection']['pos_label'].strip() \ | 400 pos_label = params["plotting_selection"]["pos_label"].strip() or None |
395 or None | 401 |
396 | 402 if plot_type == "pr_curve": |
397 if plot_type == 'pr_curve': | 403 if plot_format == "plotly_html": |
398 if plot_format == 'plotly_html': | |
399 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) | 404 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) |
400 else: | 405 else: |
401 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) | 406 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) |
402 else: # 'roc_curve' | 407 else: # 'roc_curve' |
403 drop_intermediate = (params['plotting_selection'] | 408 drop_intermediate = params["plotting_selection"]["drop_intermediate"] |
404 ['drop_intermediate']) | 409 if plot_format == "plotly_html": |
405 if plot_format == 'plotly_html': | 410 visualize_roc_curve_plotly( |
406 visualize_roc_curve_plotly(df1, df2, pos_label, | 411 df1, |
407 drop_intermediate=drop_intermediate, | 412 df2, |
408 title=title) | 413 pos_label, |
414 drop_intermediate=drop_intermediate, | |
415 title=title, | |
416 ) | |
409 else: | 417 else: |
410 visualize_roc_curve_matplotlib( | 418 visualize_roc_curve_matplotlib( |
411 df1, df2, pos_label, | 419 df1, |
420 df2, | |
421 pos_label, | |
412 drop_intermediate=drop_intermediate, | 422 drop_intermediate=drop_intermediate, |
413 title=title) | 423 title=title, |
424 ) | |
414 | 425 |
415 return 0 | 426 return 0 |
416 | 427 |
417 elif plot_type == 'rfecv_gridscores': | 428 elif plot_type == "rfecv_gridscores": |
418 input_df = pd.read_csv(infile1, sep='\t', header='infer') | 429 input_df = pd.read_csv(infile1, sep="\t", header="infer") |
419 scores = input_df.iloc[:, 0] | 430 scores = input_df.iloc[:, 0] |
420 steps = params['plotting_selection']['steps'].strip() | 431 steps = params["plotting_selection"]["steps"].strip() |
421 steps = safe_eval(steps) | 432 steps = safe_eval(steps) |
422 | 433 |
423 data = go.Scatter( | 434 data = go.Scatter( |
424 x=list(range(len(scores))), | 435 x=list(range(len(scores))), |
425 y=scores, | 436 y=scores, |
426 text=[str(_) for _ in steps] if steps else None, | 437 text=[str(_) for _ in steps] if steps else None, |
427 mode='lines' | 438 mode="lines", |
428 ) | 439 ) |
429 layout = go.Layout( | 440 layout = go.Layout( |
430 xaxis=dict(title="Number of features selected"), | 441 xaxis=dict(title="Number of features selected"), |
431 yaxis=dict(title="Cross validation score"), | 442 yaxis=dict(title="Cross validation score"), |
432 title=dict( | 443 title=dict( |
433 text=title or None, | 444 text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top" |
434 x=0.5, | |
435 y=0.92, | |
436 xanchor='center', | |
437 yanchor='top' | |
438 ), | 445 ), |
439 font=dict( | 446 font=dict(family="sans-serif", size=11), |
440 family="sans-serif", | |
441 size=11 | |
442 ), | |
443 # control backgroud colors | 447 # control backgroud colors |
444 plot_bgcolor='rgba(255,255,255,0)' | 448 plot_bgcolor="rgba(255,255,255,0)", |
445 ) | 449 ) |
446 """ | 450 """ |
447 # legend=dict( | 451 # legend=dict( |
448 # x=0.95, | 452 # x=0.95, |
449 # y=0, | 453 # y=0, |
458 # borderwidth=2 | 462 # borderwidth=2 |
459 # ), | 463 # ), |
460 """ | 464 """ |
461 | 465 |
462 fig = go.Figure(data=[data], layout=layout) | 466 fig = go.Figure(data=[data], layout=layout) |
463 plotly.offline.plot(fig, filename="output.html", | 467 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
464 auto_open=False) | |
465 # to be discovered by `from_work_dir` | 468 # to be discovered by `from_work_dir` |
466 os.rename('output.html', 'output') | 469 os.rename("output.html", "output") |
467 | 470 |
468 return 0 | 471 return 0 |
469 | 472 |
470 elif plot_type == 'learning_curve': | 473 elif plot_type == "learning_curve": |
471 input_df = pd.read_csv(infile1, sep='\t', header='infer') | 474 input_df = pd.read_csv(infile1, sep="\t", header="infer") |
472 plot_std_err = params['plotting_selection']['plot_std_err'] | 475 plot_std_err = params["plotting_selection"]["plot_std_err"] |
473 data1 = go.Scatter( | 476 data1 = go.Scatter( |
474 x=input_df['train_sizes_abs'], | 477 x=input_df["train_sizes_abs"], |
475 y=input_df['mean_train_scores'], | 478 y=input_df["mean_train_scores"], |
476 error_y=dict( | 479 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, |
477 array=input_df['std_train_scores'] | 480 mode="lines", |
478 ) if plot_std_err else None, | |
479 mode='lines', | |
480 name="Train Scores", | 481 name="Train Scores", |
481 ) | 482 ) |
482 data2 = go.Scatter( | 483 data2 = go.Scatter( |
483 x=input_df['train_sizes_abs'], | 484 x=input_df["train_sizes_abs"], |
484 y=input_df['mean_test_scores'], | 485 y=input_df["mean_test_scores"], |
485 error_y=dict( | 486 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, |
486 array=input_df['std_test_scores'] | 487 mode="lines", |
487 ) if plot_std_err else None, | |
488 mode='lines', | |
489 name="Test Scores", | 488 name="Test Scores", |
490 ) | 489 ) |
491 layout = dict( | 490 layout = dict( |
492 xaxis=dict( | 491 xaxis=dict(title="No. of samples"), |
493 title='No. of samples' | 492 yaxis=dict(title="Performance Score"), |
494 ), | |
495 yaxis=dict( | |
496 title='Performance Score' | |
497 ), | |
498 # modify these configurations to customize image | 493 # modify these configurations to customize image |
499 title=dict( | 494 title=dict( |
500 text=title or 'Learning Curve', | 495 text=title or "Learning Curve", |
501 x=0.5, | 496 x=0.5, |
502 y=0.92, | 497 y=0.92, |
503 xanchor='center', | 498 xanchor="center", |
504 yanchor='top' | 499 yanchor="top", |
505 ), | 500 ), |
506 font=dict( | 501 font=dict(family="sans-serif", size=11), |
507 family="sans-serif", | |
508 size=11 | |
509 ), | |
510 # control backgroud colors | 502 # control backgroud colors |
511 plot_bgcolor='rgba(255,255,255,0)' | 503 plot_bgcolor="rgba(255,255,255,0)", |
512 ) | 504 ) |
513 """ | 505 """ |
514 # legend=dict( | 506 # legend=dict( |
515 # x=0.95, | 507 # x=0.95, |
516 # y=0, | 508 # y=0, |
525 # borderwidth=2 | 517 # borderwidth=2 |
526 # ), | 518 # ), |
527 """ | 519 """ |
528 | 520 |
529 fig = go.Figure(data=[data1, data2], layout=layout) | 521 fig = go.Figure(data=[data1, data2], layout=layout) |
530 plotly.offline.plot(fig, filename="output.html", | 522 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
531 auto_open=False) | |
532 # to be discovered by `from_work_dir` | 523 # to be discovered by `from_work_dir` |
533 os.rename('output.html', 'output') | 524 os.rename("output.html", "output") |
534 | 525 |
535 return 0 | 526 return 0 |
536 | 527 |
537 elif plot_type == 'keras_plot_model': | 528 elif plot_type == "keras_plot_model": |
538 with open(model_config, 'r') as f: | 529 with open(model_config, "r") as f: |
539 model_str = f.read() | 530 model_str = f.read() |
540 model = model_from_json(model_str) | 531 model = model_from_json(model_str) |
541 plot_model(model, to_file="output.png") | 532 plot_model(model, to_file="output.png") |
542 os.rename('output.png', 'output') | 533 os.rename("output.png", "output") |
543 | 534 |
544 return 0 | 535 return 0 |
545 | 536 |
546 # save pdf file to disk | 537 # save pdf file to disk |
547 # fig.write_image("image.pdf", format='pdf') | 538 # fig.write_image("image.pdf", format='pdf') |
548 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | 539 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) |
549 | 540 |
550 | 541 |
551 if __name__ == '__main__': | 542 if __name__ == "__main__": |
552 aparser = argparse.ArgumentParser() | 543 aparser = argparse.ArgumentParser() |
553 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 544 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
554 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | 545 aparser.add_argument("-e", "--estimator", dest="infile_estimator") |
555 aparser.add_argument("-X", "--infile1", dest="infile1") | 546 aparser.add_argument("-X", "--infile1", dest="infile1") |
556 aparser.add_argument("-y", "--infile2", dest="infile2") | 547 aparser.add_argument("-y", "--infile2", dest="infile2") |
562 aparser.add_argument("-t", "--targets", dest="targets") | 553 aparser.add_argument("-t", "--targets", dest="targets") |
563 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 554 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
564 aparser.add_argument("-c", "--model_config", dest="model_config") | 555 aparser.add_argument("-c", "--model_config", dest="model_config") |
565 args = aparser.parse_args() | 556 args = aparser.parse_args() |
566 | 557 |
567 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | 558 main( |
568 args.outfile_result, outfile_object=args.outfile_object, | 559 args.inputs, |
569 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | 560 args.infile_estimator, |
570 targets=args.targets, fasta_path=args.fasta_path, | 561 args.infile1, |
571 model_config=args.model_config) | 562 args.infile2, |
563 args.outfile_result, | |
564 outfile_object=args.outfile_object, | |
565 groups=args.groups, | |
566 ref_seq=args.ref_seq, | |
567 intervals=args.intervals, | |
568 targets=args.targets, | |
569 fasta_path=args.fasta_path, | |
570 model_config=args.model_config, | |
571 ) |