Source code for niwidgets.streamlines
from __future__ import print_function
import os
import ipyvolume as ipv
import nibabel as nib
import numpy as np
from ipywidgets import fixed, interact, widgets
def length(x):
"""Returns the sum of euclidean distances between neighboring points"""
return np.sum(
np.sqrt(
np.sum((x[:-1, :] - x[1:, :]) * (x[:-1, :] - x[1:, :]), axis=1)
)
)
def color(x):
"""Returns an approximation for color of line based on its endpoints"""
dirvec = x[0, :] - x[-1, :]
return (dirvec / (np.sqrt(np.sum(dirvec * dirvec, axis=-1)))).dot(
np.eye(3)
)
[docs]class StreamlineWidget:
"""
Turns nibabel track files into interactive plots using ipyvolume.
Color for each line is rendered as a function of the endpoints of the
streamline. Currently, this resolves to red for left-right, green for
anterior-posterior, and blue for inferior-superior.
Args
----
filename : str, pathlib.Path
The path to your ``.trk`` file. Can be a string, or a
``PosixPath`` from python3's pathlib.
streamlines : a nibabel streamline object
An streamlines attribute of an object loaded by
nibabel.streamlines.load
"""
def __init__(self, filename=None, streamlines=None):
if filename:
filename = str(filename)
if not os.path.isfile(filename):
# for Python3 should have FileNotFoundError here
raise IOError("file {} not found".format(filename))
if not nib.streamlines.is_supported(filename):
raise ValueError(
(
"File {0} is not a streamline file supported"
" by nibabel"
).format(filename)
)
# load data in advance
self.streamlines = nib.streamlines.load(filename).streamlines
elif streamlines:
self.streamlines = streamlines
else:
raise ValueError(
"One of filename or streamlines must be specified"
)
self.lines2use_ = None
[docs] def plot(self, display_fraction=0.1, **kwargs):
"""
This is the main method for this widget.
Args
----
display_fraction : float
The fraction of streamlines to show
percentile : int
The initial number of streamlines to show using a
percentile of length distribution
width : int
The width of the figure
height : int
The height of the figure
"""
if display_fraction is not None and (
display_fraction > 1 or display_fraction <= 0
):
raise ValueError(
"proportion_to_display is a float between 0 and 1"
" (0 excluded) or None"
)
N = len(self.streamlines)
num_streamlines = int(display_fraction * N)
indices = np.random.permutation(N)[:num_streamlines]
self.lines2use_ = self.streamlines[indices]
self._default_plotter(**kwargs)
def _create_mesh(self, indices2use=None):
if indices2use is None:
lines2use = self.lines2use_
local_colors = self.colors
else:
lines2use = self.lines2use_[indices2use]
local_colors = self.colors[indices2use]
x, y, z = np.concatenate(lines2use).T
# will contain indices to the verties, [0, 1, 1, 2, 2, 3, 3, 4, 4, 5..]
indices = np.zeros(
np.sum((len(line) - 1) * 2 for line in lines2use), dtype=np.uint32
)
colors = np.zeros((len(x), 3), dtype=np.float32)
vertex_offset = 0
line_offset = 0
line_pointers = []
# if we have a line of 4 vertices, we need to add the indices:
# offset + [0, 1, 1, 2, 2, 3]
# so we have approx 2x the number of indices compared to vertices
for idx, line in enumerate(lines2use):
line_length = len(line)
# repeat all but the start and end vertex
line_indices = np.repeat(
np.arange(
vertex_offset,
vertex_offset + line_length,
dtype=indices.dtype,
),
2,
)[1:-1]
indices[
line_offset : line_offset + line_length * 2 - 2
] = line_indices
line_pointers.append([line_offset, line_length, line_indices])
colors[vertex_offset : vertex_offset + line_length] = local_colors[
idx
]
line_offset += line_length * 2 - 2
vertex_offset += line_length
return x, y, z, indices, colors, line_pointers
def _default_plotter(self, **kwargs):
"""
Basic plot function to be used if no custom function is specified.
This is called by plot, you shouldn't call it directly.
"""
self.lengths = np.array([length(x) for x in self.lines2use_])
if not ("grayscale" in kwargs and kwargs["grayscale"]):
self.colors = np.array([color(x) for x in self.lines2use_])
else:
self.colors = np.zeros((len(self.lines2use_), 3), dtype=np.float16)
self.colors[:] = [0.5, 0.5, 0.5]
self.state = {"threshold": 0, "indices": []}
width = 600
height = 600
perc = 80
if "width" in kwargs:
width = kwargs["width"]
if "height" in kwargs:
height = kwargs["height"]
if "percentile" in kwargs:
perc = kwargs["percentile"]
ipv.clear()
fig = ipv.figure(width=width, height=height)
self.state["fig"] = fig
with fig.hold_sync():
x, y, z, indices, colors, self.line_pointers = self._create_mesh()
limits = np.array(
[
min([x.min(), y.min(), z.min()]),
max([x.max(), y.max(), z.max()]),
]
)
mesh = ipv.Mesh(x=x, y=y, z=z, lines=indices, color=colors)
fig.meshes = [mesh]
if "style" not in kwargs:
fig.style = {
"axes": {
"color": "black",
"label": {"color": "black"},
"ticklabel": {"color": "black"},
"visible": False,
},
"background-color": "white",
"box": {"visible": False},
}
else:
fig.style = kwargs["style"]
ipv.pylab._grow_limits(limits, limits, limits)
fig.camera_fov = 1
ipv.show()
interact(
self._plot_lines,
state=fixed(self.state),
threshold=widgets.FloatSlider(
value=np.percentile(self.lengths, perc),
min=self.lengths.min() - 1,
max=self.lengths.max() - 1,
continuous_update=False,
),
)
def _plot_lines(self, state, threshold):
"""
Plots streamlines
This function is called by _default_plotter
"""
if threshold < state["threshold"]:
# when threshold is reduced, increase the number of lines
state["indices"] = np.where(self.lengths > threshold)[0]
with state["fig"].hold_sync():
mesh = state["fig"].meshes[0]
copy = mesh.lines.copy()
for idx in state["indices"]:
(
line_offset,
line_length,
line_indices,
) = self.line_pointers[idx]
copy[
line_offset : line_offset + line_length * 2 - 2
] = line_indices
mesh.lines = copy
mesh.send_state("lines")
else:
# when threshold is increased, decrease the number of lines
indices = np.where(self.lengths <= threshold)[0]
with state["fig"].hold_sync():
mesh = state["fig"].meshes[0]
copy = mesh.lines.copy()
for idx in indices:
(
line_offset,
line_length,
line_indices,
) = self.line_pointers[idx]
copy[line_offset : line_offset + line_length * 2 - 2] = 0
mesh.lines = copy
mesh.send_state("lines")
state["indices"] = np.where(self.lengths > threshold)[0]
state["threshold"] = threshold