Mercurial > repos > greg > linear_fascile_evaluation
comparison linear_fascile_evaluation.py @ 3:0ddfcb3b5ce6 draft
Uploaded
| author | greg |
|---|---|
| date | Wed, 29 Nov 2017 09:51:57 -0500 |
| parents | 84a2e30b5404 |
| children | eb03934e044f |
comparison
equal
deleted
inserted
replaced
| 2:5daa2541c7fa | 3:0ddfcb3b5ce6 |
|---|---|
| 1 #!/usr/bin/env python | 1 #!/usr/bin/env python |
| 2 import argparse | 2 import argparse |
| 3 import numpy as np | 3 import shutil |
| 4 import os.path as op | 4 |
| 5 import nibabel as nib | |
| 6 import dipy.core.optimize as opt | 5 import dipy.core.optimize as opt |
| 7 import dipy.tracking.life as life | 6 import dipy.tracking.life as life |
| 7 from dipy.data import fetch_stanford_t1, read_stanford_labels, read_stanford_t1 | |
| 8 from dipy.viz import fvtk | |
| 9 from dipy.viz.colormap import line_colors | |
| 10 | |
| 11 import matplotlib | |
| 8 import matplotlib.pyplot as plt | 12 import matplotlib.pyplot as plt |
| 9 import matplotlib | |
| 10 | 13 |
| 11 from dipy.viz.colormap import line_colors | |
| 12 from dipy.viz import fvtk | |
| 13 from mpl_toolkits.axes_grid1 import AxesGrid | 14 from mpl_toolkits.axes_grid1 import AxesGrid |
| 14 from dipy.data import read_stanford_labels, fetch_stanford_t1, read_stanford_t1 | 15 |
| 16 import nibabel as nib | |
| 17 | |
| 18 import numpy as np | |
| 15 | 19 |
| 16 parser = argparse.ArgumentParser() | 20 parser = argparse.ArgumentParser() |
| 17 parser.add_argument('--candidates', dest='candidates', help='Candidates selection') | 21 parser.add_argument('--input', dest='input', help='Track Visualization Header dataset') |
| 18 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') | 22 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') |
| 19 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines') | 23 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines') |
| 20 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram') | 24 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram') |
| 21 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms') | 25 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms') |
| 22 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors') | 26 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors') |
| 31 t1 = read_stanford_t1() | 35 t1 = read_stanford_t1() |
| 32 t1_data = t1.get_data() | 36 t1_data = t1.get_data() |
| 33 data = hardi_img.get_data() | 37 data = hardi_img.get_data() |
| 34 | 38 |
| 35 # Read the candidates from file in voxel space: | 39 # Read the candidates from file in voxel space: |
| 36 candidate_sl = [s[0] for s in nib.trackvis.read(args.candidates, points_space='voxel')[0]] | 40 candidate_sl = [s[0] for s in nib.trackvis.read(args.input, points_space='voxel')[0]] |
| 37 # Visualize the initial candidate group of streamlines | 41 # Visualize the initial candidate group of streamlines |
| 38 # in 3D, relative to the anatomical structure of this brain. | 42 # in 3D, relative to the anatomical structure of this brain. |
| 39 candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl)) | 43 candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl)) |
| 40 cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.]) | 44 cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.]) |
| 41 vol_actor = fvtk.slicer(t1_data) | 45 vol_actor = fvtk.slicer(t1_data) |
| 46 ren = fvtk.ren() | 50 ren = fvtk.ren() |
| 47 fvtk.add(ren, candidate_streamlines_actor) | 51 fvtk.add(ren, candidate_streamlines_actor) |
| 48 fvtk.add(ren, cc_ROI_actor) | 52 fvtk.add(ren, cc_ROI_actor) |
| 49 fvtk.add(ren, vol_actor) | 53 fvtk.add(ren, vol_actor) |
| 50 fvtk.add(ren, vol_actor2) | 54 fvtk.add(ren, vol_actor2) |
| 51 fvtk.record(ren, n_frames=1, out_path=args.output_life_candidates, size=(800, 800)) | 55 fvtk.record(ren, n_frames=1, out_path="life_candidates.png", size=(800, 800)) |
| 56 shutil.move("life_candidates.png", args.output_life_candidates) | |
| 52 # Initialize a LiFE model. | 57 # Initialize a LiFE model. |
| 53 fiber_model = life.FiberModel(gtab) | 58 fiber_model = life.FiberModel(gtab) |
| 54 # Fit the model, producing a FiberFit class instance, | 59 # Fit the model, producing a FiberFit class instance, |
| 55 # that stores the data, as well as the results of the | 60 # that stores the data, as well as the results of the |
| 56 # fitting procedure. | 61 # fitting procedure. |
| 57 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4)) | 62 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4)) |
| 58 fig, ax = plt.subplots(1) | 63 fig, ax = plt.subplots(1) |
| 59 ax.hist(fiber_fit.beta, bins=100, histtype='step') | 64 ax.hist(fiber_fit.beta, bins=100, histtype='step') |
| 60 ax.set_xlabel('Fiber weights') | 65 ax.set_xlabel('Fiber weights') |
| 61 ax.set_ylabel('# fibers') | 66 ax.set_ylabel('# fibers') |
| 62 fig.savefig(args.output_beta_histogram) | 67 fig.savefig("beta_histogram.png") |
| 68 shutil.move("beta_histogram.png", args.output_beta_histogram) | |
| 63 # Filter out these redundant streamlines and | 69 # Filter out these redundant streamlines and |
| 64 # generate an optimized group of streamlines. | 70 # generate an optimized group of streamlines. |
| 65 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta>0)[0]]) | 71 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta > 0)[0]]) |
| 66 ren = fvtk.ren() | 72 ren = fvtk.ren() |
| 67 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl))) | 73 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl))) |
| 68 fvtk.add(ren, cc_ROI_actor) | 74 fvtk.add(ren, cc_ROI_actor) |
| 69 fvtk.add(ren, vol_actor) | 75 fvtk.add(ren, vol_actor) |
| 70 fvtk.record(ren, n_frames=1, out_path=args.output_life_optimized, size=(800, 800)) | 76 fvtk.record(ren, n_frames=1, out_path="optimized.png", size=(800, 800)) |
| 77 shutil.move("optimized.png", args.output_life_optimized) | |
| 71 model_predict = fiber_fit.predict() | 78 model_predict = fiber_fit.predict() |
| 72 # Focus on the error in prediction of the diffusion-weighted | 79 # Focus on the error in prediction of the diffusion-weighted |
| 73 # data, and calculate the root of the mean squared error. | 80 # data, and calculate the root of the mean squared error. |
| 74 model_error = model_predict - fiber_fit.data | 81 model_error = model_predict - fiber_fit.data |
| 75 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1)) | 82 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1)) |
| 88 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1)) | 95 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1)) |
| 89 # Compare the overall distribution of errors | 96 # Compare the overall distribution of errors |
| 90 # between these two alternative models of the ROI. | 97 # between these two alternative models of the ROI. |
| 91 fig, ax = plt.subplots(1) | 98 fig, ax = plt.subplots(1) |
| 92 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step') | 99 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step') |
| 93 ax.text(0.2, 0.9,'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) | 100 ax.text(0.2, 0.9, 'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) |
| 94 ax.text(0.2, 0.8,'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) | 101 ax.text(0.2, 0.8, 'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) |
| 95 ax.set_xlabel('RMS Error') | 102 ax.set_xlabel('RMS Error') |
| 96 ax.set_ylabel('# voxels') | 103 ax.set_ylabel('# voxels') |
| 97 fig.savefig(args.output_error_histograms) | 104 fig.savefig("error_histograms.png") |
| 105 shutil.move("error_histograms.png", args.output_error_histograms) | |
| 98 # Show the spatial distribution of the two error terms, | 106 # Show the spatial distribution of the two error terms, |
| 99 # and of the improvement with the model fit: | 107 # and of the improvement with the model fit: |
| 100 vol_model = np.ones(data.shape[:3]) * np.nan | 108 vol_model = np.ones(data.shape[:3]) * np.nan |
| 101 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse | 109 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse |
| 102 vol_mean = np.ones(data.shape[:3]) * np.nan | 110 vol_mean = np.ones(data.shape[:3]) * np.nan |
| 104 vol_improve = np.ones(data.shape[:3]) * np.nan | 112 vol_improve = np.ones(data.shape[:3]) * np.nan |
| 105 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse | 113 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse |
| 106 sl_idx = 49 | 114 sl_idx = 49 |
| 107 fig = plt.figure() | 115 fig = plt.figure() |
| 108 fig.subplots_adjust(left=0.05, right=0.95) | 116 fig.subplots_adjust(left=0.05, right=0.95) |
| 109 ax = AxesGrid(fig, 111, nrows_ncols = (1, 3), label_mode = "1", share_all = True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%") | 117 ax = AxesGrid(fig, 111, nrows_ncols=(1, 3), label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%") |
| 110 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) | 118 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) |
| 111 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot) | 119 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot) |
| 112 ax.cbar_axes[0].colorbar(im) | 120 ax.cbar_axes[0].colorbar(im) |
| 113 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) | 121 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) |
| 114 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot) | 122 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot) |
| 117 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu) | 125 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu) |
| 118 ax.cbar_axes[2].colorbar(im) | 126 ax.cbar_axes[2].colorbar(im) |
| 119 for lax in ax: | 127 for lax in ax: |
| 120 lax.set_xticks([]) | 128 lax.set_xticks([]) |
| 121 lax.set_yticks([]) | 129 lax.set_yticks([]) |
| 122 fig.savefig(args.output_spatial_errors) | 130 fig.savefig("spatial_errors.png") |
| 131 shutil.move("spatial_errors.png", args.output_spatial_errors) |
