0
|
1 #!/usr/bin/env python
|
|
2 import argparse
|
|
3 import numpy as np
|
|
4 import os.path as op
|
|
5 import nibabel as nib
|
|
6 import dipy.core.optimize as opt
|
|
7 import dipy.tracking.life as life
|
|
8 import matplotlib.pyplot as plt
|
|
9 import matplotlib
|
|
10
|
|
11 from dipy.viz.colormap import line_colors
|
|
12 from dipy.viz import fvtk
|
|
13 from mpl_toolkits.axes_grid1 import AxesGrid
|
|
14
|
|
15 parser = argparse.ArgumentParser()
|
|
16 parser.add_argument('--candidates', dest='candidates', help='Candidates selection')
|
|
17 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates')
|
|
18 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines')
|
|
19 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram')
|
|
20 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms')
|
|
21 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors')
|
|
22
|
|
23 args = parser.parse_args()
|
|
24
|
|
25 if not op.exists(args.candidates):
|
|
26 from streamline_tools import *
|
|
27 else:
|
|
28 # We'll need to know where the corpus callosum is from these variables:
|
|
29 from dipy.data import (read_stanford_labels, fetch_stanford_t1, read_stanford_t1)
|
|
30 hardi_img, gtab, labels_img = read_stanford_labels()
|
|
31 labels = labels_img.get_data()
|
|
32 cc_slice = labels == 2
|
|
33 fetch_stanford_t1()
|
|
34 t1 = read_stanford_t1()
|
|
35 t1_data = t1.get_data()
|
|
36 data = hardi_img.get_data()
|
|
37
|
|
38 # Read the candidates from file in voxel space:
|
|
39 candidate_sl = [s[0] for s in nib.trackvis.read(args.candidates, points_space='voxel')[0]]
|
|
40 # Visualize the initial candidate group of streamlines
|
|
41 # in 3D, relative to the anatomical structure of this brain.
|
|
42 candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl))
|
|
43 cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.])
|
|
44 vol_actor = fvtk.slicer(t1_data)
|
|
45 vol_actor.display(40, None, None)
|
|
46 vol_actor2 = vol_actor.copy()
|
|
47 vol_actor2.display(None, None, 35)
|
|
48 # Add display objects to canvas.
|
|
49 ren = fvtk.ren()
|
|
50 fvtk.add(ren, candidate_streamlines_actor)
|
|
51 fvtk.add(ren, cc_ROI_actor)
|
|
52 fvtk.add(ren, vol_actor)
|
|
53 fvtk.add(ren, vol_actor2)
|
|
54 fvtk.record(ren, n_frames=1, out_path=args.output_life_candidates, size=(800, 800))
|
|
55 # Initialize a LiFE model.
|
|
56 fiber_model = life.FiberModel(gtab)
|
|
57 # Fit the model, producing a FiberFit class instance,
|
|
58 # that stores the data, as well as the results of the
|
|
59 # fitting procedure.
|
|
60 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4))
|
|
61 fig, ax = plt.subplots(1)
|
|
62 ax.hist(fiber_fit.beta, bins=100, histtype='step')
|
|
63 ax.set_xlabel('Fiber weights')
|
|
64 ax.set_ylabel('# fibers')
|
|
65 fig.savefig(args.output_beta_histogram)
|
|
66 # Filter out these redundant streamlines and
|
|
67 # generate an optimized group of streamlines.
|
|
68 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta>0)[0]])
|
|
69 ren = fvtk.ren()
|
|
70 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl)))
|
|
71 fvtk.add(ren, cc_ROI_actor)
|
|
72 fvtk.add(ren, vol_actor)
|
|
73 fvtk.record(ren, n_frames=1, out_path=args.output_life_optimized, size=(800, 800))
|
|
74 model_predict = fiber_fit.predict()
|
|
75 # Focus on the error in prediction of the diffusion-weighted
|
|
76 # data, and calculate the root of the mean squared error.
|
|
77 model_error = model_predict - fiber_fit.data
|
|
78 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1))
|
|
79 # Calculate another error term by assuming that the weight for each streamline
|
|
80 # is equal to zero. This produces the naive prediction of the mean of the
|
|
81 # signal in each voxel.
|
|
82 beta_baseline = np.zeros(fiber_fit.beta.shape[0])
|
|
83 pred_weighted = np.reshape(opt.spdot(fiber_fit.life_matrix, beta_baseline), (fiber_fit.vox_coords.shape[0], np.sum(~gtab.b0s_mask)))
|
|
84 mean_pred = np.empty((fiber_fit.vox_coords.shape[0], gtab.bvals.shape[0]))
|
|
85 S0 = fiber_fit.b0_signal
|
|
86 # Since the fitting is done in the demeaned S/S0 domain,
|
|
87 # add back the mean and then multiply by S0 in every voxel:
|
|
88 mean_pred[..., gtab.b0s_mask] = S0[:, None]
|
|
89 mean_pred[..., ~gtab.b0s_mask] = (pred_weighted + fiber_fit.mean_signal[:, None]) * S0[:, None]
|
|
90 mean_error = mean_pred - fiber_fit.data
|
|
91 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1))
|
|
92 # Compare the overall distribution of errors
|
|
93 # between these two alternative models of the ROI.
|
|
94 fig, ax = plt.subplots(1)
|
|
95 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step')
|
|
96 ax.text(0.2, 0.9,'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)
|
|
97 ax.text(0.2, 0.8,'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)
|
|
98 ax.set_xlabel('RMS Error')
|
|
99 ax.set_ylabel('# voxels')
|
|
100 fig.savefig(args.output_error_histograms)
|
|
101 # Show the spatial distribution of the two error terms,
|
|
102 # and of the improvement with the model fit:
|
|
103 vol_model = np.ones(data.shape[:3]) * np.nan
|
|
104 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse
|
|
105 vol_mean = np.ones(data.shape[:3]) * np.nan
|
|
106 vol_mean[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse
|
|
107 vol_improve = np.ones(data.shape[:3]) * np.nan
|
|
108 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse
|
|
109 sl_idx = 49
|
|
110 fig = plt.figure()
|
|
111 fig.subplots_adjust(left=0.05, right=0.95)
|
|
112 ax = AxesGrid(fig, 111, nrows_ncols = (1, 3), label_mode = "1", share_all = True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%")
|
|
113 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone)
|
|
114 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot)
|
|
115 ax.cbar_axes[0].colorbar(im)
|
|
116 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone)
|
|
117 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot)
|
|
118 ax.cbar_axes[1].colorbar(im)
|
|
119 ax[2].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone)
|
|
120 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu)
|
|
121 ax.cbar_axes[2].colorbar(im)
|
|
122 for lax in ax:
|
|
123 lax.set_xticks([])
|
|
124 lax.set_yticks([])
|
|
125 fig.savefig(args.output_spatial_errors)
|