comparison linear_fascile_evaluation.py @ 3:0ddfcb3b5ce6 draft

Uploaded
author greg
date Wed, 29 Nov 2017 09:51:57 -0500
parents 84a2e30b5404
children eb03934e044f
comparison
equal deleted inserted replaced
2:5daa2541c7fa 3:0ddfcb3b5ce6
1 #!/usr/bin/env python 1 #!/usr/bin/env python
2 import argparse 2 import argparse
3 import numpy as np 3 import shutil
4 import os.path as op 4
5 import nibabel as nib
6 import dipy.core.optimize as opt 5 import dipy.core.optimize as opt
7 import dipy.tracking.life as life 6 import dipy.tracking.life as life
7 from dipy.data import fetch_stanford_t1, read_stanford_labels, read_stanford_t1
8 from dipy.viz import fvtk
9 from dipy.viz.colormap import line_colors
10
11 import matplotlib
8 import matplotlib.pyplot as plt 12 import matplotlib.pyplot as plt
9 import matplotlib
10 13
11 from dipy.viz.colormap import line_colors
12 from dipy.viz import fvtk
13 from mpl_toolkits.axes_grid1 import AxesGrid 14 from mpl_toolkits.axes_grid1 import AxesGrid
14 from dipy.data import read_stanford_labels, fetch_stanford_t1, read_stanford_t1 15
16 import nibabel as nib
17
18 import numpy as np
15 19
16 parser = argparse.ArgumentParser() 20 parser = argparse.ArgumentParser()
17 parser.add_argument('--candidates', dest='candidates', help='Candidates selection') 21 parser.add_argument('--input', dest='input', help='Track Visualization Header dataset')
18 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') 22 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates')
19 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines') 23 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines')
20 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram') 24 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram')
21 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms') 25 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms')
22 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors') 26 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors')
31 t1 = read_stanford_t1() 35 t1 = read_stanford_t1()
32 t1_data = t1.get_data() 36 t1_data = t1.get_data()
33 data = hardi_img.get_data() 37 data = hardi_img.get_data()
34 38
35 # Read the candidates from file in voxel space: 39 # Read the candidates from file in voxel space:
36 candidate_sl = [s[0] for s in nib.trackvis.read(args.candidates, points_space='voxel')[0]] 40 candidate_sl = [s[0] for s in nib.trackvis.read(args.input, points_space='voxel')[0]]
37 # Visualize the initial candidate group of streamlines 41 # Visualize the initial candidate group of streamlines
38 # in 3D, relative to the anatomical structure of this brain. 42 # in 3D, relative to the anatomical structure of this brain.
39 candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl)) 43 candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl))
40 cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.]) 44 cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.])
41 vol_actor = fvtk.slicer(t1_data) 45 vol_actor = fvtk.slicer(t1_data)
46 ren = fvtk.ren() 50 ren = fvtk.ren()
47 fvtk.add(ren, candidate_streamlines_actor) 51 fvtk.add(ren, candidate_streamlines_actor)
48 fvtk.add(ren, cc_ROI_actor) 52 fvtk.add(ren, cc_ROI_actor)
49 fvtk.add(ren, vol_actor) 53 fvtk.add(ren, vol_actor)
50 fvtk.add(ren, vol_actor2) 54 fvtk.add(ren, vol_actor2)
51 fvtk.record(ren, n_frames=1, out_path=args.output_life_candidates, size=(800, 800)) 55 fvtk.record(ren, n_frames=1, out_path="life_candidates.png", size=(800, 800))
56 shutil.move("life_candidates.png", args.output_life_candidates)
52 # Initialize a LiFE model. 57 # Initialize a LiFE model.
53 fiber_model = life.FiberModel(gtab) 58 fiber_model = life.FiberModel(gtab)
54 # Fit the model, producing a FiberFit class instance, 59 # Fit the model, producing a FiberFit class instance,
55 # that stores the data, as well as the results of the 60 # that stores the data, as well as the results of the
56 # fitting procedure. 61 # fitting procedure.
57 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4)) 62 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4))
58 fig, ax = plt.subplots(1) 63 fig, ax = plt.subplots(1)
59 ax.hist(fiber_fit.beta, bins=100, histtype='step') 64 ax.hist(fiber_fit.beta, bins=100, histtype='step')
60 ax.set_xlabel('Fiber weights') 65 ax.set_xlabel('Fiber weights')
61 ax.set_ylabel('# fibers') 66 ax.set_ylabel('# fibers')
62 fig.savefig(args.output_beta_histogram) 67 fig.savefig("beta_histogram.png")
68 shutil.move("beta_histogram.png", args.output_beta_histogram)
63 # Filter out these redundant streamlines and 69 # Filter out these redundant streamlines and
64 # generate an optimized group of streamlines. 70 # generate an optimized group of streamlines.
65 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta>0)[0]]) 71 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta > 0)[0]])
66 ren = fvtk.ren() 72 ren = fvtk.ren()
67 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl))) 73 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl)))
68 fvtk.add(ren, cc_ROI_actor) 74 fvtk.add(ren, cc_ROI_actor)
69 fvtk.add(ren, vol_actor) 75 fvtk.add(ren, vol_actor)
70 fvtk.record(ren, n_frames=1, out_path=args.output_life_optimized, size=(800, 800)) 76 fvtk.record(ren, n_frames=1, out_path="optimized.png", size=(800, 800))
77 shutil.move("optimized.png", args.output_life_optimized)
71 model_predict = fiber_fit.predict() 78 model_predict = fiber_fit.predict()
72 # Focus on the error in prediction of the diffusion-weighted 79 # Focus on the error in prediction of the diffusion-weighted
73 # data, and calculate the root of the mean squared error. 80 # data, and calculate the root of the mean squared error.
74 model_error = model_predict - fiber_fit.data 81 model_error = model_predict - fiber_fit.data
75 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1)) 82 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1))
88 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1)) 95 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1))
89 # Compare the overall distribution of errors 96 # Compare the overall distribution of errors
90 # between these two alternative models of the ROI. 97 # between these two alternative models of the ROI.
91 fig, ax = plt.subplots(1) 98 fig, ax = plt.subplots(1)
92 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step') 99 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step')
93 ax.text(0.2, 0.9,'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) 100 ax.text(0.2, 0.9, 'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)
94 ax.text(0.2, 0.8,'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) 101 ax.text(0.2, 0.8, 'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)
95 ax.set_xlabel('RMS Error') 102 ax.set_xlabel('RMS Error')
96 ax.set_ylabel('# voxels') 103 ax.set_ylabel('# voxels')
97 fig.savefig(args.output_error_histograms) 104 fig.savefig("error_histograms.png")
105 shutil.move("error_histograms.png", args.output_error_histograms)
98 # Show the spatial distribution of the two error terms, 106 # Show the spatial distribution of the two error terms,
99 # and of the improvement with the model fit: 107 # and of the improvement with the model fit:
100 vol_model = np.ones(data.shape[:3]) * np.nan 108 vol_model = np.ones(data.shape[:3]) * np.nan
101 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse 109 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse
102 vol_mean = np.ones(data.shape[:3]) * np.nan 110 vol_mean = np.ones(data.shape[:3]) * np.nan
104 vol_improve = np.ones(data.shape[:3]) * np.nan 112 vol_improve = np.ones(data.shape[:3]) * np.nan
105 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse 113 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse
106 sl_idx = 49 114 sl_idx = 49
107 fig = plt.figure() 115 fig = plt.figure()
108 fig.subplots_adjust(left=0.05, right=0.95) 116 fig.subplots_adjust(left=0.05, right=0.95)
109 ax = AxesGrid(fig, 111, nrows_ncols = (1, 3), label_mode = "1", share_all = True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%") 117 ax = AxesGrid(fig, 111, nrows_ncols=(1, 3), label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%")
110 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) 118 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone)
111 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot) 119 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot)
112 ax.cbar_axes[0].colorbar(im) 120 ax.cbar_axes[0].colorbar(im)
113 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) 121 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone)
114 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot) 122 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot)
117 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu) 125 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu)
118 ax.cbar_axes[2].colorbar(im) 126 ax.cbar_axes[2].colorbar(im)
119 for lax in ax: 127 for lax in ax:
120 lax.set_xticks([]) 128 lax.set_xticks([])
121 lax.set_yticks([]) 129 lax.set_yticks([])
122 fig.savefig(args.output_spatial_errors) 130 fig.savefig("spatial_errors.png")
131 shutil.move("spatial_errors.png", args.output_spatial_errors)