Source code for niwidgets.streamlines

from __future__ import print_function

import os

import ipyvolume as ipv
import nibabel as nib
import numpy as np
from ipywidgets import fixed, interact, widgets


def length(x):
    """Returns the sum of euclidean distances between neighboring points"""
    return np.sum(
        np.sqrt(
            np.sum((x[:-1, :] - x[1:, :]) * (x[:-1, :] - x[1:, :]), axis=1)
        )
    )


def color(x):
    """Returns an approximation for color of line based on its endpoints"""
    dirvec = x[0, :] - x[-1, :]
    return (dirvec / (np.sqrt(np.sum(dirvec * dirvec, axis=-1)))).dot(
        np.eye(3)
    )


[docs]class StreamlineWidget: """ Turns nibabel track files into interactive plots using ipyvolume. Color for each line is rendered as a function of the endpoints of the streamline. Currently, this resolves to red for left-right, green for anterior-posterior, and blue for inferior-superior. Args ---- filename : str, pathlib.Path The path to your ``.trk`` file. Can be a string, or a ``PosixPath`` from python3's pathlib. streamlines : a nibabel streamline object An streamlines attribute of an object loaded by nibabel.streamlines.load """ def __init__(self, filename=None, streamlines=None): if filename: filename = str(filename) if not os.path.isfile(filename): # for Python3 should have FileNotFoundError here raise IOError("file {} not found".format(filename)) if not nib.streamlines.is_supported(filename): raise ValueError( ( "File {0} is not a streamline file supported" " by nibabel" ).format(filename) ) # load data in advance self.streamlines = nib.streamlines.load(filename).streamlines elif streamlines: self.streamlines = streamlines else: raise ValueError( "One of filename or streamlines must be specified" ) self.lines2use_ = None
[docs] def plot(self, display_fraction=0.1, **kwargs): """ This is the main method for this widget. Args ---- display_fraction : float The fraction of streamlines to show percentile : int The initial number of streamlines to show using a percentile of length distribution width : int The width of the figure height : int The height of the figure """ if display_fraction is not None and ( display_fraction > 1 or display_fraction <= 0 ): raise ValueError( "proportion_to_display is a float between 0 and 1" " (0 excluded) or None" ) N = len(self.streamlines) num_streamlines = int(display_fraction * N) indices = np.random.permutation(N)[:num_streamlines] self.lines2use_ = self.streamlines[indices] self._default_plotter(**kwargs)
def _create_mesh(self, indices2use=None): if indices2use is None: lines2use = self.lines2use_ local_colors = self.colors else: lines2use = self.lines2use_[indices2use] local_colors = self.colors[indices2use] x, y, z = np.concatenate(lines2use).T # will contain indices to the verties, [0, 1, 1, 2, 2, 3, 3, 4, 4, 5..] indices = np.zeros( np.sum((len(line) - 1) * 2 for line in lines2use), dtype=np.uint32 ) colors = np.zeros((len(x), 3), dtype=np.float32) vertex_offset = 0 line_offset = 0 line_pointers = [] # if we have a line of 4 vertices, we need to add the indices: # offset + [0, 1, 1, 2, 2, 3] # so we have approx 2x the number of indices compared to vertices for idx, line in enumerate(lines2use): line_length = len(line) # repeat all but the start and end vertex line_indices = np.repeat( np.arange( vertex_offset, vertex_offset + line_length, dtype=indices.dtype, ), 2, )[1:-1] indices[ line_offset : line_offset + line_length * 2 - 2 ] = line_indices line_pointers.append([line_offset, line_length, line_indices]) colors[vertex_offset : vertex_offset + line_length] = local_colors[ idx ] line_offset += line_length * 2 - 2 vertex_offset += line_length return x, y, z, indices, colors, line_pointers def _default_plotter(self, **kwargs): """ Basic plot function to be used if no custom function is specified. This is called by plot, you shouldn't call it directly. """ self.lengths = np.array([length(x) for x in self.lines2use_]) if not ("grayscale" in kwargs and kwargs["grayscale"]): self.colors = np.array([color(x) for x in self.lines2use_]) else: self.colors = np.zeros((len(self.lines2use_), 3), dtype=np.float16) self.colors[:] = [0.5, 0.5, 0.5] self.state = {"threshold": 0, "indices": []} width = 600 height = 600 perc = 80 if "width" in kwargs: width = kwargs["width"] if "height" in kwargs: height = kwargs["height"] if "percentile" in kwargs: perc = kwargs["percentile"] ipv.clear() fig = ipv.figure(width=width, height=height) self.state["fig"] = fig with fig.hold_sync(): x, y, z, indices, colors, self.line_pointers = self._create_mesh() limits = np.array( [ min([x.min(), y.min(), z.min()]), max([x.max(), y.max(), z.max()]), ] ) mesh = ipv.Mesh(x=x, y=y, z=z, lines=indices, color=colors) fig.meshes = [mesh] if "style" not in kwargs: fig.style = { "axes": { "color": "black", "label": {"color": "black"}, "ticklabel": {"color": "black"}, "visible": False, }, "background-color": "white", "box": {"visible": False}, } else: fig.style = kwargs["style"] ipv.pylab._grow_limits(limits, limits, limits) fig.camera_fov = 1 ipv.show() interact( self._plot_lines, state=fixed(self.state), threshold=widgets.FloatSlider( value=np.percentile(self.lengths, perc), min=self.lengths.min() - 1, max=self.lengths.max() - 1, continuous_update=False, ), ) def _plot_lines(self, state, threshold): """ Plots streamlines This function is called by _default_plotter """ if threshold < state["threshold"]: # when threshold is reduced, increase the number of lines state["indices"] = np.where(self.lengths > threshold)[0] with state["fig"].hold_sync(): mesh = state["fig"].meshes[0] copy = mesh.lines.copy() for idx in state["indices"]: ( line_offset, line_length, line_indices, ) = self.line_pointers[idx] copy[ line_offset : line_offset + line_length * 2 - 2 ] = line_indices mesh.lines = copy mesh.send_state("lines") else: # when threshold is increased, decrease the number of lines indices = np.where(self.lengths <= threshold)[0] with state["fig"].hold_sync(): mesh = state["fig"].meshes[0] copy = mesh.lines.copy() for idx in indices: ( line_offset, line_length, line_indices, ) = self.line_pointers[idx] copy[line_offset : line_offset + line_length * 2 - 2] = 0 mesh.lines = copy mesh.send_state("lines") state["indices"] = np.where(self.lengths > threshold)[0] state["threshold"] = threshold