Mercurial > repos > bgruening > sklearn_stacking_ensemble_models
comparison ml_visualization_ex.py @ 9:b8c92e94ac1d draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
| author | bgruening |
|---|---|
| date | Tue, 13 Apr 2021 15:49:42 +0000 |
| parents | 6430b9b00d2f |
| children | 2d890789ac48 |
comparison
equal
deleted
inserted
replaced
| 8:6430b9b00d2f | 9:b8c92e94ac1d |
|---|---|
| 20 | 20 |
| 21 safe_eval = SafeEval() | 21 safe_eval = SafeEval() |
| 22 | 22 |
| 23 # plotly default colors | 23 # plotly default colors |
| 24 default_colors = [ | 24 default_colors = [ |
| 25 '#1f77b4', # muted blue | 25 "#1f77b4", # muted blue |
| 26 '#ff7f0e', # safety orange | 26 "#ff7f0e", # safety orange |
| 27 '#2ca02c', # cooked asparagus green | 27 "#2ca02c", # cooked asparagus green |
| 28 '#d62728', # brick red | 28 "#d62728", # brick red |
| 29 '#9467bd', # muted purple | 29 "#9467bd", # muted purple |
| 30 '#8c564b', # chestnut brown | 30 "#8c564b", # chestnut brown |
| 31 '#e377c2', # raspberry yogurt pink | 31 "#e377c2", # raspberry yogurt pink |
| 32 '#7f7f7f', # middle gray | 32 "#7f7f7f", # middle gray |
| 33 '#bcbd22', # curry yellow-green | 33 "#bcbd22", # curry yellow-green |
| 34 '#17becf' # blue-teal | 34 "#17becf", # blue-teal |
| 35 ] | 35 ] |
| 36 | 36 |
| 37 | 37 |
| 38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): | 38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): |
| 39 """output pr-curve in html using plotly | 39 """output pr-curve in html using plotly |
| 50 data = [] | 50 data = [] |
| 51 for idx in range(df1.shape[1]): | 51 for idx in range(df1.shape[1]): |
| 52 y_true = df1.iloc[:, idx].values | 52 y_true = df1.iloc[:, idx].values |
| 53 y_score = df2.iloc[:, idx].values | 53 y_score = df2.iloc[:, idx].values |
| 54 | 54 |
| 55 precision, recall, _ = precision_recall_curve( | 55 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) |
| 56 y_true, y_score, pos_label=pos_label) | 56 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) |
| 57 ap = average_precision_score( | |
| 58 y_true, y_score, pos_label=pos_label or 1) | |
| 59 | 57 |
| 60 trace = go.Scatter( | 58 trace = go.Scatter( |
| 61 x=recall, | 59 x=recall, |
| 62 y=precision, | 60 y=precision, |
| 63 mode='lines', | 61 mode="lines", |
| 64 marker=dict( | 62 marker=dict(color=default_colors[idx % len(default_colors)]), |
| 65 color=default_colors[idx % len(default_colors)] | 63 name="%s (area = %.3f)" % (idx, ap), |
| 66 ), | |
| 67 name='%s (area = %.3f)' % (idx, ap) | |
| 68 ) | 64 ) |
| 69 data.append(trace) | 65 data.append(trace) |
| 70 | 66 |
| 71 layout = go.Layout( | 67 layout = go.Layout( |
| 72 xaxis=dict( | 68 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), |
| 73 title='Recall', | 69 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), |
| 74 linecolor='lightslategray', | |
| 75 linewidth=1 | |
| 76 ), | |
| 77 yaxis=dict( | |
| 78 title='Precision', | |
| 79 linecolor='lightslategray', | |
| 80 linewidth=1 | |
| 81 ), | |
| 82 title=dict( | 70 title=dict( |
| 83 text=title or 'Precision-Recall Curve', | 71 text=title or "Precision-Recall Curve", |
| 84 x=0.5, | 72 x=0.5, |
| 85 y=0.92, | 73 y=0.92, |
| 86 xanchor='center', | 74 xanchor="center", |
| 87 yanchor='top' | 75 yanchor="top", |
| 88 ), | 76 ), |
| 89 font=dict( | 77 font=dict(family="sans-serif", size=11), |
| 90 family="sans-serif", | |
| 91 size=11 | |
| 92 ), | |
| 93 # control backgroud colors | 78 # control backgroud colors |
| 94 plot_bgcolor='rgba(255,255,255,0)' | 79 plot_bgcolor="rgba(255,255,255,0)", |
| 95 ) | 80 ) |
| 96 """ | 81 """ |
| 97 legend=dict( | 82 legend=dict( |
| 98 x=0.95, | 83 x=0.95, |
| 99 y=0, | 84 y=0, |
| 110 | 95 |
| 111 fig = go.Figure(data=data, layout=layout) | 96 fig = go.Figure(data=data, layout=layout) |
| 112 | 97 |
| 113 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 98 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
| 114 # to be discovered by `from_work_dir` | 99 # to be discovered by `from_work_dir` |
| 115 os.rename('output.html', 'output') | 100 os.rename("output.html", "output") |
| 116 | 101 |
| 117 | 102 |
| 118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): | 103 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): |
| 119 """visualize pr-curve using matplotlib and output svg image | 104 """visualize pr-curve using matplotlib and output svg image""" |
| 120 """ | |
| 121 backend = matplotlib.get_backend() | 105 backend = matplotlib.get_backend() |
| 122 if "inline" not in backend: | 106 if "inline" not in backend: |
| 123 matplotlib.use("SVG") | 107 matplotlib.use("SVG") |
| 124 plt.style.use('seaborn-colorblind') | 108 plt.style.use("seaborn-colorblind") |
| 125 plt.figure() | 109 plt.figure() |
| 126 | 110 |
| 127 for idx in range(df1.shape[1]): | 111 for idx in range(df1.shape[1]): |
| 128 y_true = df1.iloc[:, idx].values | 112 y_true = df1.iloc[:, idx].values |
| 129 y_score = df2.iloc[:, idx].values | 113 y_score = df2.iloc[:, idx].values |
| 130 | 114 |
| 131 precision, recall, _ = precision_recall_curve( | 115 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) |
| 132 y_true, y_score, pos_label=pos_label) | 116 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) |
| 133 ap = average_precision_score( | 117 |
| 134 y_true, y_score, pos_label=pos_label or 1) | 118 plt.step( |
| 135 | 119 recall, |
| 136 plt.step(recall, precision, 'r-', color="black", alpha=0.3, | 120 precision, |
| 137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) | 121 "r-", |
| 122 color="black", | |
| 123 alpha=0.3, | |
| 124 lw=1, | |
| 125 where="post", | |
| 126 label="%s (area = %.3f)" % (idx, ap), | |
| 127 ) | |
| 138 | 128 |
| 139 plt.xlim([0.0, 1.0]) | 129 plt.xlim([0.0, 1.0]) |
| 140 plt.ylim([0.0, 1.05]) | 130 plt.ylim([0.0, 1.05]) |
| 141 plt.xlabel('Recall') | 131 plt.xlabel("Recall") |
| 142 plt.ylabel('Precision') | 132 plt.ylabel("Precision") |
| 143 title = title or 'Precision-Recall Curve' | 133 title = title or "Precision-Recall Curve" |
| 144 plt.title(title) | 134 plt.title(title) |
| 145 folder = os.getcwd() | 135 folder = os.getcwd() |
| 146 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 136 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
| 147 os.rename(os.path.join(folder, "output.svg"), | 137 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
| 148 os.path.join(folder, "output")) | 138 |
| 149 | 139 |
| 150 | 140 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): |
| 151 def visualize_roc_curve_plotly(df1, df2, pos_label, | |
| 152 drop_intermediate=True, | |
| 153 title=None): | |
| 154 """output roc-curve in html using plotly | 141 """output roc-curve in html using plotly |
| 155 | 142 |
| 156 df1 : pandas.DataFrame | 143 df1 : pandas.DataFrame |
| 157 Containing y_true | 144 Containing y_true |
| 158 df2 : pandas.DataFrame | 145 df2 : pandas.DataFrame |
| 167 data = [] | 154 data = [] |
| 168 for idx in range(df1.shape[1]): | 155 for idx in range(df1.shape[1]): |
| 169 y_true = df1.iloc[:, idx].values | 156 y_true = df1.iloc[:, idx].values |
| 170 y_score = df2.iloc[:, idx].values | 157 y_score = df2.iloc[:, idx].values |
| 171 | 158 |
| 172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | 159 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) |
| 173 drop_intermediate=drop_intermediate) | |
| 174 roc_auc = auc(fpr, tpr) | 160 roc_auc = auc(fpr, tpr) |
| 175 | 161 |
| 176 trace = go.Scatter( | 162 trace = go.Scatter( |
| 177 x=fpr, | 163 x=fpr, |
| 178 y=tpr, | 164 y=tpr, |
| 179 mode='lines', | 165 mode="lines", |
| 180 marker=dict( | 166 marker=dict(color=default_colors[idx % len(default_colors)]), |
| 181 color=default_colors[idx % len(default_colors)] | 167 name="%s (area = %.3f)" % (idx, roc_auc), |
| 182 ), | |
| 183 name='%s (area = %.3f)' % (idx, roc_auc) | |
| 184 ) | 168 ) |
| 185 data.append(trace) | 169 data.append(trace) |
| 186 | 170 |
| 187 layout = go.Layout( | 171 layout = go.Layout( |
| 188 xaxis=dict( | 172 xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1), |
| 189 title='False Positive Rate', | 173 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), |
| 190 linecolor='lightslategray', | |
| 191 linewidth=1 | |
| 192 ), | |
| 193 yaxis=dict( | |
| 194 title='True Positive Rate', | |
| 195 linecolor='lightslategray', | |
| 196 linewidth=1 | |
| 197 ), | |
| 198 title=dict( | 174 title=dict( |
| 199 text=title or 'Receiver Operating Characteristic (ROC) Curve', | 175 text=title or "Receiver Operating Characteristic (ROC) Curve", |
| 200 x=0.5, | 176 x=0.5, |
| 201 y=0.92, | 177 y=0.92, |
| 202 xanchor='center', | 178 xanchor="center", |
| 203 yanchor='top' | 179 yanchor="top", |
| 204 ), | 180 ), |
| 205 font=dict( | 181 font=dict(family="sans-serif", size=11), |
| 206 family="sans-serif", | |
| 207 size=11 | |
| 208 ), | |
| 209 # control backgroud colors | 182 # control backgroud colors |
| 210 plot_bgcolor='rgba(255,255,255,0)' | 183 plot_bgcolor="rgba(255,255,255,0)", |
| 211 ) | 184 ) |
| 212 """ | 185 """ |
| 213 # legend=dict( | 186 # legend=dict( |
| 214 # x=0.95, | 187 # x=0.95, |
| 215 # y=0, | 188 # y=0, |
| 227 | 200 |
| 228 fig = go.Figure(data=data, layout=layout) | 201 fig = go.Figure(data=data, layout=layout) |
| 229 | 202 |
| 230 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 203 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
| 231 # to be discovered by `from_work_dir` | 204 # to be discovered by `from_work_dir` |
| 232 os.rename('output.html', 'output') | 205 os.rename("output.html", "output") |
| 233 | 206 |
| 234 | 207 |
| 235 def visualize_roc_curve_matplotlib(df1, df2, pos_label, | 208 def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None): |
| 236 drop_intermediate=True, | 209 """visualize roc-curve using matplotlib and output svg image""" |
| 237 title=None): | |
| 238 """visualize roc-curve using matplotlib and output svg image | |
| 239 """ | |
| 240 backend = matplotlib.get_backend() | 210 backend = matplotlib.get_backend() |
| 241 if "inline" not in backend: | 211 if "inline" not in backend: |
| 242 matplotlib.use("SVG") | 212 matplotlib.use("SVG") |
| 243 plt.style.use('seaborn-colorblind') | 213 plt.style.use("seaborn-colorblind") |
| 244 plt.figure() | 214 plt.figure() |
| 245 | 215 |
| 246 for idx in range(df1.shape[1]): | 216 for idx in range(df1.shape[1]): |
| 247 y_true = df1.iloc[:, idx].values | 217 y_true = df1.iloc[:, idx].values |
| 248 y_score = df2.iloc[:, idx].values | 218 y_score = df2.iloc[:, idx].values |
| 249 | 219 |
| 250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | 220 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) |
| 251 drop_intermediate=drop_intermediate) | |
| 252 roc_auc = auc(fpr, tpr) | 221 roc_auc = auc(fpr, tpr) |
| 253 | 222 |
| 254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, | 223 plt.step( |
| 255 where="post", label='%s (area = %.3f)' % (idx, roc_auc)) | 224 fpr, |
| 225 tpr, | |
| 226 "r-", | |
| 227 color="black", | |
| 228 alpha=0.3, | |
| 229 lw=1, | |
| 230 where="post", | |
| 231 label="%s (area = %.3f)" % (idx, roc_auc), | |
| 232 ) | |
| 256 | 233 |
| 257 plt.xlim([0.0, 1.0]) | 234 plt.xlim([0.0, 1.0]) |
| 258 plt.ylim([0.0, 1.05]) | 235 plt.ylim([0.0, 1.05]) |
| 259 plt.xlabel('False Positive Rate') | 236 plt.xlabel("False Positive Rate") |
| 260 plt.ylabel('True Positive Rate') | 237 plt.ylabel("True Positive Rate") |
| 261 title = title or 'Receiver Operating Characteristic (ROC) Curve' | 238 title = title or "Receiver Operating Characteristic (ROC) Curve" |
| 262 plt.title(title) | 239 plt.title(title) |
| 263 folder = os.getcwd() | 240 folder = os.getcwd() |
| 264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 241 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
| 265 os.rename(os.path.join(folder, "output.svg"), | 242 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
| 266 os.path.join(folder, "output")) | |
| 267 | 243 |
| 268 | 244 |
| 269 def get_dataframe(file_path, plot_selection, header_name, column_name): | 245 def get_dataframe(file_path, plot_selection, header_name, column_name): |
| 270 header = 'infer' if plot_selection[header_name] else None | 246 header = "infer" if plot_selection[header_name] else None |
| 271 column_option = plot_selection[column_name]["selected_column_selector_option"] | 247 column_option = plot_selection[column_name]["selected_column_selector_option"] |
| 272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: | 248 if column_option in [ |
| 249 "by_index_number", | |
| 250 "all_but_by_index_number", | |
| 251 "by_header_name", | |
| 252 "all_but_by_header_name", | |
| 253 ]: | |
| 273 col = plot_selection[column_name]["col1"] | 254 col = plot_selection[column_name]["col1"] |
| 274 else: | 255 else: |
| 275 col = None | 256 col = None |
| 276 _, input_df = read_columns(file_path, c=col, | 257 _, input_df = read_columns(file_path, c=col, |
| 277 c_option=column_option, | 258 c_option=column_option, |
| 278 return_df=True, | 259 return_df=True, |
| 279 sep='\t', header=header, | 260 sep='\t', header=header, |
| 280 parse_dates=True) | 261 parse_dates=True) |
| 281 return input_df | 262 return input_df |
| 282 | 263 |
| 283 | 264 |
| 284 def main(inputs, infile_estimator=None, infile1=None, | 265 def main( |
| 285 infile2=None, outfile_result=None, | 266 inputs, |
| 286 outfile_object=None, groups=None, | 267 infile_estimator=None, |
| 287 ref_seq=None, intervals=None, | 268 infile1=None, |
| 288 targets=None, fasta_path=None, | 269 infile2=None, |
| 289 model_config=None, true_labels=None, | 270 outfile_result=None, |
| 290 predicted_labels=None, plot_color=None, | 271 outfile_object=None, |
| 291 title=None): | 272 groups=None, |
| 273 ref_seq=None, | |
| 274 intervals=None, | |
| 275 targets=None, | |
| 276 fasta_path=None, | |
| 277 model_config=None, | |
| 278 true_labels=None, | |
| 279 predicted_labels=None, | |
| 280 plot_color=None, | |
| 281 title=None, | |
| 282 ): | |
| 292 """ | 283 """ |
| 293 Parameter | 284 Parameter |
| 294 --------- | 285 --------- |
| 295 inputs : str | 286 inputs : str |
| 296 File path to galaxy tool parameter | 287 File path to galaxy tool parameter |
| 339 Color of the confusion matrix heatmap | 330 Color of the confusion matrix heatmap |
| 340 | 331 |
| 341 title : str, default is None | 332 title : str, default is None |
| 342 Title of the confusion matrix heatmap | 333 Title of the confusion matrix heatmap |
| 343 """ | 334 """ |
| 344 warnings.simplefilter('ignore') | 335 warnings.simplefilter("ignore") |
| 345 | 336 |
| 346 with open(inputs, 'r') as param_handler: | 337 with open(inputs, "r") as param_handler: |
| 347 params = json.load(param_handler) | 338 params = json.load(param_handler) |
| 348 | 339 |
| 349 title = params['plotting_selection']['title'].strip() | 340 title = params["plotting_selection"]["title"].strip() |
| 350 plot_type = params['plotting_selection']['plot_type'] | 341 plot_type = params["plotting_selection"]["plot_type"] |
| 351 plot_format = params['plotting_selection']['plot_format'] | 342 plot_format = params["plotting_selection"]["plot_format"] |
| 352 | 343 |
| 353 if plot_type == 'feature_importances': | 344 if plot_type == "feature_importances": |
| 354 with open(infile_estimator, 'rb') as estimator_handler: | 345 with open(infile_estimator, "rb") as estimator_handler: |
| 355 estimator = load_model(estimator_handler) | 346 estimator = load_model(estimator_handler) |
| 356 | 347 |
| 357 column_option = (params['plotting_selection'] | 348 column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"] |
| 358 ['column_selector_options'] | 349 if column_option in [ |
| 359 ['selected_column_selector_option']) | 350 "by_index_number", |
| 360 if column_option in ['by_index_number', 'all_but_by_index_number', | 351 "all_but_by_index_number", |
| 361 'by_header_name', 'all_but_by_header_name']: | 352 "by_header_name", |
| 362 c = (params['plotting_selection'] | 353 "all_but_by_header_name", |
| 363 ['column_selector_options']['col1']) | 354 ]: |
| 355 c = params["plotting_selection"]["column_selector_options"]["col1"] | |
| 364 else: | 356 else: |
| 365 c = None | 357 c = None |
| 366 | 358 |
| 367 _, input_df = read_columns(infile1, c=c, | 359 _, input_df = read_columns( |
| 368 c_option=column_option, | 360 infile1, |
| 369 return_df=True, | 361 c=c, |
| 370 sep='\t', header='infer', | 362 c_option=column_option, |
| 371 parse_dates=True) | 363 return_df=True, |
| 364 sep="\t", | |
| 365 header="infer", | |
| 366 parse_dates=True, | |
| 367 ) | |
| 372 | 368 |
| 373 feature_names = input_df.columns.values | 369 feature_names = input_df.columns.values |
| 374 | 370 |
| 375 if isinstance(estimator, Pipeline): | 371 if isinstance(estimator, Pipeline): |
| 376 for st in estimator.steps[:-1]: | 372 for st in estimator.steps[:-1]: |
| 377 if isinstance(st[-1], SelectorMixin): | 373 if isinstance(st[-1], SelectorMixin): |
| 378 mask = st[-1].get_support() | 374 mask = st[-1].get_support() |
| 379 feature_names = feature_names[mask] | 375 feature_names = feature_names[mask] |
| 380 estimator = estimator.steps[-1][-1] | 376 estimator = estimator.steps[-1][-1] |
| 381 | 377 |
| 382 if hasattr(estimator, 'coef_'): | 378 if hasattr(estimator, "coef_"): |
| 383 coefs = estimator.coef_ | 379 coefs = estimator.coef_ |
| 384 else: | 380 else: |
| 385 coefs = getattr(estimator, 'feature_importances_', None) | 381 coefs = getattr(estimator, "feature_importances_", None) |
| 386 if coefs is None: | 382 if coefs is None: |
| 387 raise RuntimeError('The classifier does not expose ' | 383 raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes") |
| 388 '"coef_" or "feature_importances_" ' | 384 |
| 389 'attributes') | 385 threshold = params["plotting_selection"]["threshold"] |
| 390 | |
| 391 threshold = params['plotting_selection']['threshold'] | |
| 392 if threshold is not None: | 386 if threshold is not None: |
| 393 mask = (coefs > threshold) | (coefs < -threshold) | 387 mask = (coefs > threshold) | (coefs < -threshold) |
| 394 coefs = coefs[mask] | 388 coefs = coefs[mask] |
| 395 feature_names = feature_names[mask] | 389 feature_names = feature_names[mask] |
| 396 | 390 |
| 397 # sort | 391 # sort |
| 398 indices = np.argsort(coefs)[::-1] | 392 indices = np.argsort(coefs)[::-1] |
| 399 | 393 |
| 400 trace = go.Bar(x=feature_names[indices], | 394 trace = go.Bar(x=feature_names[indices], y=coefs[indices]) |
| 401 y=coefs[indices]) | |
| 402 layout = go.Layout(title=title or "Feature Importances") | 395 layout = go.Layout(title=title or "Feature Importances") |
| 403 fig = go.Figure(data=[trace], layout=layout) | 396 fig = go.Figure(data=[trace], layout=layout) |
| 404 | 397 |
| 405 plotly.offline.plot(fig, filename="output.html", | 398 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
| 406 auto_open=False) | |
| 407 # to be discovered by `from_work_dir` | 399 # to be discovered by `from_work_dir` |
| 408 os.rename('output.html', 'output') | 400 os.rename("output.html", "output") |
| 409 | 401 |
| 410 return 0 | 402 return 0 |
| 411 | 403 |
| 412 elif plot_type in ('pr_curve', 'roc_curve'): | 404 elif plot_type in ("pr_curve", "roc_curve"): |
| 413 df1 = pd.read_csv(infile1, sep='\t', header='infer') | 405 df1 = pd.read_csv(infile1, sep="\t", header="infer") |
| 414 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) | 406 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) |
| 415 | 407 |
| 416 minimum = params['plotting_selection']['report_minimum_n_positives'] | 408 minimum = params["plotting_selection"]["report_minimum_n_positives"] |
| 417 # filter out columns whose n_positives is beblow the threhold | 409 # filter out columns whose n_positives is beblow the threhold |
| 418 if minimum: | 410 if minimum: |
| 419 mask = df1.sum(axis=0) >= minimum | 411 mask = df1.sum(axis=0) >= minimum |
| 420 df1 = df1.loc[:, mask] | 412 df1 = df1.loc[:, mask] |
| 421 df2 = df2.loc[:, mask] | 413 df2 = df2.loc[:, mask] |
| 422 | 414 |
| 423 pos_label = params['plotting_selection']['pos_label'].strip() \ | 415 pos_label = params["plotting_selection"]["pos_label"].strip() or None |
| 424 or None | 416 |
| 425 | 417 if plot_type == "pr_curve": |
| 426 if plot_type == 'pr_curve': | 418 if plot_format == "plotly_html": |
| 427 if plot_format == 'plotly_html': | |
| 428 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) | 419 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) |
| 429 else: | 420 else: |
| 430 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) | 421 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) |
| 431 else: # 'roc_curve' | 422 else: # 'roc_curve' |
| 432 drop_intermediate = (params['plotting_selection'] | 423 drop_intermediate = params["plotting_selection"]["drop_intermediate"] |
| 433 ['drop_intermediate']) | 424 if plot_format == "plotly_html": |
| 434 if plot_format == 'plotly_html': | 425 visualize_roc_curve_plotly( |
| 435 visualize_roc_curve_plotly(df1, df2, pos_label, | 426 df1, |
| 436 drop_intermediate=drop_intermediate, | 427 df2, |
| 437 title=title) | 428 pos_label, |
| 429 drop_intermediate=drop_intermediate, | |
| 430 title=title, | |
| 431 ) | |
| 438 else: | 432 else: |
| 439 visualize_roc_curve_matplotlib( | 433 visualize_roc_curve_matplotlib( |
| 440 df1, df2, pos_label, | 434 df1, |
| 435 df2, | |
| 436 pos_label, | |
| 441 drop_intermediate=drop_intermediate, | 437 drop_intermediate=drop_intermediate, |
| 442 title=title) | 438 title=title, |
| 439 ) | |
| 443 | 440 |
| 444 return 0 | 441 return 0 |
| 445 | 442 |
| 446 elif plot_type == 'rfecv_gridscores': | 443 elif plot_type == "rfecv_gridscores": |
| 447 input_df = pd.read_csv(infile1, sep='\t', header='infer') | 444 input_df = pd.read_csv(infile1, sep="\t", header="infer") |
| 448 scores = input_df.iloc[:, 0] | 445 scores = input_df.iloc[:, 0] |
| 449 steps = params['plotting_selection']['steps'].strip() | 446 steps = params["plotting_selection"]["steps"].strip() |
| 450 steps = safe_eval(steps) | 447 steps = safe_eval(steps) |
| 451 | 448 |
| 452 data = go.Scatter( | 449 data = go.Scatter( |
| 453 x=list(range(len(scores))), | 450 x=list(range(len(scores))), |
| 454 y=scores, | 451 y=scores, |
| 455 text=[str(_) for _ in steps] if steps else None, | 452 text=[str(_) for _ in steps] if steps else None, |
| 456 mode='lines' | 453 mode="lines", |
| 457 ) | 454 ) |
| 458 layout = go.Layout( | 455 layout = go.Layout( |
| 459 xaxis=dict(title="Number of features selected"), | 456 xaxis=dict(title="Number of features selected"), |
| 460 yaxis=dict(title="Cross validation score"), | 457 yaxis=dict(title="Cross validation score"), |
| 461 title=dict( | 458 title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"), |
| 462 text=title or None, | 459 font=dict(family="sans-serif", size=11), |
| 463 x=0.5, | |
| 464 y=0.92, | |
| 465 xanchor='center', | |
| 466 yanchor='top' | |
| 467 ), | |
| 468 font=dict( | |
| 469 family="sans-serif", | |
| 470 size=11 | |
| 471 ), | |
| 472 # control backgroud colors | 460 # control backgroud colors |
| 473 plot_bgcolor='rgba(255,255,255,0)' | 461 plot_bgcolor="rgba(255,255,255,0)", |
| 474 ) | 462 ) |
| 475 """ | 463 """ |
| 476 # legend=dict( | 464 # legend=dict( |
| 477 # x=0.95, | 465 # x=0.95, |
| 478 # y=0, | 466 # y=0, |
| 487 # borderwidth=2 | 475 # borderwidth=2 |
| 488 # ), | 476 # ), |
| 489 """ | 477 """ |
| 490 | 478 |
| 491 fig = go.Figure(data=[data], layout=layout) | 479 fig = go.Figure(data=[data], layout=layout) |
| 492 plotly.offline.plot(fig, filename="output.html", | 480 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
| 493 auto_open=False) | |
| 494 # to be discovered by `from_work_dir` | 481 # to be discovered by `from_work_dir` |
| 495 os.rename('output.html', 'output') | 482 os.rename("output.html", "output") |
| 496 | 483 |
| 497 return 0 | 484 return 0 |
| 498 | 485 |
| 499 elif plot_type == 'learning_curve': | 486 elif plot_type == "learning_curve": |
| 500 input_df = pd.read_csv(infile1, sep='\t', header='infer') | 487 input_df = pd.read_csv(infile1, sep="\t", header="infer") |
| 501 plot_std_err = params['plotting_selection']['plot_std_err'] | 488 plot_std_err = params["plotting_selection"]["plot_std_err"] |
| 502 data1 = go.Scatter( | 489 data1 = go.Scatter( |
| 503 x=input_df['train_sizes_abs'], | 490 x=input_df["train_sizes_abs"], |
| 504 y=input_df['mean_train_scores'], | 491 y=input_df["mean_train_scores"], |
| 505 error_y=dict( | 492 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, |
| 506 array=input_df['std_train_scores'] | 493 mode="lines", |
| 507 ) if plot_std_err else None, | |
| 508 mode='lines', | |
| 509 name="Train Scores", | 494 name="Train Scores", |
| 510 ) | 495 ) |
| 511 data2 = go.Scatter( | 496 data2 = go.Scatter( |
| 512 x=input_df['train_sizes_abs'], | 497 x=input_df["train_sizes_abs"], |
| 513 y=input_df['mean_test_scores'], | 498 y=input_df["mean_test_scores"], |
| 514 error_y=dict( | 499 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, |
| 515 array=input_df['std_test_scores'] | 500 mode="lines", |
| 516 ) if plot_std_err else None, | |
| 517 mode='lines', | |
| 518 name="Test Scores", | 501 name="Test Scores", |
| 519 ) | 502 ) |
| 520 layout = dict( | 503 layout = dict( |
| 521 xaxis=dict( | 504 xaxis=dict(title="No. of samples"), |
| 522 title='No. of samples' | 505 yaxis=dict(title="Performance Score"), |
| 523 ), | |
| 524 yaxis=dict( | |
| 525 title='Performance Score' | |
| 526 ), | |
| 527 # modify these configurations to customize image | 506 # modify these configurations to customize image |
| 528 title=dict( | 507 title=dict( |
| 529 text=title or 'Learning Curve', | 508 text=title or "Learning Curve", |
| 530 x=0.5, | 509 x=0.5, |
| 531 y=0.92, | 510 y=0.92, |
| 532 xanchor='center', | 511 xanchor="center", |
| 533 yanchor='top' | 512 yanchor="top", |
| 534 ), | 513 ), |
| 535 font=dict( | 514 font=dict(family="sans-serif", size=11), |
| 536 family="sans-serif", | |
| 537 size=11 | |
| 538 ), | |
| 539 # control backgroud colors | 515 # control backgroud colors |
| 540 plot_bgcolor='rgba(255,255,255,0)' | 516 plot_bgcolor="rgba(255,255,255,0)", |
| 541 ) | 517 ) |
| 542 """ | 518 """ |
| 543 # legend=dict( | 519 # legend=dict( |
| 544 # x=0.95, | 520 # x=0.95, |
| 545 # y=0, | 521 # y=0, |
| 554 # borderwidth=2 | 530 # borderwidth=2 |
| 555 # ), | 531 # ), |
| 556 """ | 532 """ |
| 557 | 533 |
| 558 fig = go.Figure(data=[data1, data2], layout=layout) | 534 fig = go.Figure(data=[data1, data2], layout=layout) |
| 559 plotly.offline.plot(fig, filename="output.html", | 535 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
| 560 auto_open=False) | |
| 561 # to be discovered by `from_work_dir` | 536 # to be discovered by `from_work_dir` |
| 562 os.rename('output.html', 'output') | 537 os.rename("output.html", "output") |
| 563 | 538 |
| 564 return 0 | 539 return 0 |
| 565 | 540 |
| 566 elif plot_type == 'keras_plot_model': | 541 elif plot_type == "keras_plot_model": |
| 567 with open(model_config, 'r') as f: | 542 with open(model_config, "r") as f: |
| 568 model_str = f.read() | 543 model_str = f.read() |
| 569 model = model_from_json(model_str) | 544 model = model_from_json(model_str) |
| 570 plot_model(model, to_file="output.png") | 545 plot_model(model, to_file="output.png") |
| 571 os.rename('output.png', 'output') | 546 os.rename("output.png", "output") |
| 572 | 547 |
| 573 return 0 | 548 return 0 |
| 574 | 549 |
| 575 elif plot_type == 'classification_confusion_matrix': | 550 elif plot_type == "classification_confusion_matrix": |
| 576 plot_selection = params["plotting_selection"] | 551 plot_selection = params["plotting_selection"] |
| 577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") | 552 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") |
| 578 header_predicted = 'infer' if plot_selection["header_predicted"] else None | 553 header_predicted = "infer" if plot_selection["header_predicted"] else None |
| 579 input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted) | 554 input_predicted = pd.read_csv(predicted_labels, sep="\t", parse_dates=True, header=header_predicted) |
| 580 true_classes = input_true.iloc[:, -1].copy() | 555 true_classes = input_true.iloc[:, -1].copy() |
| 581 predicted_classes = input_predicted.iloc[:, -1].copy() | 556 predicted_classes = input_predicted.iloc[:, -1].copy() |
| 582 axis_labels = list(set(true_classes)) | 557 axis_labels = list(set(true_classes)) |
| 583 c_matrix = confusion_matrix(true_classes, predicted_classes) | 558 c_matrix = confusion_matrix(true_classes, predicted_classes) |
| 584 fig, ax = plt.subplots(figsize=(7, 7)) | 559 fig, ax = plt.subplots(figsize=(7, 7)) |
| 585 im = plt.imshow(c_matrix, cmap=plot_color) | 560 im = plt.imshow(c_matrix, cmap=plot_color) |
| 586 for i in range(len(c_matrix)): | 561 for i in range(len(c_matrix)): |
| 587 for j in range(len(c_matrix)): | 562 for j in range(len(c_matrix)): |
| 588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | 563 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") |
| 589 ax.set_ylabel('True class labels') | 564 ax.set_ylabel("True class labels") |
| 590 ax.set_xlabel('Predicted class labels') | 565 ax.set_xlabel("Predicted class labels") |
| 591 ax.set_title(title) | 566 ax.set_title(title) |
| 592 ax.set_xticks(axis_labels) | 567 ax.set_xticks(axis_labels) |
| 593 ax.set_yticks(axis_labels) | 568 ax.set_yticks(axis_labels) |
| 594 fig.colorbar(im, ax=ax) | 569 fig.colorbar(im, ax=ax) |
| 595 fig.tight_layout() | 570 fig.tight_layout() |
| 596 plt.savefig("output.png", dpi=125) | 571 plt.savefig("output.png", dpi=125) |
| 597 os.rename('output.png', 'output') | 572 os.rename("output.png", "output") |
| 598 | 573 |
| 599 return 0 | 574 return 0 |
| 600 | 575 |
| 601 # save pdf file to disk | 576 # save pdf file to disk |
| 602 # fig.write_image("image.pdf", format='pdf') | 577 # fig.write_image("image.pdf", format='pdf') |
| 603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | 578 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) |
| 604 | 579 |
| 605 | 580 |
| 606 if __name__ == '__main__': | 581 if __name__ == "__main__": |
| 607 aparser = argparse.ArgumentParser() | 582 aparser = argparse.ArgumentParser() |
| 608 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 583 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
| 609 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | 584 aparser.add_argument("-e", "--estimator", dest="infile_estimator") |
| 610 aparser.add_argument("-X", "--infile1", dest="infile1") | 585 aparser.add_argument("-X", "--infile1", dest="infile1") |
| 611 aparser.add_argument("-y", "--infile2", dest="infile2") | 586 aparser.add_argument("-y", "--infile2", dest="infile2") |
| 621 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | 596 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") |
| 622 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | 597 aparser.add_argument("-pc", "--plot_color", dest="plot_color") |
| 623 aparser.add_argument("-pt", "--title", dest="title") | 598 aparser.add_argument("-pt", "--title", dest="title") |
| 624 args = aparser.parse_args() | 599 args = aparser.parse_args() |
| 625 | 600 |
| 626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | 601 main( |
| 627 args.outfile_result, outfile_object=args.outfile_object, | 602 args.inputs, |
| 628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | 603 args.infile_estimator, |
| 629 targets=args.targets, fasta_path=args.fasta_path, | 604 args.infile1, |
| 630 model_config=args.model_config, true_labels=args.true_labels, | 605 args.infile2, |
| 631 predicted_labels=args.predicted_labels, | 606 args.outfile_result, |
| 632 plot_color=args.plot_color, | 607 outfile_object=args.outfile_object, |
| 633 title=args.title) | 608 groups=args.groups, |
| 609 ref_seq=args.ref_seq, | |
| 610 intervals=args.intervals, | |
| 611 targets=args.targets, | |
| 612 fasta_path=args.fasta_path, | |
| 613 model_config=args.model_config, | |
| 614 true_labels=args.true_labels, | |
| 615 predicted_labels=args.predicted_labels, | |
| 616 plot_color=args.plot_color, | |
| 617 title=args.title, | |
| 618 ) |
