Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ibt-cmr/modeling/cmr-random-diffmaps
1 result
Show changes
Showing
with 642 additions and 22 deletions
*.npy filter=lfs diff=lfs merge=lfs -text
*.vti filter=lfs diff=lfs merge=lfs -text
*.vtk filter=lfs diff=lfs merge=lfs -text
......@@ -3,3 +3,4 @@ import cmr_rnd_diffusion.diff3d
import cmr_rnd_diffusion.mesh_utils
import cmr_rnd_diffusion.sampling
import cmr_rnd_diffusion.vtk_utils
from cmr_rnd_diffusion._cardiac_mesh_dataset import CardiacMeshDataset, SliceDataset
import os
from typing import List, Union, Tuple, Sequence, Iterable
import pyvista
import numpy as np
from tqdm.notebook import tqdm
from pint import Quantity
import cdtipy
import cmr_rnd_diffusion
import cmr_rnd_diffusion.mesh_utils
import cmr_rnd_diffusion.vtk_utils
import cmr_rnd_diffusion.diff3d
import vtk
class BaseDataset(object):
""" """
#: Mesh containing the displacements for each time-point
mesh: Union[pyvista.UnstructuredGrid, pyvista.DataSet]
#: Time points corresponding to the mesh states given by the list of files in ms
timing: np.ndarray
def __init__(self, mesh: Union[pyvista.UnstructuredGrid, pyvista.DataSet],
timing_ms: np.ndarray):
"""
:param mesh: pyvista mesh containing the field-array 'displacement_{time}ms' for each
time specified in timing_ms
:param timing_ms: array containing the times corresponding to the stored displacements
"""
self.mesh = mesh
self.timing = timing_ms
def add_data_per_time(self, scalars: np.array, name: str):
"""
:param scalars: (t, #p, ...)
:param name:
:return:
"""
for timing, value in zip(self.timing, scalars):
self.mesh[f"{name}_{timing}ms"] = value
def transform_vectors(self, time_stamps: np.ndarray,
vector_names: List[str],
reference_time: Quantity = None,
rotation_only: bool = True):
""" Evalutes the deformation matrix for ref-timing -> time_stamps and applies the
transformation to the vector arrays of specified names in the mesh. Results are stored
as vector field in self.mesh.
:param time_stamps:
:param vector_names: Names of properties that are transformed
:param reference_time: Time that correspond to the definition of the vectors
:param rotation_only: If true the transformation includes a normalizatios, thus results in a
rotation only.
:return:
"""
reference_positions = self.get_positions_at_time(reference_time)
deformation_mesh = pyvista.UnstructuredGrid(
{vtk.VTK_TETRA: self.mesh.cell_connectivity.reshape(-1, 4)},
reference_positions)
# Apply transformation for all time steps
for ts in tqdm(time_stamps, desc="Computing deformation matrices", leave=False):
# Compute the deformation matrix on the mesh serving as transformation
for idx, direction in enumerate("xyz"):
deformation_mesh[f"disp{direction}"] = self.mesh[f"displacements_{ts}ms"][:, idx]
deformation_mesh = deformation_mesh.compute_derivative(scalars=f"disp{direction}",
gradient=f"d{direction}_dxyz")
deformation_mesh[f"d{direction}_dxyz"][
np.logical_not(np.isfinite(deformation_mesh[f"d{direction}_dxyz"]))] = 0
nabla_u = np.stack([deformation_mesh["dx_dxyz"], deformation_mesh["dy_dxyz"],
deformation_mesh["dz_dxyz"]], axis=1)
F = nabla_u + np.eye(3, 3).reshape(1, 3, 3)
# Apply all transformations and store result in self.mesh
for vec_name in vector_names:
self.mesh[f"{vec_name}_{ts}ms"] = np.einsum("nji, nj -> ni",
F, self.mesh[vec_name])
if rotation_only:
norm = np.linalg.norm(self.mesh[f"{vec_name}_{ts}ms"], axis=-1,
keepdims=True)
self.mesh[f"{vec_name}_{ts}ms"] /= norm
def render_input_animation(self, filename: str, scalar: str = None,
vector: str = None, start: Quantity = Quantity(0., "ms"),
end: Quantity = None, mesh_kwargs: dict = None,
vector_kwargs: dict = {"mag": 1e-3}):
""" Saves a gif animation using the original input mesh and displacements.
:param filename:
:param scalar:
:param vector:
:param start:
:param end:
:param mesh_kwargs:
:param vector_kwargs:
"""
plotting_mesh = self.mesh.copy()
index_start = np.searchsorted(self.timing, start.m_as("ms"))
reference_position = np.array(plotting_mesh.points)
initial_position = (reference_position +
plotting_mesh[f"displacements_{self.timing[index_start]}ms"])
if scalar is not None:
plotting_mesh.set_active_scalars(f"{scalar}_{self.timing[index_start]}ms")
if mesh_kwargs is None:
mesh_kwargs = dict()
if end is None:
index_end = -1
else:
index_end = np.searchsorted(self.timing, end.m_as("ms"), side="right")
plotter = pyvista.Plotter(off_screen=True)
plotting_mesh.points = initial_position
plotter.add_mesh(plotting_mesh, **mesh_kwargs)
plotter.open_gif(f"{filename}.gif")
pbar = tqdm(self.timing[index_start:index_end],
desc=f"Rendering Mesh-Animation from {self.timing[index_start]:1.3f}"
f" ms to {self.timing[index_end]:1.3f} ms")
for timing in pbar:
text_actor = plotter.add_text(f"{timing: 1.2f} ms", position='upper_right',
color='white', shadow=False, font_size=26)
plotter.update_coordinates(reference_position +
plotting_mesh[f"displacements_{timing}ms"],
mesh=plotting_mesh, render=False)
if scalar is not None:
pbar.set_postfix_str(f"Update scalar with: {scalar}_{self.timing[index_start]}ms")
plotter.update_scalars(plotting_mesh[f"{scalar}_{timing}ms"],
mesh=plotting_mesh,
render=False)
if vector is not None:
arrow_actor = plotter.add_arrows(
cent=reference_position + plotting_mesh[f"displacements_{timing}ms"],
direction=plotting_mesh[f"{vector}_{timing}ms"],
**vector_kwargs
)
plotter.render()
plotter.write_frame()
plotter.remove_actor(text_actor)
if vector is not None:
plotter.remove_actor(arrow_actor)
plotter.close()
def get_positions_at_time(self, time: Quantity):
""" Returns mesh nodes at the closest match to time argument
:param time: Quantity
:return:
"""
if time is not None:
index_start = np.searchsorted(self.timing, time.m_as("ms"))
reference_position = (self.mesh.points +
self.mesh[f"displacements_{self.timing[index_start]}ms"])
else:
reference_position = np.array(self.mesh.points)
return reference_position
def probe_reference(self, reference_mesh: pyvista.UnstructuredGrid,
field_names: List[str], reference_time: Quantity = None):
""" Probes the reference mesh with the points of self.mesh at reference time and stores
the probed fields matching the given names into self.mesh.
:param reference_mesh:
:param field_names:
:param reference_time:
:return:
"""
reference_positions = self.get_positions_at_time(reference_time)
probing_mesh = pyvista.UnstructuredGrid(
{vtk.VTK_TETRA: self.mesh.cell_connectivity.reshape(-1, 4)},
reference_positions)
probing_mesh = reference_mesh.probe(probing_mesh)
for name in field_names:
self.mesh[name] = probing_mesh[name]
def get_field(self, field_name: str, timing_indices: Iterable[int] = None) \
-> (Quantity, np.ndarray):
""" Gets the specified field at all given time-indices and returns a stacked array
:param field_name:
:param timing_indices:
:return:
"""
if timing_indices is not None:
times = Quantity([self.timing[t] for t in timing_indices], "ms")
else:
times = Quantity(self.timing, "ms")
field_arr = np.stack([self.mesh[f"{field_name}_{t}ms"] for t in times.m], axis=0)
return times, field_arr
class CardiacMeshDataset(BaseDataset):
""" Cardiac mesh model assuming the continuous node-ordering:
[apex, n_points_per_ring(1st slice, 1st shell), n_points_per_ring(2st slice, 1st shell),
..., n_points_per_ring(1st slice, 2nd shell)]
shell --> radial
slice --> longitudinal
"""
#: (n_shells, n_points_per_ring, n_slices)
mesh_numbers: Tuple[int, int, int]
def __init__(self, mesh: Union[pyvista.UnstructuredGrid, pyvista.DataSet],
mesh_numbers: (int, int, int),
timing_ms: np.ndarray):
"""
:param mesh: pyvista mesh containing the field-array 'displacement_{time}ms' for each
time specified in timing_ms
:param mesh_numbers: (shells, circ, long)
:param timing_ms: array containing the times corresponding to the stored displacements
"""
super().__init__(mesh, timing_ms)
self.mesh_numbers = mesh_numbers
@classmethod
def from_list_of_arrays(cls, initial_mesh: pyvista.DataSet, mesh_numbers: (int, int, int),
list_of_files: List, timing: Quantity, time_precision_ms: int = 1):
""" Loads a mesh and its displacements from a list of numpy files each containing the
positional vectors for all points in initial_mesh.
:raises: FileNotFoundError - if not all files in given list of files exist
:param initial_mesh: Initial mesh configuration
:param mesh_numbers: (shells, circ, long)
:param list_of_files: List of .npy containing the position of all mesh nodes per time
:param timing:
:param time_precision_ms: number of decimals used in the field name and timepoints
(e.g. displacements_0.01ms)
"""
# Make sure all file exist before starting to load all of them
files_exist = [os.path.exists(f) for f in list_of_files]
if not all(files_exist):
n_not_found = len(list_of_files) - sum(files_exist)
non_existent_files = "\n\t".join([f for f, exists in zip(list_of_files, files_exist)
if not exists])
raise FileNotFoundError(f"Could not find {n_not_found} files:\n\t {non_existent_files}")
timing_ms = np.around(timing.m_as("ms"), decimals=time_precision_ms)
initial_mesh[f"displacements_{timing_ms[0]}ms"] = np.zeros_like(initial_mesh.points)
for file, time in tqdm(zip(list_of_files[1:], timing_ms[1:]), desc="Loading Files... ",
total=len(list_of_files)):
temp_rvecs = np.load(file)
initial_mesh[f"displacements_{time}ms"] = temp_rvecs - initial_mesh.points
return cls(mesh, timing_ms=timing_ms, mesh_numbers=mesh_numbers)
@classmethod
def from_list_of_meshes(cls, list_of_files: List, timing: Quantity, mesh_numbers: (int, int, int),
field_key: str = "displacement", time_precision_ms: int = 1):
""" Loads
:raises: FileNotFoundError - if not all files in given list of files exist
:param list_of_files: List of .vtk or .vti files each containing a mesh in which the
displacement vectors for each node is stored under the name of
specified by the argument field_key
:param timing:
:param mesh_numbers: (shells, circ, long)
:param field_key: key to extract
:param time_precision_ms: number of decimals used in the field name and timepoints
(e.g. displacements_0.01ms)
"""
# Make sure all file exist before starting to load all of them
files_exist = [os.path.exists(f) for f in list_of_files]
if not all(files_exist):
n_not_found = len(list_of_files) - sum(files_exist)
non_existent_files = "\n\t".join([f for f, exists in zip(list_of_files, files_exist)
if not exists])
raise FileNotFoundError(f"Could not find {n_not_found} files:\n\t {non_existent_files}")
# Load initial vtk file and add all displacements to the mesh as field array
mesh = pyvista.read(list_of_files[0])
timing_ms = np.around(timing.m_as("ms"), decimals=time_precision_ms)
mesh[f"displacements_{timing_ms[0]}ms"] = np.zeros_like(mesh.points)
for file, time in tqdm(zip(list_of_files[1:], timing_ms[1:]), desc="Loading Files... ",
total=len(list_of_files)):
temp_mesh = pyvista.read(file)
mesh[f"displacements_{time}ms"] = temp_mesh[field_key]
return cls(mesh, timing_ms=timing_ms, mesh_numbers=mesh_numbers)
@classmethod
def from_single_vtk(cls, file_name: str, mesh_numbers: (int, int, int),
time_precision_ms: int = 3):
""" Assumes, the displacements to be saved as defined in the other class methods
:param file_name:
:param mesh_numbers: (shells, circ, long)
:param time_precision_ms: number of decimals used in the field name and timepoints
(e.g. displacements_0.01ms)
:return:
"""
mesh = pyvista.read(file_name)
timings = [float(t.replace("displacements_", "").replace("ms", "")) for t
in mesh.array_names if "displacements" in t]
timings.sort()
timings = np.around(np.array(timings), decimals=time_precision_ms)
return cls(mesh, timing_ms=timings, mesh_numbers=mesh_numbers)
def refine(self, longitudinal_factor: int, circumferential_factor: int, radial_factor: int)\
-> 'CardiacMeshDataset':
""" Refines the cardiac mesh model assuming the continous node-ordering:
[apex, n_points_per_ring(1st slice, 1st shell), n_points_per_ring(2st slice, 1st shell),
..., n_points_per_ring(1st slice, 2nd shell)]
shell --> radial
slice --> longitudinal
:param longitudinal_factor:
:param circumferential_factor:
:param radial_factor:
:return:
"""
n_shells, n_points_per_ring, n_slices = self.mesh_numbers
original_ref_points = self.mesh.points + self.mesh[f"displacements_{self.timing[0]}ms"]
refined_ref_positions, new_dimensions = cmr_rnd_diffusion.mesh_utils.interpolate_mesh(
original_ref_points, points_per_ring=n_points_per_ring,
n_shells=n_shells, n_slices=n_slices,
interp_factor_c=circumferential_factor,
interp_factor_l=longitudinal_factor,
interp_factor_r=radial_factor)
connectivity = cmr_rnd_diffusion.mesh_utils.evaluate_connectivity(*new_dimensions)
new_mesh = pyvista.UnstructuredGrid({vtk.VTK_TETRA: connectivity}, refined_ref_positions)
refined_ref_positions = np.array(new_mesh.points)
for idx, timing in tqdm(enumerate(self.timing), total=len(self.timing),
desc=f"Refining mesh by factors: l={longitudinal_factor}, "
f"c={circumferential_factor}, r={radial_factor}"):
original_points = self.mesh.points + self.mesh[f"displacements_{self.timing[idx]}ms"]
refined_positions, new_mesh_numbers = cmr_rnd_diffusion.mesh_utils.interpolate_mesh(
original_points, points_per_ring=n_points_per_ring,
n_shells=n_shells, n_slices=n_slices,
interp_factor_c=circumferential_factor,
interp_factor_l=longitudinal_factor,
interp_factor_r=radial_factor)
new_mesh[f"displacements_{timing}ms"] = refined_positions - refined_ref_positions
new_dataset = CardiacMeshDataset(new_mesh, new_mesh_numbers, self.timing)
return new_dataset
def select_slice(self, slice_position: Quantity, slice_normal: np.ndarray,
slice_thickness: Quantity, reference_time: Quantity) \
-> pyvista.UnstructuredGrid:
"""
:param slice_position: (3, ) vector denoting the slice position
:param slice_normal: (3, ) vector denoting the normal vector of the slice
(is normalized internally)
:param slice_thickness: Thickness of the selected slice (along the slice_normal)
:param reference_time:
:return: mesh containing only the cell being inside the specified slice
"""
time_idx = np.searchsorted(self.timing, reference_time.m_as("ms"), side="right")
all_points = self.mesh.cell_centers().points + \
self.mesh.ptc()[f"displacements_{self.timing[time_idx]}ms"]
slice_normal = slice_normal / np.linalg.norm(slice_normal)
slice_position = slice_position.m_as("m")
d1 = np.einsum("ni, i", all_points - slice_position[np.newaxis, :], slice_normal)
in_slice = np.abs(d1) < slice_thickness.m_as("m") / 2
mesh_copy = self.mesh.copy()
current_slice = mesh_copy.remove_cells(np.where(np.logical_not(in_slice)),
inplace=True)
current_slice.reference_time = reference_time
return current_slice
class SliceDataset(BaseDataset):
#:
position: Quantity
#:
normal: np.ndarray
#:
thickness: Quantity
def __init__(self, mesh: pyvista.UnstructuredGrid, timing: np.ndarray, position: Quantity,
normal: np.ndarray, thickness: Quantity):
super().__init__(mesh, timing)
self.position = position
self.thickness = thickness
self.normal = normal
def get_trajectories(self, start: Quantity, end: Quantity) -> Tuple[Quantity, Quantity]:
""" Returns available times and positions between specified start-end
:param start:
:param end:
:return: time, positions (N, t, 3)
"""
time_idx_start = np.searchsorted(self.timing, start.m_as("ms"), side="right")
time_idx_end = np.searchsorted(self.timing, start.m_as("ms"), side="right")
time = Quantity(self.timing[time_idx_start:time_idx_end], "ms")
positions = Quantity(np.stack([self.mesh.points + self.mesh[f"displacements_{t}ms"]
for t in time.m], axis=1), "m")
return time, positions
%% Cell type:markdown id:254cb070-c731-4686-9ce4-137ab161cb78 tags:
# Demonstration of Input mesh handling
## Import
%% Cell type:code id:2d64a34a-af5c-45bb-aac8-1e0dce39c05c tags:
``` python
import tensorflow as tf
gpu = tf.config.get_visible_devices("GPU")[0]
tf.config.set_visible_devices(gpu, device_type="GPU")
tf.config.experimental.set_memory_growth(gpu, True)
import pyvista
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import vtk
import os
import math
from typing import List, Union, Tuple, Sequence
from pint import Quantity
from scipy.optimize import curve_fit
import cdtipy
import sys
sys.path.insert(0, "..")
import cmr_rnd_diffusion
```
%% Output
C:\Users\Jonat\anaconda3\envs\tf29_pyvista\lib\site-packages\cdtipy\reconstruction.py:11: UserWarning: No matlab installation found!
warnings.warn("No matlab installation found!")
C:\Users\Jonat\anaconda3\envs\tf29_pyvista\lib\site-packages\cdtipy\registration.py:10: UserWarning: No matlab installation found!
warnings.warn("No matlab installation found!")
%% Cell type:markdown id:3aea212a-c69a-4c4c-b95d-4105aefc4638 tags:
## Refine mesh
Load data from list of vtk files (1 per time-stamp) and instantiate a CardiacMeshDataset
%% Cell type:code id:0984ac04-591b-4402-9f3f-05935aae7c45 tags:
``` python
files = [f'example_data/example_mesh/Displacements/Displ{i:03}.vtk' for i in range(0, 73, 1)]
timing = Quantity(np.loadtxt("example_data/example_mesh/PV_loop_reordered.txt")[:, 1], "ms")
original_mesh = cmr_rnd_diffusion.CardiacMeshDataset.from_list_of_meshes(files, timing[:len(files)],
mesh_numbers=(4, 40, 30),
time_precision_ms=3)
```
%% Cell type:markdown id:84763570-d71f-4a88-b24f-6bd15fde8932 tags:
call refinement function
%% Cell type:code id:63c186b1-c8fd-48fb-8a6a-a94b63a79950 tags:
``` python
refined_mesh = original_mesh.refine(longitudinal_factor=2, circumferential_factor=2, radial_factor=3)
refined_mesh.mesh.save(f"refined_mesh_({refined_mesh.mesh_numbers}).vtk")
```
%% Cell type:markdown id:c547d934-e265-4894-b666-614ff26b54ca tags:
Render animations of moving particle using the refined mesh starting at 80 ms (start systole)
%% Cell type:code id:273941ac-6761-41be-9b73-9011a6a742f0 tags:
``` python
# pyvista.start_xvfb() # If on Linux start Buffer manual to prevent memory corruption
refined_mesh.render_input_animation("refined_positions", scalar=None,
start=Quantity(80., "ms"),
end=Quantity(280., "ms"),
cbar_args=dict(show_edges=True, opacity=0.3))
```
%% Cell type:markdown id:4f716711-fdc0-41ce-8598-344a87cb0c67 tags:
## Select a slice at given time:
%% Cell type:code id:03858326-90d5-4287-b97f-c56d4ef30e34 tags:
``` python
refined_mesh = cmr_rnd_diffusion.CardiacMeshDataset.from_single_vtk("refined_mesh_((81, 60, 10)).vtk",
mesh_numbers=(81, 60, 10))
slice_dict = dict(slice_thickness=Quantity(10, "mm"),
slice_normal=np.array((0., 0., 1.)),
slice_position=Quantity([0, 0, -3], "cm"),
reference_time = Quantity(105, "ms"))
temp = refined_mesh.select_slice(**slice_dict)
slice_mesh = cmr_rnd_diffusion.SliceDataset(temp, refined_mesh.timing,
position=slice_dict["slice_position"],
normal=slice_dict["slice_normal"],
thickness=slice_dict["slice_thickness"])
slice_mesh.render_input_animation("slice_trajectory", scalar=None, mesh_kwargs=dict(show_edges=True, opacity=0.3))
```
%% Cell type:markdown id:b0cc54c5-6f10-43c2-b0c6-597a35eeceb1 tags:
## Load reference mesh that contains the fibre-orientations
%% Cell type:code id:f6097e3d-24b9-4eb2-89e4-b64b34b4cfac tags:
``` python
reference_mesh = pyvista.read("example_data/example_mesh/ED_reference_mesh.vtk")
slice_mesh.probe_reference(reference_mesh, ["fibers", "sheets"], reference_time=None)
slice_mesh.transform_vectors(slice_mesh.timing, ["fibers", "sheets"], reference_time=None, rotation_only=True)
slice_mesh.render_input_animation("vectors_slice_fiber", start=Quantity(0, "ms"), end=None, vector="fibers",
mesh_kwargs=dict(opacity=0.4), vector_kwargs=dict(mag=5e-3, color="red"))
slice_mesh.render_input_animation("vectors_slice_sheet", start=Quantity(0, "ms"), end=None, vector="sheets",
mesh_kwargs=dict(opacity=0.4), vector_kwargs=dict(mag=5e-3, color="blue"))
```
%% Output
%% Cell type:code id:cfe2836a-6ce1-4bb5-8c99-177d439ff410 tags:
``` python
reference_mesh = pyvista.read("example_data/example_mesh/ED_reference_mesh.vtk")
refined_mesh.probe_reference(reference_mesh, ["fibers", "sheets"], reference_time=None)
refined_mesh.transform_vectors(refined_mesh.timing, ["fibers"], reference_time=None, rotation_only=True)
refined_mesh.render_input_animation("vectors", start=Quantity(0, "ms"), end=None, vector="fibers",
mesh_kwargs=dict(opacity=0.4), vector_kwargs=dict(mag=5e-3, color="red"))
```
%% Cell type:code id:2559500d-f63f-4642-bc10-12cfa2dd150c tags:
``` python
t, fibers_over_time = slice_mesh.get_field("fibers")
print(fibers_over_time.shape)
```
%% Output
(73, 5756, 3)
%% Cell type:markdown id:boxed-compiler tags:
### Import Modules
%% Cell type:code id:differential-discipline tags:
``` python
import sys
sys.path.append("../")
sys.path.append("../../cmr-interpolation/")
import ipywidgets as widgets
from IPython.display import display
%matplotlib widget
import numpy as np
import tensorflow as tf
import pyvista
import cdtipy
from cmr_rnd_diffusion import diff3d
from cmr_rnd_diffusion import mesh_utils
```
%% Output
C:\Users\Jonat\anaconda3\envs\tf27\lib\site-packages\cdtipy\reconstruction.py:11: UserWarning: No matlab installation found!
warnings.warn("No matlab installation found!")
C:\Users\Jonat\anaconda3\envs\tf27\lib\site-packages\cdtipy\registration.py:10: UserWarning: No matlab installation found!
warnings.warn("No matlab installation found!")
%% Cell type:markdown id:floral-reasoning tags:
### Load Mesh & Calculate Coordinates
%% Cell type:code id:colored-picture tags:
``` python
positional_vectors = np.load(f"./example_data/r_vectors_000.npy")
cell_volumes = np.load(f"./example_data/cell_volumes_000.npy")
cell_volumes = cell_volumes / np.mean(cell_volumes)
# Compute relative cylinder coordinates and local coordinate system (radial, circumferential, longitudinal)
cylinder_coordinates = mesh_utils.evaluate_cylinder_coordinates(positional_vectors, 16, 1+202*150)
local_basis = diff3d.compute_local_basis(positional_vectors, homogeneous_handedness=True)
```
%% Cell type:markdown id:clean-capability tags:
#### Display cylinder coordinates projected onto mesh
%% Cell type:code id:usual-basketball tags:
``` python
acc = widgets.Accordion(children=[widgets.Output() for _ in range(1)])
acc.set_title(0, "Cylinder Coordinates")
display(acc)
with acc.children[0]:
hbox = widgets.HBox([widgets.Output() for _ in range(3)])
display(hbox)
colormaps = ["fire", "jet", "summer"]
colormaps = ["Accent", "jet", "Accent"]
for i in range(3):
with hbox.children[i]:
pl = pyvista.Plotter(notebook=True)
pmesh = pyvista.PolyData(positional_vectors[:, :].reshape(-1, 3))
pl.add_mesh(pmesh, scalars=cylinder_coordinates[..., i], cmap=colormaps[i])
pl.add_title(["Radius", "Angle", "Z"][i])
pl.show(True)
```
%% Output
%% Cell type:markdown id:continuing-pharmaceutical tags:
### Generate Random Tensor-Map
%% Cell type:markdown id:computational-damage tags:
#### Random E2A and Linear HA for selected slice
%% Cell type:code id:sharing-class tags:
``` python
normal = np.array([0., 0., 1.])
slice_position = np.array([0., 0., -0.4e-2])
slice_thickness = 2e-3
slice_thickness = 0.5e-2
slice_coords_cart, in_slice = mesh_utils.select_slice(positional_vectors, normal/np.linalg.norm(normal),
slice_position=slice_position, slice_thickness=slice_thickness)
slice_coordinates = cylinder_coordinates[np.where(in_slice)]
e2a = diff3d.generate_e2a(slice_coordinates,
down_sampling_factor=5**2.,
seed_fraction=0.2,
down_sampling_factor=50.,
seed_fraction=0.3,
prob_high=0.6,
high_angle_value=90.,
low_angle_value=1.,
kernel_width=0.005,
kernel_width=0.15,
distance_weights=(1., 1.5, 15.),
verbose=True)
ha = diff3d.generate_ha(slice_coordinates[..., 0], angle_range=(65, -65))
```
%% Output
Scalar interpolation running: 100 %
%% Cell type:markdown id:czech-specification tags:
#### Interpolate randomly scaled reference tensors
%% Cell type:code id:democratic-installation tags:
``` python
from cmr_rnd_diffusion import sampling
# samplers = [sampling.MCRejectionSampler.from_savepath(f"./sampler_diffusion_{s}.npz",
# sampling_method=sampling.exponentially_sample_simplex)
# for s in ("healthy_exp", )]
# eigen_value_pool = samplers[0].get_samples(10000)
samplers = [sampling.MCRejectionSampler.from_savepath(f"./sampler_diffusion_{s}.npz",
sampling_method=sampling.exponentially_sample_simplex)
for s in ("healthy_exp", )]
with tf.device("CPU"):
eigen_value_pool = samplers[0].get_samples(10000)
slice_local_basis = local_basis[np.where(in_slice)]
flat_eigen_basis = diff3d.create_eigenbasis(slice_local_basis, ha, e2a)
tensors = diff3d.generate_tensors(eigen_value_pool, flat_eigen_basis, slice_coordinates,
seed_fraction=0.2, distance_weights=(1., 1.5, 15.), kernel_variance=0.025)
seed_fraction=0.2, distance_weights=(1., 1.5, 15.), kernel_variance=0.05)
```
%% Output
Tensor-Interpolation running: 100 % 10000
%% Cell type:code id:cognitive-struggle tags:
``` python
from tqdm.notebook import tqdm
def _plot(m, s, cmap, title):
pl = pyvista.Plotter()
pl.add_title(title)
pl.add_mesh(m, scalars=s, cmap=cmap)
pl.show()
acc = widgets.Accordion(children=[widgets.Output() for _ in range(4)])
[acc.set_title(i, t) for i, t in enumerate(["Angle Maps", "Metrics", "Histograms", "Ellipsoid Glyphs"])]
display(acc)
with acc.children[0]:
hbox = widgets.HBox([widgets.Output(), widgets.Output()])
display(hbox)
with hbox.children[0]:
_plot(pyvista.PolyData(slice_coords_cart), np.rad2deg(e2a), "seismic", "E2A")
with hbox.children[1]:
_plot(pyvista.PolyData(slice_coords_cart), np.rad2deg(ha), "jet", "HA")
with acc.children[1]:
metrics = cdtipy.tensor_metrics.metric_set(tensors)
metrics["HA"] = cdtipy.angulation.helix_angle(metrics["evecs"][0, ..., 0], slice_local_basis)
metrics["E2A"] = np.abs(cdtipy.angulation.sheetlet_angle(metrics["evecs"][0, ..., 0], metrics["evecs"][0, ..., 1], slice_local_basis))
hbox2 = widgets.HBox([widgets.Output() for _ in range(4)])
display(hbox2)
for idx, k, cmap in zip(range(4), ["MD", "FA", "HA", "E2A"], ["viridis", "fire", "jet", "seismic"]):
for idx, k, cmap in zip(range(4), ["MD", "FA", "HA", "E2A"], ["viridis", "winter", "jet", "seismic"]):
with hbox2.children[idx]:
_plot(pyvista.PolyData(slice_coords_cart), metrics[k], cmap, k)
with acc.children[2]:
cdtipy.plotting.scalars.hist_summary(metrics)
# with acc.children[3]:
# plotter = pyvista.Plotter()
# for n in tqdm(range(0, 3500, 5)):
# eig_val, eig_vec = metrics["evals"][0][n] * 4e-4, metrics["evecs"][0][n:n+1, ..., 0]
# g = pyvista.ParametricEllipsoid(*eig_val)
# m = pyvista.PolyData(slice_coords_cart[n:n+1])
# m["evec"] = eig_vec
# glyph = m.glyph(orient="evec", geom=g, )
# plotter.add_mesh(glyph, scalars=np.tile(metrics["FA"][0][n:n+1], [10000, 1]), cmap="jet")
# plotter.show()
with acc.children[3]:
plotter = pyvista.Plotter()
for n in tqdm(range(0, slice_coords_cart.shape[0], 20)):
eig_val, eig_vec = metrics["evals"][0][n] * 4e-4, metrics["evecs"][0][n:n+1, ..., 0]
g = pyvista.ParametricEllipsoid(*eig_val)
m = pyvista.PolyData(slice_coords_cart[n:n+1])
m["evec"] = eig_vec
glyph = m.glyph(orient="evec", geom=g, )
plotter.add_mesh(glyph, scalars=np.tile(metrics["HA"][n], [len(glyph.points),]), cmap="jet")
plotter.show()
```
%% Cell type:code id:2d58913a-1119-4645-bbda-5e450a4a4345 tags:
%% Cell type:code id:59540b46-9639-4cc7-8a15-16a8ef0306ea tags:
``` python
```
......
No preview for this file type
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added