Mercurial > repos > greg > linear_fascile_evaluation
comparison linear_fascile_evaluation.py @ 3:0ddfcb3b5ce6 draft
Uploaded
author | greg |
---|---|
date | Wed, 29 Nov 2017 09:51:57 -0500 |
parents | 84a2e30b5404 |
children | eb03934e044f |
comparison
equal
deleted
inserted
replaced
2:5daa2541c7fa | 3:0ddfcb3b5ce6 |
---|---|
1 #!/usr/bin/env python | 1 #!/usr/bin/env python |
2 import argparse | 2 import argparse |
3 import numpy as np | 3 import shutil |
4 import os.path as op | 4 |
5 import nibabel as nib | |
6 import dipy.core.optimize as opt | 5 import dipy.core.optimize as opt |
7 import dipy.tracking.life as life | 6 import dipy.tracking.life as life |
7 from dipy.data import fetch_stanford_t1, read_stanford_labels, read_stanford_t1 | |
8 from dipy.viz import fvtk | |
9 from dipy.viz.colormap import line_colors | |
10 | |
11 import matplotlib | |
8 import matplotlib.pyplot as plt | 12 import matplotlib.pyplot as plt |
9 import matplotlib | |
10 | 13 |
11 from dipy.viz.colormap import line_colors | |
12 from dipy.viz import fvtk | |
13 from mpl_toolkits.axes_grid1 import AxesGrid | 14 from mpl_toolkits.axes_grid1 import AxesGrid |
14 from dipy.data import read_stanford_labels, fetch_stanford_t1, read_stanford_t1 | 15 |
16 import nibabel as nib | |
17 | |
18 import numpy as np | |
15 | 19 |
16 parser = argparse.ArgumentParser() | 20 parser = argparse.ArgumentParser() |
17 parser.add_argument('--candidates', dest='candidates', help='Candidates selection') | 21 parser.add_argument('--input', dest='input', help='Track Visualization Header dataset') |
18 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') | 22 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') |
19 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines') | 23 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines') |
20 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram') | 24 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram') |
21 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms') | 25 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms') |
22 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors') | 26 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors') |
31 t1 = read_stanford_t1() | 35 t1 = read_stanford_t1() |
32 t1_data = t1.get_data() | 36 t1_data = t1.get_data() |
33 data = hardi_img.get_data() | 37 data = hardi_img.get_data() |
34 | 38 |
35 # Read the candidates from file in voxel space: | 39 # Read the candidates from file in voxel space: |
36 candidate_sl = [s[0] for s in nib.trackvis.read(args.candidates, points_space='voxel')[0]] | 40 candidate_sl = [s[0] for s in nib.trackvis.read(args.input, points_space='voxel')[0]] |
37 # Visualize the initial candidate group of streamlines | 41 # Visualize the initial candidate group of streamlines |
38 # in 3D, relative to the anatomical structure of this brain. | 42 # in 3D, relative to the anatomical structure of this brain. |
39 candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl)) | 43 candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl)) |
40 cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.]) | 44 cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.]) |
41 vol_actor = fvtk.slicer(t1_data) | 45 vol_actor = fvtk.slicer(t1_data) |
46 ren = fvtk.ren() | 50 ren = fvtk.ren() |
47 fvtk.add(ren, candidate_streamlines_actor) | 51 fvtk.add(ren, candidate_streamlines_actor) |
48 fvtk.add(ren, cc_ROI_actor) | 52 fvtk.add(ren, cc_ROI_actor) |
49 fvtk.add(ren, vol_actor) | 53 fvtk.add(ren, vol_actor) |
50 fvtk.add(ren, vol_actor2) | 54 fvtk.add(ren, vol_actor2) |
51 fvtk.record(ren, n_frames=1, out_path=args.output_life_candidates, size=(800, 800)) | 55 fvtk.record(ren, n_frames=1, out_path="life_candidates.png", size=(800, 800)) |
56 shutil.move("life_candidates.png", args.output_life_candidates) | |
52 # Initialize a LiFE model. | 57 # Initialize a LiFE model. |
53 fiber_model = life.FiberModel(gtab) | 58 fiber_model = life.FiberModel(gtab) |
54 # Fit the model, producing a FiberFit class instance, | 59 # Fit the model, producing a FiberFit class instance, |
55 # that stores the data, as well as the results of the | 60 # that stores the data, as well as the results of the |
56 # fitting procedure. | 61 # fitting procedure. |
57 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4)) | 62 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4)) |
58 fig, ax = plt.subplots(1) | 63 fig, ax = plt.subplots(1) |
59 ax.hist(fiber_fit.beta, bins=100, histtype='step') | 64 ax.hist(fiber_fit.beta, bins=100, histtype='step') |
60 ax.set_xlabel('Fiber weights') | 65 ax.set_xlabel('Fiber weights') |
61 ax.set_ylabel('# fibers') | 66 ax.set_ylabel('# fibers') |
62 fig.savefig(args.output_beta_histogram) | 67 fig.savefig("beta_histogram.png") |
68 shutil.move("beta_histogram.png", args.output_beta_histogram) | |
63 # Filter out these redundant streamlines and | 69 # Filter out these redundant streamlines and |
64 # generate an optimized group of streamlines. | 70 # generate an optimized group of streamlines. |
65 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta>0)[0]]) | 71 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta > 0)[0]]) |
66 ren = fvtk.ren() | 72 ren = fvtk.ren() |
67 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl))) | 73 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl))) |
68 fvtk.add(ren, cc_ROI_actor) | 74 fvtk.add(ren, cc_ROI_actor) |
69 fvtk.add(ren, vol_actor) | 75 fvtk.add(ren, vol_actor) |
70 fvtk.record(ren, n_frames=1, out_path=args.output_life_optimized, size=(800, 800)) | 76 fvtk.record(ren, n_frames=1, out_path="optimized.png", size=(800, 800)) |
77 shutil.move("optimized.png", args.output_life_optimized) | |
71 model_predict = fiber_fit.predict() | 78 model_predict = fiber_fit.predict() |
72 # Focus on the error in prediction of the diffusion-weighted | 79 # Focus on the error in prediction of the diffusion-weighted |
73 # data, and calculate the root of the mean squared error. | 80 # data, and calculate the root of the mean squared error. |
74 model_error = model_predict - fiber_fit.data | 81 model_error = model_predict - fiber_fit.data |
75 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1)) | 82 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1)) |
88 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1)) | 95 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1)) |
89 # Compare the overall distribution of errors | 96 # Compare the overall distribution of errors |
90 # between these two alternative models of the ROI. | 97 # between these two alternative models of the ROI. |
91 fig, ax = plt.subplots(1) | 98 fig, ax = plt.subplots(1) |
92 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step') | 99 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step') |
93 ax.text(0.2, 0.9,'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) | 100 ax.text(0.2, 0.9, 'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) |
94 ax.text(0.2, 0.8,'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) | 101 ax.text(0.2, 0.8, 'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) |
95 ax.set_xlabel('RMS Error') | 102 ax.set_xlabel('RMS Error') |
96 ax.set_ylabel('# voxels') | 103 ax.set_ylabel('# voxels') |
97 fig.savefig(args.output_error_histograms) | 104 fig.savefig("error_histograms.png") |
105 shutil.move("error_histograms.png", args.output_error_histograms) | |
98 # Show the spatial distribution of the two error terms, | 106 # Show the spatial distribution of the two error terms, |
99 # and of the improvement with the model fit: | 107 # and of the improvement with the model fit: |
100 vol_model = np.ones(data.shape[:3]) * np.nan | 108 vol_model = np.ones(data.shape[:3]) * np.nan |
101 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse | 109 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse |
102 vol_mean = np.ones(data.shape[:3]) * np.nan | 110 vol_mean = np.ones(data.shape[:3]) * np.nan |
104 vol_improve = np.ones(data.shape[:3]) * np.nan | 112 vol_improve = np.ones(data.shape[:3]) * np.nan |
105 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse | 113 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse |
106 sl_idx = 49 | 114 sl_idx = 49 |
107 fig = plt.figure() | 115 fig = plt.figure() |
108 fig.subplots_adjust(left=0.05, right=0.95) | 116 fig.subplots_adjust(left=0.05, right=0.95) |
109 ax = AxesGrid(fig, 111, nrows_ncols = (1, 3), label_mode = "1", share_all = True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%") | 117 ax = AxesGrid(fig, 111, nrows_ncols=(1, 3), label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%") |
110 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) | 118 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) |
111 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot) | 119 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot) |
112 ax.cbar_axes[0].colorbar(im) | 120 ax.cbar_axes[0].colorbar(im) |
113 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) | 121 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) |
114 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot) | 122 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot) |
117 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu) | 125 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu) |
118 ax.cbar_axes[2].colorbar(im) | 126 ax.cbar_axes[2].colorbar(im) |
119 for lax in ax: | 127 for lax in ax: |
120 lax.set_xticks([]) | 128 lax.set_xticks([]) |
121 lax.set_yticks([]) | 129 lax.set_yticks([]) |
122 fig.savefig(args.output_spatial_errors) | 130 fig.savefig("spatial_errors.png") |
131 shutil.move("spatial_errors.png", args.output_spatial_errors) |