Mercurial > repos > greg > linear_fascile_evaluation
changeset 0:cbfa8c336751 draft
Uploaded
author | greg |
---|---|
date | Tue, 28 Nov 2017 13:18:32 -0500 |
parents | |
children | 84a2e30b5404 |
files | linear_fascile_evaluation.py linear_fascile_evaluation.xml |
diffstat | 2 files changed, 169 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/linear_fascile_evaluation.py Tue Nov 28 13:18:32 2017 -0500 @@ -0,0 +1,125 @@ +#!/usr/bin/env python +import argparse +import numpy as np +import os.path as op +import nibabel as nib +import dipy.core.optimize as opt +import dipy.tracking.life as life +import matplotlib.pyplot as plt +import matplotlib + +from dipy.viz.colormap import line_colors +from dipy.viz import fvtk +from mpl_toolkits.axes_grid1 import AxesGrid + +parser = argparse.ArgumentParser() +parser.add_argument('--candidates', dest='candidates', help='Candidates selection') +parser.add_argument('--output_life_candidates', dest='output_life_candidates', help='Output life candidates') +parser.add_argument('--output_life_optimized', dest='output_life_optimized', help='Output life optimized streamlines') +parser.add_argument('--output_beta_histogram', dest='output_beta_histogram', help='Output beta histogram') +parser.add_argument('--output_error_histograms', dest='output_error_histograms', help='Output error histograms') +parser.add_argument('--output_spatial_errors', dest='output_spatial_errors', help='Output spatial errors') + +args = parser.parse_args() + +if not op.exists(args.candidates): + from streamline_tools import * +else: + # We'll need to know where the corpus callosum is from these variables: + from dipy.data import (read_stanford_labels, fetch_stanford_t1, read_stanford_t1) + hardi_img, gtab, labels_img = read_stanford_labels() + labels = labels_img.get_data() + cc_slice = labels == 2 + fetch_stanford_t1() + t1 = read_stanford_t1() + t1_data = t1.get_data() + data = hardi_img.get_data() + +# Read the candidates from file in voxel space: +candidate_sl = [s[0] for s in nib.trackvis.read(args.candidates, points_space='voxel')[0]] +# Visualize the initial candidate group of streamlines +# in 3D, relative to the anatomical structure of this brain. +candidate_streamlines_actor = fvtk.streamtube(candidate_sl, line_colors(candidate_sl)) +cc_ROI_actor = fvtk.contour(cc_slice, levels=[1], colors=[(1., 1., 0.)], opacities=[1.]) +vol_actor = fvtk.slicer(t1_data) +vol_actor.display(40, None, None) +vol_actor2 = vol_actor.copy() +vol_actor2.display(None, None, 35) +# Add display objects to canvas. +ren = fvtk.ren() +fvtk.add(ren, candidate_streamlines_actor) +fvtk.add(ren, cc_ROI_actor) +fvtk.add(ren, vol_actor) +fvtk.add(ren, vol_actor2) +fvtk.record(ren, n_frames=1, out_path=args.output_life_candidates, size=(800, 800)) +# Initialize a LiFE model. +fiber_model = life.FiberModel(gtab) +# Fit the model, producing a FiberFit class instance, +# that stores the data, as well as the results of the +# fitting procedure. +fiber_fit = fiber_model.fit(data, candidate_sl, affine=np.eye(4)) +fig, ax = plt.subplots(1) +ax.hist(fiber_fit.beta, bins=100, histtype='step') +ax.set_xlabel('Fiber weights') +ax.set_ylabel('# fibers') +fig.savefig(args.output_beta_histogram) +# Filter out these redundant streamlines and +# generate an optimized group of streamlines. +optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta>0)[0]]) +ren = fvtk.ren() +fvtk.add(ren, fvtk.streamtube(optimized_sl, line_colors(optimized_sl))) +fvtk.add(ren, cc_ROI_actor) +fvtk.add(ren, vol_actor) +fvtk.record(ren, n_frames=1, out_path=args.output_life_optimized, size=(800, 800)) +model_predict = fiber_fit.predict() +# Focus on the error in prediction of the diffusion-weighted +# data, and calculate the root of the mean squared error. +model_error = model_predict - fiber_fit.data +model_rmse = np.sqrt(np.mean(model_error[:, 10:] ** 2, -1)) +# Calculate another error term by assuming that the weight for each streamline +# is equal to zero. This produces the naive prediction of the mean of the +# signal in each voxel. +beta_baseline = np.zeros(fiber_fit.beta.shape[0]) +pred_weighted = np.reshape(opt.spdot(fiber_fit.life_matrix, beta_baseline), (fiber_fit.vox_coords.shape[0], np.sum(~gtab.b0s_mask))) +mean_pred = np.empty((fiber_fit.vox_coords.shape[0], gtab.bvals.shape[0])) +S0 = fiber_fit.b0_signal +# Since the fitting is done in the demeaned S/S0 domain, +# add back the mean and then multiply by S0 in every voxel: +mean_pred[..., gtab.b0s_mask] = S0[:, None] +mean_pred[..., ~gtab.b0s_mask] = (pred_weighted + fiber_fit.mean_signal[:, None]) * S0[:, None] +mean_error = mean_pred - fiber_fit.data +mean_rmse = np.sqrt(np.mean(mean_error ** 2, -1)) +# Compare the overall distribution of errors +# between these two alternative models of the ROI. +fig, ax = plt.subplots(1) +ax.hist(mean_rmse - model_rmse, bins=100, histtype='step') +ax.text(0.2, 0.9,'Median RMSE, mean model: %.2f' % np.median(mean_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) +ax.text(0.2, 0.8,'Median RMSE, LiFE: %.2f' % np.median(model_rmse), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes) +ax.set_xlabel('RMS Error') +ax.set_ylabel('# voxels') +fig.savefig(args.output_error_histograms) +# Show the spatial distribution of the two error terms, +# and of the improvement with the model fit: +vol_model = np.ones(data.shape[:3]) * np.nan +vol_model[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = model_rmse +vol_mean = np.ones(data.shape[:3]) * np.nan +vol_mean[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse +vol_improve = np.ones(data.shape[:3]) * np.nan +vol_improve[fiber_fit.vox_coords[:, 0], fiber_fit.vox_coords[:, 1], fiber_fit.vox_coords[:, 2]] = mean_rmse - model_rmse +sl_idx = 49 +fig = plt.figure() +fig.subplots_adjust(left=0.05, right=0.95) +ax = AxesGrid(fig, 111, nrows_ncols = (1, 3), label_mode = "1", share_all = True, cbar_location="top", cbar_mode="each", cbar_size="10%", cbar_pad="5%") +ax[0].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) +im = ax[0].matshow(np.rot90(vol_model[sl_idx, :, :]), cmap=matplotlib.cm.hot) +ax.cbar_axes[0].colorbar(im) +ax[1].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) +im = ax[1].matshow(np.rot90(vol_mean[sl_idx, :, :]), cmap=matplotlib.cm.hot) +ax.cbar_axes[1].colorbar(im) +ax[2].matshow(np.rot90(t1_data[sl_idx, :, :]), cmap=matplotlib.cm.bone) +im = ax[2].matshow(np.rot90(vol_improve[sl_idx, :, :]), cmap=matplotlib.cm.RdBu) +ax.cbar_axes[2].colorbar(im) +for lax in ax: + lax.set_xticks([]) + lax.set_yticks([]) +fig.savefig(args.output_spatial_errors)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/linear_fascile_evaluation.xml Tue Nov 28 13:18:32 2017 -0500 @@ -0,0 +1,44 @@ +<tool id="linear_fascile_evaluation" name="Linear fascicle evaluation" version="0.13.0"> + <description>(LiFE) for tractography results</description> + <requirements> + <requirement type="package" version="0.13.0">dipy</requirement> + </requirements> + <command detect_errors="exit_code"><![CDATA[ +python '$__tool_directory__/linear_fascile_evaluation.py' +--candidates '$candidates' +--output_life_candidates '$output_life_candidates' +--output_life_optimized '$output_life_optimized' +--output_error_histograms '$output_error_histograms' +--output_beta_histogram '$output_beta_histogram' +--output_spatial_errors 'output_spatial_errors' + ]]></command> + <inputs> + <param name="candidates" type="select" label="Candidates"> + <option value="lr-superiorfrontal.trk" selected="true">lr-superiorfrontal.trk</option> + </param> + </inputs> + <outputs> + <data name="output_spatial_errors" format="png" label="${tool.name}: Spatial Errors" /> + <data name="output_beta_histogram" format="png" label="${tool.name}: Beta Histogram" /> + <data name="output_error_histograms" format="png" label="${tool.name}: Error Histograms" /> + <data name="output_life_optimized" format="png" label="${tool.name}: LiFE Optimized Streamlines" /> + <data name="output_life_candidates" format="png" label="${tool.name}: LiFE Candidates" /> + </outputs> + <tests> + <test> + </test> + </tests> + <help> +**What it does** + +Uses a forward model that predicts the signal from each of a set of streamlines, and fits a +linear model to these simultaneous predictions for evaluation of tractography results. + +----- + +**Options** + + </help> + <citations> + </citations> +</tool>