comparison ml_visualization_ex.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children ba7fb6b33cd0
comparison
equal deleted inserted replaced
2:38c4f8a98038 3:0a1812986bc3
1 import argparse 1 import argparse
2 import json 2 import json
3 import os
4 import warnings
5
3 import matplotlib 6 import matplotlib
4 import matplotlib.pyplot as plt 7 import matplotlib.pyplot as plt
5 import numpy as np 8 import numpy as np
6 import os
7 import pandas as pd 9 import pandas as pd
8 import plotly 10 import plotly
9 import plotly.graph_objs as go 11 import plotly.graph_objs as go
10 import warnings 12 from galaxy_ml.model_persist import load_model_from_h5
11 13 from galaxy_ml.utils import read_columns, SafeEval
12 from keras.models import model_from_json 14 from sklearn.feature_selection._base import SelectorMixin
13 from keras.utils import plot_model 15 from sklearn.metrics import (
14 from sklearn.feature_selection.base import SelectorMixin 16 auc,
15 from sklearn.metrics import precision_recall_curve, average_precision_score 17 average_precision_score,
16 from sklearn.metrics import roc_curve, auc 18 precision_recall_curve,
19 roc_curve,
20 )
17 from sklearn.pipeline import Pipeline 21 from sklearn.pipeline import Pipeline
18 from galaxy_ml.utils import load_model, read_columns, SafeEval 22 from tensorflow.keras.models import model_from_json
19 23 from tensorflow.keras.utils import plot_model
20 24
21 safe_eval = SafeEval() 25 safe_eval = SafeEval()
22 26
23 # plotly default colors 27 # plotly default colors
24 default_colors = [ 28 default_colors = [
25 '#1f77b4', # muted blue 29 "#1f77b4", # muted blue
26 '#ff7f0e', # safety orange 30 "#ff7f0e", # safety orange
27 '#2ca02c', # cooked asparagus green 31 "#2ca02c", # cooked asparagus green
28 '#d62728', # brick red 32 "#d62728", # brick red
29 '#9467bd', # muted purple 33 "#9467bd", # muted purple
30 '#8c564b', # chestnut brown 34 "#8c564b", # chestnut brown
31 '#e377c2', # raspberry yogurt pink 35 "#e377c2", # raspberry yogurt pink
32 '#7f7f7f', # middle gray 36 "#7f7f7f", # middle gray
33 '#bcbd22', # curry yellow-green 37 "#bcbd22", # curry yellow-green
34 '#17becf' # blue-teal 38 "#17becf", # blue-teal
35 ] 39 ]
36 40
37 41
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): 42 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
39 """output pr-curve in html using plotly 43 """output pr-curve in html using plotly
51 for idx in range(df1.shape[1]): 55 for idx in range(df1.shape[1]):
52 y_true = df1.iloc[:, idx].values 56 y_true = df1.iloc[:, idx].values
53 y_score = df2.iloc[:, idx].values 57 y_score = df2.iloc[:, idx].values
54 58
55 precision, recall, _ = precision_recall_curve( 59 precision, recall, _ = precision_recall_curve(
56 y_true, y_score, pos_label=pos_label) 60 y_true, y_score, pos_label=pos_label
57 ap = average_precision_score( 61 )
58 y_true, y_score, pos_label=pos_label or 1) 62 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
59 63
60 trace = go.Scatter( 64 trace = go.Scatter(
61 x=recall, 65 x=recall,
62 y=precision, 66 y=precision,
63 mode='lines', 67 mode="lines",
64 marker=dict( 68 marker=dict(color=default_colors[idx % len(default_colors)]),
65 color=default_colors[idx % len(default_colors)] 69 name="%s (area = %.3f)" % (idx, ap),
66 ),
67 name='%s (area = %.3f)' % (idx, ap)
68 ) 70 )
69 data.append(trace) 71 data.append(trace)
70 72
71 layout = go.Layout( 73 layout = go.Layout(
72 xaxis=dict( 74 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1),
73 title='Recall', 75 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1),
74 linecolor='lightslategray',
75 linewidth=1
76 ),
77 yaxis=dict(
78 title='Precision',
79 linecolor='lightslategray',
80 linewidth=1
81 ),
82 title=dict( 76 title=dict(
83 text=title or 'Precision-Recall Curve', 77 text=title or "Precision-Recall Curve",
84 x=0.5, 78 x=0.5,
85 y=0.92, 79 y=0.92,
86 xanchor='center', 80 xanchor="center",
87 yanchor='top' 81 yanchor="top",
88 ), 82 ),
89 font=dict( 83 font=dict(family="sans-serif", size=11),
90 family="sans-serif",
91 size=11
92 ),
93 # control backgroud colors 84 # control backgroud colors
94 plot_bgcolor='rgba(255,255,255,0)' 85 plot_bgcolor="rgba(255,255,255,0)",
95 ) 86 )
96 """ 87 """
97 legend=dict( 88 legend=dict(
98 x=0.95, 89 x=0.95,
99 y=0, 90 y=0,
110 101
111 fig = go.Figure(data=data, layout=layout) 102 fig = go.Figure(data=data, layout=layout)
112 103
113 plotly.offline.plot(fig, filename="output.html", auto_open=False) 104 plotly.offline.plot(fig, filename="output.html", auto_open=False)
114 # to be discovered by `from_work_dir` 105 # to be discovered by `from_work_dir`
115 os.rename('output.html', 'output') 106 os.rename("output.html", "output")
116 107
117 108
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): 109 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
119 """visualize pr-curve using matplotlib and output svg image 110 """visualize pr-curve using matplotlib and output svg image"""
120 """
121 backend = matplotlib.get_backend() 111 backend = matplotlib.get_backend()
122 if "inline" not in backend: 112 if "inline" not in backend:
123 matplotlib.use("SVG") 113 matplotlib.use("SVG")
124 plt.style.use('seaborn-colorblind') 114 plt.style.use("seaborn-colorblind")
125 plt.figure() 115 plt.figure()
126 116
127 for idx in range(df1.shape[1]): 117 for idx in range(df1.shape[1]):
128 y_true = df1.iloc[:, idx].values 118 y_true = df1.iloc[:, idx].values
129 y_score = df2.iloc[:, idx].values 119 y_score = df2.iloc[:, idx].values
130 120
131 precision, recall, _ = precision_recall_curve( 121 precision, recall, _ = precision_recall_curve(
132 y_true, y_score, pos_label=pos_label) 122 y_true, y_score, pos_label=pos_label
133 ap = average_precision_score( 123 )
134 y_true, y_score, pos_label=pos_label or 1) 124 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
135 125
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3, 126 plt.step(
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) 127 recall,
128 precision,
129 "r-",
130 color="black",
131 alpha=0.3,
132 lw=1,
133 where="post",
134 label="%s (area = %.3f)" % (idx, ap),
135 )
138 136
139 plt.xlim([0.0, 1.0]) 137 plt.xlim([0.0, 1.0])
140 plt.ylim([0.0, 1.05]) 138 plt.ylim([0.0, 1.05])
141 plt.xlabel('Recall') 139 plt.xlabel("Recall")
142 plt.ylabel('Precision') 140 plt.ylabel("Precision")
143 title = title or 'Precision-Recall Curve' 141 title = title or "Precision-Recall Curve"
144 plt.title(title) 142 plt.title(title)
145 folder = os.getcwd() 143 folder = os.getcwd()
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 144 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
147 os.rename(os.path.join(folder, "output.svg"), 145 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
148 os.path.join(folder, "output")) 146
149 147
150 148 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None):
151 def visualize_roc_curve_plotly(df1, df2, pos_label,
152 drop_intermediate=True,
153 title=None):
154 """output roc-curve in html using plotly 149 """output roc-curve in html using plotly
155 150
156 df1 : pandas.DataFrame 151 df1 : pandas.DataFrame
157 Containing y_true 152 Containing y_true
158 df2 : pandas.DataFrame 153 df2 : pandas.DataFrame
167 data = [] 162 data = []
168 for idx in range(df1.shape[1]): 163 for idx in range(df1.shape[1]):
169 y_true = df1.iloc[:, idx].values 164 y_true = df1.iloc[:, idx].values
170 y_score = df2.iloc[:, idx].values 165 y_score = df2.iloc[:, idx].values
171 166
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, 167 fpr, tpr, _ = roc_curve(
173 drop_intermediate=drop_intermediate) 168 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
169 )
174 roc_auc = auc(fpr, tpr) 170 roc_auc = auc(fpr, tpr)
175 171
176 trace = go.Scatter( 172 trace = go.Scatter(
177 x=fpr, 173 x=fpr,
178 y=tpr, 174 y=tpr,
179 mode='lines', 175 mode="lines",
180 marker=dict( 176 marker=dict(color=default_colors[idx % len(default_colors)]),
181 color=default_colors[idx % len(default_colors)] 177 name="%s (area = %.3f)" % (idx, roc_auc),
182 ),
183 name='%s (area = %.3f)' % (idx, roc_auc)
184 ) 178 )
185 data.append(trace) 179 data.append(trace)
186 180
187 layout = go.Layout( 181 layout = go.Layout(
188 xaxis=dict( 182 xaxis=dict(
189 title='False Positive Rate', 183 title="False Positive Rate", linecolor="lightslategray", linewidth=1
190 linecolor='lightslategray',
191 linewidth=1
192 ), 184 ),
193 yaxis=dict( 185 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
194 title='True Positive Rate',
195 linecolor='lightslategray',
196 linewidth=1
197 ),
198 title=dict( 186 title=dict(
199 text=title or 'Receiver Operating Characteristic (ROC) Curve', 187 text=title or "Receiver Operating Characteristic (ROC) Curve",
200 x=0.5, 188 x=0.5,
201 y=0.92, 189 y=0.92,
202 xanchor='center', 190 xanchor="center",
203 yanchor='top' 191 yanchor="top",
204 ), 192 ),
205 font=dict( 193 font=dict(family="sans-serif", size=11),
206 family="sans-serif",
207 size=11
208 ),
209 # control backgroud colors 194 # control backgroud colors
210 plot_bgcolor='rgba(255,255,255,0)' 195 plot_bgcolor="rgba(255,255,255,0)",
211 ) 196 )
212 """ 197 """
213 # legend=dict( 198 # legend=dict(
214 # x=0.95, 199 # x=0.95,
215 # y=0, 200 # y=0,
227 212
228 fig = go.Figure(data=data, layout=layout) 213 fig = go.Figure(data=data, layout=layout)
229 214
230 plotly.offline.plot(fig, filename="output.html", auto_open=False) 215 plotly.offline.plot(fig, filename="output.html", auto_open=False)
231 # to be discovered by `from_work_dir` 216 # to be discovered by `from_work_dir`
232 os.rename('output.html', 'output') 217 os.rename("output.html", "output")
233 218
234 219
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label, 220 def visualize_roc_curve_matplotlib(
236 drop_intermediate=True, 221 df1, df2, pos_label, drop_intermediate=True, title=None
237 title=None): 222 ):
238 """visualize roc-curve using matplotlib and output svg image 223 """visualize roc-curve using matplotlib and output svg image"""
239 """
240 backend = matplotlib.get_backend() 224 backend = matplotlib.get_backend()
241 if "inline" not in backend: 225 if "inline" not in backend:
242 matplotlib.use("SVG") 226 matplotlib.use("SVG")
243 plt.style.use('seaborn-colorblind') 227 plt.style.use("seaborn-colorblind")
244 plt.figure() 228 plt.figure()
245 229
246 for idx in range(df1.shape[1]): 230 for idx in range(df1.shape[1]):
247 y_true = df1.iloc[:, idx].values 231 y_true = df1.iloc[:, idx].values
248 y_score = df2.iloc[:, idx].values 232 y_score = df2.iloc[:, idx].values
249 233
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, 234 fpr, tpr, _ = roc_curve(
251 drop_intermediate=drop_intermediate) 235 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
236 )
252 roc_auc = auc(fpr, tpr) 237 roc_auc = auc(fpr, tpr)
253 238
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, 239 plt.step(
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc)) 240 fpr,
241 tpr,
242 "r-",
243 color="black",
244 alpha=0.3,
245 lw=1,
246 where="post",
247 label="%s (area = %.3f)" % (idx, roc_auc),
248 )
256 249
257 plt.xlim([0.0, 1.0]) 250 plt.xlim([0.0, 1.0])
258 plt.ylim([0.0, 1.05]) 251 plt.ylim([0.0, 1.05])
259 plt.xlabel('False Positive Rate') 252 plt.xlabel("False Positive Rate")
260 plt.ylabel('True Positive Rate') 253 plt.ylabel("True Positive Rate")
261 title = title or 'Receiver Operating Characteristic (ROC) Curve' 254 title = title or "Receiver Operating Characteristic (ROC) Curve"
262 plt.title(title) 255 plt.title(title)
263 folder = os.getcwd() 256 folder = os.getcwd()
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 257 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
265 os.rename(os.path.join(folder, "output.svg"), 258 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
266 os.path.join(folder, "output")) 259
267 260
268 261 def main(
269 def main(inputs, infile_estimator=None, infile1=None, 262 inputs,
270 infile2=None, outfile_result=None, 263 infile_estimator=None,
271 outfile_object=None, groups=None, 264 infile1=None,
272 ref_seq=None, intervals=None, 265 infile2=None,
273 targets=None, fasta_path=None, 266 outfile_result=None,
274 model_config=None): 267 outfile_object=None,
268 groups=None,
269 ref_seq=None,
270 intervals=None,
271 targets=None,
272 fasta_path=None,
273 model_config=None,
274 ):
275 """ 275 """
276 Parameter 276 Parameter
277 --------- 277 ---------
278 inputs : str 278 inputs : str
279 File path to galaxy tool parameter 279 File path to galaxy tool parameter
310 File path to dataset containing fasta file 310 File path to dataset containing fasta file
311 311
312 model_config : str, default is None 312 model_config : str, default is None
313 File path to dataset containing JSON config for neural networks 313 File path to dataset containing JSON config for neural networks
314 """ 314 """
315 warnings.simplefilter('ignore') 315 warnings.simplefilter("ignore")
316 316
317 with open(inputs, 'r') as param_handler: 317 with open(inputs, "r") as param_handler:
318 params = json.load(param_handler) 318 params = json.load(param_handler)
319 319
320 title = params['plotting_selection']['title'].strip() 320 title = params["plotting_selection"]["title"].strip()
321 plot_type = params['plotting_selection']['plot_type'] 321 plot_type = params["plotting_selection"]["plot_type"]
322 plot_format = params['plotting_selection']['plot_format'] 322 plot_format = params["plotting_selection"]["plot_format"]
323 323
324 if plot_type == 'feature_importances': 324 if plot_type == "feature_importances":
325 with open(infile_estimator, 'rb') as estimator_handler: 325 estimator = load_model_from_h5(infile_estimator)
326 estimator = load_model(estimator_handler) 326
327 327 column_option = params["plotting_selection"]["column_selector_options"][
328 column_option = (params['plotting_selection'] 328 "selected_column_selector_option"
329 ['column_selector_options'] 329 ]
330 ['selected_column_selector_option']) 330 if column_option in [
331 if column_option in ['by_index_number', 'all_but_by_index_number', 331 "by_index_number",
332 'by_header_name', 'all_but_by_header_name']: 332 "all_but_by_index_number",
333 c = (params['plotting_selection'] 333 "by_header_name",
334 ['column_selector_options']['col1']) 334 "all_but_by_header_name",
335 ]:
336 c = params["plotting_selection"]["column_selector_options"]["col1"]
335 else: 337 else:
336 c = None 338 c = None
337 339
338 _, input_df = read_columns(infile1, c=c, 340 _, input_df = read_columns(
339 c_option=column_option, 341 infile1,
340 return_df=True, 342 c=c,
341 sep='\t', header='infer', 343 c_option=column_option,
342 parse_dates=True) 344 return_df=True,
345 sep="\t",
346 header="infer",
347 parse_dates=True,
348 )
343 349
344 feature_names = input_df.columns.values 350 feature_names = input_df.columns.values
345 351
346 if isinstance(estimator, Pipeline): 352 if isinstance(estimator, Pipeline):
347 for st in estimator.steps[:-1]: 353 for st in estimator.steps[:-1]:
348 if isinstance(st[-1], SelectorMixin): 354 if isinstance(st[-1], SelectorMixin):
349 mask = st[-1].get_support() 355 mask = st[-1].get_support()
350 feature_names = feature_names[mask] 356 feature_names = feature_names[mask]
351 estimator = estimator.steps[-1][-1] 357 estimator = estimator.steps[-1][-1]
352 358
353 if hasattr(estimator, 'coef_'): 359 if hasattr(estimator, "coef_"):
354 coefs = estimator.coef_ 360 coefs = estimator.coef_
355 else: 361 else:
356 coefs = getattr(estimator, 'feature_importances_', None) 362 coefs = getattr(estimator, "feature_importances_", None)
357 if coefs is None: 363 if coefs is None:
358 raise RuntimeError('The classifier does not expose ' 364 raise RuntimeError(
359 '"coef_" or "feature_importances_" ' 365 "The classifier does not expose "
360 'attributes') 366 '"coef_" or "feature_importances_" '
361 367 "attributes"
362 threshold = params['plotting_selection']['threshold'] 368 )
369
370 threshold = params["plotting_selection"]["threshold"]
363 if threshold is not None: 371 if threshold is not None:
364 mask = (coefs > threshold) | (coefs < -threshold) 372 mask = (coefs > threshold) | (coefs < -threshold)
365 coefs = coefs[mask] 373 coefs = coefs[mask]
366 feature_names = feature_names[mask] 374 feature_names = feature_names[mask]
367 375
368 # sort 376 # sort
369 indices = np.argsort(coefs)[::-1] 377 indices = np.argsort(coefs)[::-1]
370 378
371 trace = go.Bar(x=feature_names[indices], 379 trace = go.Bar(x=feature_names[indices], y=coefs[indices])
372 y=coefs[indices])
373 layout = go.Layout(title=title or "Feature Importances") 380 layout = go.Layout(title=title or "Feature Importances")
374 fig = go.Figure(data=[trace], layout=layout) 381 fig = go.Figure(data=[trace], layout=layout)
375 382
376 plotly.offline.plot(fig, filename="output.html", 383 plotly.offline.plot(fig, filename="output.html", auto_open=False)
377 auto_open=False)
378 # to be discovered by `from_work_dir` 384 # to be discovered by `from_work_dir`
379 os.rename('output.html', 'output') 385 os.rename("output.html", "output")
380 386
381 return 0 387 return 0
382 388
383 elif plot_type in ('pr_curve', 'roc_curve'): 389 elif plot_type in ("pr_curve", "roc_curve"):
384 df1 = pd.read_csv(infile1, sep='\t', header='infer') 390 df1 = pd.read_csv(infile1, sep="\t", header="infer")
385 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) 391 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32)
386 392
387 minimum = params['plotting_selection']['report_minimum_n_positives'] 393 minimum = params["plotting_selection"]["report_minimum_n_positives"]
388 # filter out columns whose n_positives is beblow the threhold 394 # filter out columns whose n_positives is beblow the threhold
389 if minimum: 395 if minimum:
390 mask = df1.sum(axis=0) >= minimum 396 mask = df1.sum(axis=0) >= minimum
391 df1 = df1.loc[:, mask] 397 df1 = df1.loc[:, mask]
392 df2 = df2.loc[:, mask] 398 df2 = df2.loc[:, mask]
393 399
394 pos_label = params['plotting_selection']['pos_label'].strip() \ 400 pos_label = params["plotting_selection"]["pos_label"].strip() or None
395 or None 401
396 402 if plot_type == "pr_curve":
397 if plot_type == 'pr_curve': 403 if plot_format == "plotly_html":
398 if plot_format == 'plotly_html':
399 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) 404 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
400 else: 405 else:
401 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) 406 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
402 else: # 'roc_curve' 407 else: # 'roc_curve'
403 drop_intermediate = (params['plotting_selection'] 408 drop_intermediate = params["plotting_selection"]["drop_intermediate"]
404 ['drop_intermediate']) 409 if plot_format == "plotly_html":
405 if plot_format == 'plotly_html': 410 visualize_roc_curve_plotly(
406 visualize_roc_curve_plotly(df1, df2, pos_label, 411 df1,
407 drop_intermediate=drop_intermediate, 412 df2,
408 title=title) 413 pos_label,
414 drop_intermediate=drop_intermediate,
415 title=title,
416 )
409 else: 417 else:
410 visualize_roc_curve_matplotlib( 418 visualize_roc_curve_matplotlib(
411 df1, df2, pos_label, 419 df1,
420 df2,
421 pos_label,
412 drop_intermediate=drop_intermediate, 422 drop_intermediate=drop_intermediate,
413 title=title) 423 title=title,
424 )
414 425
415 return 0 426 return 0
416 427
417 elif plot_type == 'rfecv_gridscores': 428 elif plot_type == "rfecv_gridscores":
418 input_df = pd.read_csv(infile1, sep='\t', header='infer') 429 input_df = pd.read_csv(infile1, sep="\t", header="infer")
419 scores = input_df.iloc[:, 0] 430 scores = input_df.iloc[:, 0]
420 steps = params['plotting_selection']['steps'].strip() 431 steps = params["plotting_selection"]["steps"].strip()
421 steps = safe_eval(steps) 432 steps = safe_eval(steps)
422 433
423 data = go.Scatter( 434 data = go.Scatter(
424 x=list(range(len(scores))), 435 x=list(range(len(scores))),
425 y=scores, 436 y=scores,
426 text=[str(_) for _ in steps] if steps else None, 437 text=[str(_) for _ in steps] if steps else None,
427 mode='lines' 438 mode="lines",
428 ) 439 )
429 layout = go.Layout( 440 layout = go.Layout(
430 xaxis=dict(title="Number of features selected"), 441 xaxis=dict(title="Number of features selected"),
431 yaxis=dict(title="Cross validation score"), 442 yaxis=dict(title="Cross validation score"),
432 title=dict( 443 title=dict(
433 text=title or None, 444 text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"
434 x=0.5,
435 y=0.92,
436 xanchor='center',
437 yanchor='top'
438 ), 445 ),
439 font=dict( 446 font=dict(family="sans-serif", size=11),
440 family="sans-serif",
441 size=11
442 ),
443 # control backgroud colors 447 # control backgroud colors
444 plot_bgcolor='rgba(255,255,255,0)' 448 plot_bgcolor="rgba(255,255,255,0)",
445 ) 449 )
446 """ 450 """
447 # legend=dict( 451 # legend=dict(
448 # x=0.95, 452 # x=0.95,
449 # y=0, 453 # y=0,
458 # borderwidth=2 462 # borderwidth=2
459 # ), 463 # ),
460 """ 464 """
461 465
462 fig = go.Figure(data=[data], layout=layout) 466 fig = go.Figure(data=[data], layout=layout)
463 plotly.offline.plot(fig, filename="output.html", 467 plotly.offline.plot(fig, filename="output.html", auto_open=False)
464 auto_open=False)
465 # to be discovered by `from_work_dir` 468 # to be discovered by `from_work_dir`
466 os.rename('output.html', 'output') 469 os.rename("output.html", "output")
467 470
468 return 0 471 return 0
469 472
470 elif plot_type == 'learning_curve': 473 elif plot_type == "learning_curve":
471 input_df = pd.read_csv(infile1, sep='\t', header='infer') 474 input_df = pd.read_csv(infile1, sep="\t", header="infer")
472 plot_std_err = params['plotting_selection']['plot_std_err'] 475 plot_std_err = params["plotting_selection"]["plot_std_err"]
473 data1 = go.Scatter( 476 data1 = go.Scatter(
474 x=input_df['train_sizes_abs'], 477 x=input_df["train_sizes_abs"],
475 y=input_df['mean_train_scores'], 478 y=input_df["mean_train_scores"],
476 error_y=dict( 479 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None,
477 array=input_df['std_train_scores'] 480 mode="lines",
478 ) if plot_std_err else None,
479 mode='lines',
480 name="Train Scores", 481 name="Train Scores",
481 ) 482 )
482 data2 = go.Scatter( 483 data2 = go.Scatter(
483 x=input_df['train_sizes_abs'], 484 x=input_df["train_sizes_abs"],
484 y=input_df['mean_test_scores'], 485 y=input_df["mean_test_scores"],
485 error_y=dict( 486 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None,
486 array=input_df['std_test_scores'] 487 mode="lines",
487 ) if plot_std_err else None,
488 mode='lines',
489 name="Test Scores", 488 name="Test Scores",
490 ) 489 )
491 layout = dict( 490 layout = dict(
492 xaxis=dict( 491 xaxis=dict(title="No. of samples"),
493 title='No. of samples' 492 yaxis=dict(title="Performance Score"),
494 ),
495 yaxis=dict(
496 title='Performance Score'
497 ),
498 # modify these configurations to customize image 493 # modify these configurations to customize image
499 title=dict( 494 title=dict(
500 text=title or 'Learning Curve', 495 text=title or "Learning Curve",
501 x=0.5, 496 x=0.5,
502 y=0.92, 497 y=0.92,
503 xanchor='center', 498 xanchor="center",
504 yanchor='top' 499 yanchor="top",
505 ), 500 ),
506 font=dict( 501 font=dict(family="sans-serif", size=11),
507 family="sans-serif",
508 size=11
509 ),
510 # control backgroud colors 502 # control backgroud colors
511 plot_bgcolor='rgba(255,255,255,0)' 503 plot_bgcolor="rgba(255,255,255,0)",
512 ) 504 )
513 """ 505 """
514 # legend=dict( 506 # legend=dict(
515 # x=0.95, 507 # x=0.95,
516 # y=0, 508 # y=0,
525 # borderwidth=2 517 # borderwidth=2
526 # ), 518 # ),
527 """ 519 """
528 520
529 fig = go.Figure(data=[data1, data2], layout=layout) 521 fig = go.Figure(data=[data1, data2], layout=layout)
530 plotly.offline.plot(fig, filename="output.html", 522 plotly.offline.plot(fig, filename="output.html", auto_open=False)
531 auto_open=False)
532 # to be discovered by `from_work_dir` 523 # to be discovered by `from_work_dir`
533 os.rename('output.html', 'output') 524 os.rename("output.html", "output")
534 525
535 return 0 526 return 0
536 527
537 elif plot_type == 'keras_plot_model': 528 elif plot_type == "keras_plot_model":
538 with open(model_config, 'r') as f: 529 with open(model_config, "r") as f:
539 model_str = f.read() 530 model_str = f.read()
540 model = model_from_json(model_str) 531 model = model_from_json(model_str)
541 plot_model(model, to_file="output.png") 532 plot_model(model, to_file="output.png")
542 os.rename('output.png', 'output') 533 os.rename("output.png", "output")
543 534
544 return 0 535 return 0
545 536
546 # save pdf file to disk 537 # save pdf file to disk
547 # fig.write_image("image.pdf", format='pdf') 538 # fig.write_image("image.pdf", format='pdf')
548 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) 539 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
549 540
550 541
551 if __name__ == '__main__': 542 if __name__ == "__main__":
552 aparser = argparse.ArgumentParser() 543 aparser = argparse.ArgumentParser()
553 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 544 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
554 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 545 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
555 aparser.add_argument("-X", "--infile1", dest="infile1") 546 aparser.add_argument("-X", "--infile1", dest="infile1")
556 aparser.add_argument("-y", "--infile2", dest="infile2") 547 aparser.add_argument("-y", "--infile2", dest="infile2")
562 aparser.add_argument("-t", "--targets", dest="targets") 553 aparser.add_argument("-t", "--targets", dest="targets")
563 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 554 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
564 aparser.add_argument("-c", "--model_config", dest="model_config") 555 aparser.add_argument("-c", "--model_config", dest="model_config")
565 args = aparser.parse_args() 556 args = aparser.parse_args()
566 557
567 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 558 main(
568 args.outfile_result, outfile_object=args.outfile_object, 559 args.inputs,
569 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, 560 args.infile_estimator,
570 targets=args.targets, fasta_path=args.fasta_path, 561 args.infile1,
571 model_config=args.model_config) 562 args.infile2,
563 args.outfile_result,
564 outfile_object=args.outfile_object,
565 groups=args.groups,
566 ref_seq=args.ref_seq,
567 intervals=args.intervals,
568 targets=args.targets,
569 fasta_path=args.fasta_path,
570 model_config=args.model_config,
571 )