Mercurial > repos > greg > linear_fascile_evaluation
comparison linear_fascile_evaluation.py @ 7:eb03934e044f draft
Uploaded
| author | greg |
|---|---|
| date | Wed, 29 Nov 2017 16:40:08 -0500 |
| parents | 0ddfcb3b5ce6 |
| children | 2de70534993d |
comparison
equal
deleted
inserted
replaced
| 6:8dba8c7c1f53 | 7:eb03934e044f |
|---|---|
| 18 import numpy as np | 18 import numpy as np |
| 19 | 19 |
| 20 parser = argparse.ArgumentParser() | 20 parser = argparse.ArgumentParser() |
| 21 parser.add_argument('--input', dest='input', help='Track Visualization Header dataset') | 21 parser.add_argument('--input', dest='input', help='Track Visualization Header dataset') |
| 22 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') | 22 parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') |
| 23 parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines') | |
| 24 parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram') | |
| 25 parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms') | |
| 26 parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors') | |
| 27 | 23 |
| 28 args = parser.parse_args() | 24 args = parser.parse_args() |
| 29 | 25 |
| 30 # We'll need to know where the corpus callosum is from these variables. | 26 # We'll need to know where the corpus callosum is from these variables. |
| 31 hardi_img, gtab, labels_img = read_stanford_labels() | 27 hardi_img, gtab, labels_img = read_stanford_labels() |
| 52 fvtk.add(ren, cc_ROI_actor) | 48 fvtk.add(ren, cc_ROI_actor) |
| 53 fvtk.add(ren, vol_actor) | 49 fvtk.add(ren, vol_actor) |
| 54 fvtk.add(ren, vol_actor2) | 50 fvtk.add(ren, vol_actor2) |
| 55 fvtk.record(ren, n_frames=1, out_path="life_candidates.png", size=(800, 800)) | 51 fvtk.record(ren, n_frames=1, out_path="life_candidates.png", size=(800, 800)) |
| 56 shutil.move("life_candidates.png", args.output_life_candidates) | 52 shutil.move("life_candidates.png", args.output_life_candidates) |
| 57 # Initialize a LiFE model. | |
| 58 fiber_model = life.FiberModel(gtab) | |
| 59 # Fit the model, producing a FiberFit class instance, | |
| 60 # that stores the data, as well as the results of the | |
| 61 # fitting procedure. | |
| 62 fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4)) | |
| 63 fig, ax = plt.subplots(1) | |
| 64 ax.hist(fiber_fit.beta, bins=100, histtype='step') | |
| 65 ax.set_xlabel('Fiber weights') | |
| 66 ax.set_ylabel('# fibers') | |
| 67 fig.savefig("beta_histogram.png") | |
| 68 shutil.move("beta_histogram.png", args.output_beta_histogram) | |
| 69 # Filter out these redundant streamlines and | |
| 70 # generate an optimized group of streamlines. | |
| 71 optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta > 0)[0]]) | |
| 72 ren = fvtk.ren() | |
| 73 fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl))) | |
| 74 fvtk.add(ren, cc_ROI_actor) | |
| 75 fvtk.add(ren, vol_actor) | |
| 76 fvtk.record(ren, n_frames=1, out_path="optimized.png", size=(800, 800)) | |
| 77 shutil.move("optimized.png", args.output_life_optimized) | |
| 78 model_predict = fiber_fit.predict() | |
| 79 # Focus on the error in prediction of the diffusion-weighted | |
| 80 # data, and calculate the root of the mean squared error. | |
| 81 model_error = model_predict - fiber_fit.data | |
| 82 model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1)) | |
| 83 # Calculate another error term by assuming that the weight for each streamline | |
| 84 # is equal to zero. This produces the naive prediction of the mean of the | |
| 85 # signal in each voxel. | |
| 86 beta_baseline = np.zeros(fiber_fit.beta.shape[0]) | |
| 87 pred_weighted = np.reshape(opt.spdot(fiber_fit.life_matrix, beta_baseline), (fiber_fit.vox_coords.shape[0], np.sum(~gtab.b0s_mask))) | |
| 88 mean_pred = np.empty((fiber_fit.vox_coords.shape[0], gtab.bvals.shape[0])) | |
| 89 S0 = fiber_fit.b0_signal | |
| 90 # Since the fitting is done in the demeaned S/S0 domain, | |
| 91 # add back the mean and then multiply by S0 in every voxel: | |
| 92 mean_pred[..., gtab.b0s_mask] = S0[:, None] | |
| 93 mean_pred[..., ~gtab.b0s_mask] = (pred_weighted + fiber_fit.mean_signal[:, None]) * S0[:, None] | |
| 94 mean_error = mean_pred - fiber_fit.data | |
| 95 mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1)) | |
| 96 # Compare the overall distribution of errors | |
| 97 # between these two alternative models of the ROI. | |
| 98 fig, ax = plt.subplots(1) | |
| 99 ax.hist(mean_rmse - model_rmse, bins=100, histtype='step') | |
| 100 ax.text(0.2, 0.9, 'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) | |
| 101 ax.text(0.2, 0.8, 'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) | |
| 102 ax.set_xlabel('RMS Error') | |
| 103 ax.set_ylabel('# voxels') | |
| 104 fig.savefig("error_histograms.png") | |
| 105 shutil.move("error_histograms.png", args.output_error_histograms) | |
| 106 # Show the spatial distribution of the two error terms, | |
| 107 # and of the improvement with the model fit: | |
| 108 vol_model = np.ones(data.shape[:3]) * np.nan | |
| 109 vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse | |
| 110 vol_mean = np.ones(data.shape[:3]) * np.nan | |
| 111 vol_mean[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse | |
| 112 vol_improve = np.ones(data.shape[:3]) * np.nan | |
| 113 vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse | |
| 114 sl_idx = 49 | |
| 115 fig = plt.figure() | |
| 116 fig.subplots_adjust(left=0.05, right=0.95) | |
| 117 ax = AxesGrid(fig, 111, nrows_ncols=(1, 3), label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%") | |
| 118 ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) | |
| 119 im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot) | |
| 120 ax.cbar_axes[0].colorbar(im) | |
| 121 ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) | |
| 122 im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot) | |
| 123 ax.cbar_axes[1].colorbar(im) | |
| 124 ax[2].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) | |
| 125 im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu) | |
| 126 ax.cbar_axes[2].colorbar(im) | |
| 127 for lax in ax: | |
| 128 lax.set_xticks([]) | |
| 129 lax.set_yticks([]) | |
| 130 fig.savefig("spatial_errors.png") | |
| 131 shutil.move("spatial_errors.png", args.output_spatial_errors) |
