From a6b1df50f2ad8f6fe5437827a161a7810b3bb881 Mon Sep 17 00:00:00 2001 From: Jonathan Weine <jweine@ethz.ch> Date: Mon, 31 Jul 2023 16:20:53 +0200 Subject: [PATCH 1/3] Added jit-compilations to contrast modules --- .../analytic/contrast/coil_sensitivities.py | 5 ++--- .../analytic/contrast/diffusion_weighting.py | 7 ++++--- cmrsim/analytic/contrast/offresonance.py | 14 +++++++------ cmrsim/analytic/contrast/phase_tracking.py | 20 ++++++++++++------- cmrsim/analytic/contrast/sequences.py | 11 +++++----- cmrsim/analytic/contrast/t2_star.py | 9 +++++---- 6 files changed, 38 insertions(+), 28 deletions(-) diff --git a/cmrsim/analytic/contrast/coil_sensitivities.py b/cmrsim/analytic/contrast/coil_sensitivities.py index 162bf1f..5bc3757 100644 --- a/cmrsim/analytic/contrast/coil_sensitivities.py +++ b/cmrsim/analytic/contrast/coil_sensitivities.py @@ -33,7 +33,6 @@ class CoilSensitivity(LookUpTableModule): """ required_quantities = ('r_vectors', ) - _device2: str #: Spatial extend of the look up as: (3, 2) -> [(xlow, xhigh), (ylow, yhigh), (zlow, zhigh)] map_dimensions: tf.Variable # float32 @@ -61,7 +60,6 @@ class CoilSensitivity(LookUpTableModule): name="coil_sensitivities", device=device, expand_repetitions=True) - self._device2 = device self.expansion_factor = coils.shape[-1] // 2 def update(self): @@ -69,6 +67,7 @@ class CoilSensitivity(LookUpTableModule): self.expansion_factor = self.look_up_map3d.read_value().shape[-1] // 2 # pylint: disable=signature-differs, too-many-locals + @tf.function(jit_compile=True) def __call__(self, signal_tensor: tf.Tensor, r_vectors: tf.Tensor, **kwargs) -> tf.Tensor: """Evaluation of coil-sensitivity weighting: @@ -80,7 +79,7 @@ class CoilSensitivity(LookUpTableModule): :param r_vectors: (#voxels, #repetitions, #samples, 3) :return: """ - with tf.device(self._device2): + with tf.device(self.device): input_shape = tf.shape(signal_tensor) tf.Assert(tf.shape(r_vectors)[1] == 1 or tf.shape(r_vectors)[1] == input_shape[1], diff --git a/cmrsim/analytic/contrast/diffusion_weighting.py b/cmrsim/analytic/contrast/diffusion_weighting.py index 0ce105b..16c4455 100644 --- a/cmrsim/analytic/contrast/diffusion_weighting.py +++ b/cmrsim/analytic/contrast/diffusion_weighting.py @@ -33,10 +33,11 @@ class GaussianDiffusionTensor(BaseSignalModel): :param expand_repetitions: See BaseSignalModel for explanation """ super().__init__(name="dti_gauss", expand_repetitions=expand_repetitions, **kwargs) - self.b_vectors = tf.Variable(b_vectors, shape=(None, 3), dtype=tf.float32, name='b_vectors') - self.update() + with tf.device(self.device): + self.b_vectors = tf.Variable(b_vectors, shape=(None, 3), dtype=tf.float32, name='b_vectors') + self.update() - @tf.Module.with_name_scope + @tf.function(jit_compile=True) def __call__(self, signal_tensor: tf.Tensor, diffusion_tensor: tf.Tensor, **kwargs)\ -> tf.Tensor: """ Evaluates the diffusion operator. diff --git a/cmrsim/analytic/contrast/offresonance.py b/cmrsim/analytic/contrast/offresonance.py index 2c8d9d4..67cd5a4 100644 --- a/cmrsim/analytic/contrast/offresonance.py +++ b/cmrsim/analytic/contrast/offresonance.py @@ -15,7 +15,7 @@ class OffResonanceProperty(BaseSignalModel): .. math:: - M_{xy}^{i, \\prime} = M_{xy}^{i} exp(-\gamma \\Delta B \\delta t), + M_{xy}^{i, \\prime} = M_{xy}^{i} exp(-j\gamma \\Delta B \\delta t), where :math:`\\delta t` is the duration from the RF-excitation center for GRE sequences while it is the temporal distance to the echo-center for SE sequences. @@ -53,17 +53,19 @@ class OffResonanceProperty(BaseSignalModel): f"(#repetitions, #k-space-points), but is: {sampling_times.shape}") if len(sampling_times.shape) == 1: sampling_times = tf.reshape(sampling_times, [1, -1]) - - self.gamma = tf.constant(gamma, dtype=tf.float32) - self.sampling_times = tf.Variable(tf.constant(sampling_times, tf.float32), - shape=(None, None), dtype=tf.float32) - self.expansion_factor = sampling_times.shape[0] + + with tf.device(self.device): + self.gamma = tf.constant(gamma, dtype=tf.float32) + self.sampling_times = tf.Variable(tf.constant(sampling_times, tf.float32), + shape=(None, None), dtype=tf.float32) + self.expansion_factor = sampling_times.shape[0] def update(self): super().update() self.expansion_factor == tf.shape(self.sampling_times.read_value())[0] # pylint: disable=signature-differs + @tf.function(jit_compile=True) def __call__(self, signal_tensor: tf.Tensor, deltaB0: tf.Tensor, **kwargs): """ Evaluates the off-resonance operator under the assumption diff --git a/cmrsim/analytic/contrast/phase_tracking.py b/cmrsim/analytic/contrast/phase_tracking.py index 81b704a..118aec3 100644 --- a/cmrsim/analytic/contrast/phase_tracking.py +++ b/cmrsim/analytic/contrast/phase_tracking.py @@ -22,6 +22,10 @@ class PhaseTracking(BaseSignalModel): :param gamma: Gyromagnetic ratio in rad/ms/mT """ required_quantities = ('trajectory', ) + #: (#repetitions, #time-steps, 4[t, x, y, z]) + wave_form: tf.Variable + #: gyromagnetic ratio in rad/mT/ms + gamma: tf.Tensor def __init__(self, expand_repetitions: bool, wave_form: tf.Tensor, gamma: float, **kwargs): @@ -31,13 +35,15 @@ class PhaseTracking(BaseSignalModel): :param gamma: Gyromagnetic ratio in rad/ms/mT """ super().__init__(expand_repetitions=expand_repetitions, name="phase_tracking", **kwargs) - self.wave_form = tf.Variable(wave_form, shape=(None, None, 4), - dtype=tf.float32, name='wave_form') - self.n_steps = tf.Variable(0, shape=(), dtype=tf.int64, name="n_steps") - self.gamma = tf.constant(gamma, dtype=tf.float32) - self.expansion_factor = tf.shape(self.wave_form)[0] - self.update() - + with tf.device(self.device): + self.wave_form = tf.Variable(wave_form, shape=(None, None, 4), + dtype=tf.float32, name='wave_form') + self.n_steps = tf.Variable(0, shape=(), dtype=tf.int64, name="n_steps") + self.gamma = tf.constant(gamma, dtype=tf.float32) + self.expansion_factor = tf.shape(self.wave_form)[0] + self.update() + + # @tf.function(jit_compile=True) def __call__(self, signal_tensor: tf.Tensor, trajectory: tf.Tensor, **kwargs) -> tf.Tensor: """ Evaluates the phase integration for given trajectories diff --git a/cmrsim/analytic/contrast/sequences.py b/cmrsim/analytic/contrast/sequences.py index 3720180..3c0620b 100644 --- a/cmrsim/analytic/contrast/sequences.py +++ b/cmrsim/analytic/contrast/sequences.py @@ -50,12 +50,13 @@ class SpinEcho(BaseSignalModel): f"of elements. Got {echo_time.shape} and {repetition_time.shape}") super().__init__(expand_repetitions=expand_repetitions, name="spin_echo", **kwargs) - self.TE = tf.Variable(echo_time, shape=(None, ), dtype=tf.float32, name='TE') - self.TR = tf.Variable(repetition_time, shape=(None, ), dtype=tf.float32, name='TR') - self.expansion_factor = self.TE.read_value().shape[0] - self.update() + with tf.device(self.device): + self.TE = tf.Variable(echo_time, shape=(None, ), dtype=tf.float32, name='TE') + self.TR = tf.Variable(repetition_time, shape=(None, ), dtype=tf.float32, name='TR') + self.expansion_factor = self.TE.read_value().shape[0] + self.update() - @tf.Module.with_name_scope + @tf.function(jit_compile=True) def __call__(self, signal_tensor: tf.Tensor, T1: tf.Tensor, T2: tf.Tensor, **kwargs): # noqa """ Evaluates spin echo operator diff --git a/cmrsim/analytic/contrast/t2_star.py b/cmrsim/analytic/contrast/t2_star.py index 72d5507..ebfa223 100644 --- a/cmrsim/analytic/contrast/t2_star.py +++ b/cmrsim/analytic/contrast/t2_star.py @@ -62,16 +62,17 @@ class StaticT2starDecay(BaseSignalModel): f"(#repetitions, #k-space-points), but is: {sampling_times.shape}") if len(sampling_times.shape) == 1: sampling_times = tf.reshape(sampling_times, [1, -1]) - - self.sampling_times = tf.Variable(tf.constant(sampling_times, tf.float32), - shape=(None, None), dtype=tf.float32) - self.expansion_factor = sampling_times.shape[0] + with tf.device(self.device): + self.sampling_times = tf.Variable(tf.constant(sampling_times, tf.float32), + shape=(None, None), dtype=tf.float32) + self.expansion_factor = sampling_times.shape[0] def update(self): super().update() self.expansion_factor == tf.shape(self.sampling_times.read_value())[0] # pylint: disable=signature-differs + @tf.function(jit_compile=True) def __call__(self, signal_tensor: tf.Tensor, T2star: tf.Tensor, **kwargs): """ Evaluates the T2* operator -- GitLab From e2e188b1d9ad0dd62caa4efd4ab029a498db0db7 Mon Sep 17 00:00:00 2001 From: Jonathan Weine <jweine@ethz.ch> Date: Mon, 31 Jul 2023 16:21:24 +0200 Subject: [PATCH 2/3] Moved simulation k-segment loop over to base implementation of encoding module --- cmrsim/analytic/encoding/base.py | 64 +++++++++++++++++++------- cmrsim/analytic/simulation.py | 77 ++++++++++++-------------------- 2 files changed, 76 insertions(+), 65 deletions(-) diff --git a/cmrsim/analytic/encoding/base.py b/cmrsim/analytic/encoding/base.py index ea25c18..2eadff6 100644 --- a/cmrsim/analytic/encoding/base.py +++ b/cmrsim/analytic/encoding/base.py @@ -68,18 +68,19 @@ class BaseSampling(tf.Module): if isinstance(absolute_noise_std, float): absolute_noise_std = [absolute_noise_std, ] - self.absolute_noise_std = tf.Variable(tf.constant(absolute_noise_std), dtype=tf.float32, - trainable=False, shape=[None, ], - name='absolute_noise_std') - self.k_space_segments = tf.Variable(tf.constant(k_space_segments), dtype=tf.int32, - trainable=False) - self.k_space_vectors = tf.Variable(tf.zeros((1, 3), tf.float32), shape=[None, 3], - dtype=tf.float32, trainable=False, - name='k_space_vectors') - self.sampling_times = tf.Variable(tf.zeros((1,), tf.float32), shape=[None, ], - dtype=tf.float32, trainable=False, name='sampling_times') - self.update() - self.ori_matrix = orientation_matrix + with tf.device(self.device): + self.absolute_noise_std = tf.Variable(tf.constant(absolute_noise_std), dtype=tf.float32, + trainable=False, shape=[None, ], + name='absolute_noise_std') + self.k_space_segments = tf.Variable(tf.constant(k_space_segments), dtype=tf.int32, + trainable=False) + self.k_space_vectors = tf.Variable(tf.zeros((1, 3), tf.float32), shape=[None, 3], + dtype=tf.float32, trainable=False, + name='k_space_vectors') + self.sampling_times = tf.Variable(tf.zeros((1,), tf.float32), shape=[None, ], + dtype=tf.float32, trainable=False, name='sampling_times') + self.update() + self.ori_matrix = orientation_matrix def update(self): """ Assigns values to trajectory vectors. Should be used after modifying the parameters of @@ -107,8 +108,39 @@ class BaseSampling(tf.Module): self._segmented_k_vectors = tf.RaggedTensor.from_row_lengths( self.k_space_vectors.read_value(), row_lengths) - def __call__(self, transverse_magnetization: tf.Tensor, r_vectors: tf.Tensor, - segment_index: Union[int, tf.Tensor] = 0, **kwargs): + @tf.function(jit_compile=False) + def __call__(self, m_transverse: tf.Tensor, r_vectors: tf.Tensor): + """ Fourier transforms the handed batch of object points / isochromates, + + :param m_transverse: + :param r_vectors: + :return: tf.Tensor of shape (#repetitions, #k-space-samples) + """ + with tf.device(self.device): + # Allocate k-space-Tensor + # s_of_k_segments = tf.TensorArray(dtype=tf.complex64, size=n_segments, dynamic_size=False) + samples_per_segment = tf.cast(self.number_of_samples / self.k_space_segments, tf.int32) + accumulator_shape = tf.stack([self.k_space_segments, tf.shape(m_transverse)[1], samples_per_segment], axis=0) + s_of_k_segments = tf.zeros(accumulator_shape, dtype=tf.complex64) + max_mtrans_index = tf.math.reduce_min(tf.stack([tf.shape(m_transverse)[-1], self.k_space_segments])) + max_rvectrans_index = tf.math.reduce_min(tf.stack([tf.shape(r_vectors)[-2], self.k_space_segments])) + + # Loop over k-space segments, defined by k-space trajectory in self.encoding_module + for segment_index in tf.range(self.k_space_segments): + mseg_idx = tf.reduce_min(tf.stack([segment_index + 1, max_mtrans_index])) + rseg_idx = tf.reduce_min(tf.stack([segment_index + 1, max_rvectrans_index])) + m_segment = m_transverse[..., samples_per_segment * (mseg_idx - 1): samples_per_segment * mseg_idx] + r_segment = r_vectors[..., samples_per_segment * (rseg_idx - 1):samples_per_segment*rseg_idx, :] + k_space_segment = self.call_single_segment(m_segment, r_segment, segment_index=segment_index) + s_of_k_segments = tf.tensor_scatter_nd_update(s_of_k_segments, [[segment_index], ], + k_space_segment[tf.newaxis]) + + # Concat all segments and transpose axes to match needed format + return tf.reshape(tf.transpose(s_of_k_segments, [1, 0, 2]), [accumulator_shape[1], -1]) + + @tf.function(jit_compile=False) + def call_single_segment(self, transverse_magnetization: tf.Tensor, r_vectors: tf.Tensor, + segment_index: Union[int, tf.Tensor] = 0, **kwargs) -> tf.Tensor: """Calculates fourier phases for given object-representation at r-vectors. For multiple different contrasts #repetitions is the representing axis. @@ -122,10 +154,10 @@ class BaseSampling(tf.Module): :param r_vectors: (#voxel, #repetitions, #k-space-samples, 3), axis #repetitions and #k-space-samples can be 1 to broadcast for coordinate reuse. :param segment_index: int - :param kwargs: - noise_seed: if not None, sets seed for sampling the noise vector. - :return: tf.Tensor + :return: tf.Tensor (..., k-space-points in segment) """ with tf.device(self.device): + n_repetitions = tf.shape(transverse_magnetization)[1] fourier_factors = self._calculate_fourier_phases(r_vectors=r_vectors, segment_index=segment_index) diff --git a/cmrsim/analytic/simulation.py b/cmrsim/analytic/simulation.py index fa1b08a..0ee93e1 100644 --- a/cmrsim/analytic/simulation.py +++ b/cmrsim/analytic/simulation.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from cmrsim.utils.display import SimulationProgressBarII +from time import perf_counter class AnalyticSimulation(tf.Module): """ This module provides the entry point to build and call the simulations defined within the @@ -56,7 +57,7 @@ class AnalyticSimulation(tf.Module): # Initialize progress bar n_k_space_segments = self.encoding_module.k_space_segments.read_value() self.progress_bar = SimulationProgressBarII(total_voxels=1, prefix='Run Simulation: ', - total_segments=n_k_space_segments) + total_segments=n_k_space_segments) @abstractmethod def build(self, **kwargs) -> ('CompositeSignalModel', 'BaseSampling', 'BaseRecon'): @@ -122,7 +123,7 @@ class AnalyticSimulation(tf.Module): f'{map_names} does not contain all entries of' f' {self.forward_model.required_quantities}') - def _simulation_loop(self, dataset: tf.data.Dataset, k_space_shape: tf.Tensor, + def _simulation_loop(self, dataset: tf.data.Dataset, k_space_shape: tf.Tensor, trajectory_module=None, trajectory_signatures: dict = None, additional_kwargs: dict = None): """ Consumes all object-configuration data (images) and simulates the MR images. The @@ -137,64 +138,39 @@ class AnalyticSimulation(tf.Module): (#repetitions, #samples). The returned images are only 3D if the input data, encoding and recon module are 3D as well. """ + if additional_kwargs is None: + additional_kwargs = {} + # Allocate tensor array to store the simulated images s_of_k_temp = tf.zeros(k_space_shape, tf.complex64) self.progress_bar.reset_voxel() - s_of_k_temp *= (0. + 0.j) # Loop over material points in dataset for batch_idx, batch_dict in dataset.enumerate(): if trajectory_module is not None: batch_dict = self._map_trajectories(batch_idx, batch_dict, trajectory_module, trajectory_signatures, additional_kwargs) - # print([f"{k}: {v.shape}" for k, v in batch_dict.items()]) - s_of_k_temp += self._segment_loop(**batch_dict) + m_transverse = self.forward_model(batch_dict["M0"], segment_index=0, **batch_dict) + s_of_k_temp += self.encoding_module(m_transverse, batch_dict["r_vectors"]) self.progress_bar.update(add_voxels=tf.shape(batch_dict['M0'])[0]) self.progress_bar.print() - self.progress_bar.print_final() return s_of_k_temp - - @tf.function(jit_compile=False, reduce_retracing=True) - def _segment_loop(self, **batch_dict): - """ Fourier transforms the handed batch of object points / isochromates, - - :param batch_dict: dictionary of type {'property_name': tf.Tensor(...)} - e.g. {'M0': tf.Tensor{}} - :return: tf.Tensor of shape (1, #repetitions, #k-space-samples) - """ - # Allocate k-space-Tensor - n_segments = self.encoding_module.k_space_segments.read_value() - s_of_k_segments = tf.TensorArray(dtype=tf.complex64, size=n_segments, dynamic_size=False) - self.progress_bar.reset_segments() - - m_transverse = self.forward_model(batch_dict["M0"], segment_index=0, **batch_dict) - - samples_per_segment = tf.cast(self.encoding_module.number_of_samples / - self.encoding_module.k_space_segments, tf.int32) - max_mtrans_index = tf.math.reduce_min(tf.stack([tf.shape(m_transverse)[-1], n_segments])) - max_rvectrans_index = tf.math.reduce_min(tf.stack([tf.shape(batch_dict['r_vectors'])[-2], n_segments])) - - # Loop over k-space segments, defined by k-space trajectory in self.encoding_module - for segment_index in tf.range(n_segments): - mseg_idx = tf.reduce_min(tf.stack([segment_index + 1, max_mtrans_index])) - rseg_idx = tf.reduce_min(tf.stack([segment_index + 1, max_rvectrans_index])) - m_segment = m_transverse[..., samples_per_segment * (mseg_idx - 1): samples_per_segment * mseg_idx] - r_segment = batch_dict['r_vectors'][..., samples_per_segment * (rseg_idx - 1):samples_per_segment*rseg_idx, :] - k_space_segment = self.encoding_module(m_segment, r_segment, segment_index=segment_index) - k_space_segment = tf.transpose(k_space_segment, [1, 0]) - s_of_k_segments = s_of_k_segments.write(segment_index, k_space_segment) - - self.progress_bar.update(add_segment=1) - if tf.math.floormod(segment_index, 5) == 0: - self.progress_bar.print() - - # Concat all segments and transpose axes to match needed format - return tf.transpose(s_of_k_segments.concat(), [1, 0]) @staticmethod def _map_trajectories(idx: int, batch: dict, trajectory_module, trajectory_signatures: dict, additional_kwargs: dict): + """ + + :warning: Only the last entry of the trajectory signatures dict + + :param idx: + :param batch: + :param trajectory_module: + :param trajectory_signatures: + :param additional_kwargs: + :return: + """ init_r = batch.pop("initial_positions") for k, t in trajectory_signatures.items(): new_shape = tf.concat([tf.shape(init_r)[0:1], tf.shape(t), [3, ]], axis=0) @@ -202,17 +178,19 @@ class AnalyticSimulation(tf.Module): # using the flattened time-samples pos, add_lookups = [], [] for t_ in t: - p, alup = trajectory_module(initial_positions=tf.reshape(init_r, (-1, 3)), timing=t_, + p, alup = trajectory_module(initial_positions=tf.reshape(init_r, (-1, 3)), + timing=tf.reshape(t_, (-1, )), **additional_kwargs, batch_index=tf.cast(idx, tf.int32)) pos.append(p) add_lookups.append(alup) pos = tf.stack(pos, axis=1) pos = tf.reshape(pos, new_shape) - add_lookups = {k: tf.reshape(tf.stack([alup[k] for alup in add_lookups]), - tf.concat([new_shape[:-1], tf.shape(add_lookups[0][k])])) - for k in add_lookups[0].keys()} - batch.update({k: pos}) + + lookup_shapes = [tf.concat([new_shape[:-1], tf.shape(add_lookups[-1][k])[2:]], 0) + for k in add_lookups[-1].keys()] + add_lookups = {k: tf.reshape(tf.stack([alup[k] for alup in add_lookups]), lus) + for k, lus in zip(add_lookups[-1].keys(), lookup_shapes)} batch.update(add_lookups) return batch @@ -294,7 +272,8 @@ class AnalyticSimulation(tf.Module): print("Graph Tracing activated...") tf.summary.trace_on(graph=True) batch_dict, = [_ for _ in dataset(100).take(1)] - _ = self._segment_loop(**batch_dict) + m_transverse = self.forward_model(batch_dict["M0"], segment_index=0, **batch_dict) + _ = self.encoding_module(m_transverse, batch_dict["r_vectors"]) with writer.as_default(): # pylint: disable not-context-manager tf.summary.trace_export(name='graph_def', step=0) -- GitLab From f4bbc398141ee667a5ba109a7fc2fb22b1c8183e Mon Sep 17 00:00:00 2001 From: Jonathan Weine <jweine@ethz.ch> Date: Mon, 31 Jul 2023 17:11:00 +0200 Subject: [PATCH 3/3] Removed unused and undocumented modules --- cmrsim/analytic/encoding/base.py | 87 +------------------------------- tests/test_encoding_base.py | 28 +--------- 2 files changed, 3 insertions(+), 112 deletions(-) diff --git a/cmrsim/analytic/encoding/base.py b/cmrsim/analytic/encoding/base.py index 2eadff6..c4ac13b 100644 --- a/cmrsim/analytic/encoding/base.py +++ b/cmrsim/analytic/encoding/base.py @@ -1,5 +1,5 @@ """ This module contains the base implementation of the encoding mechanism of the simulation""" -__all__ = ["BaseSampling", "InternalAccumulator", "BlochOperatorInput"] +__all__ = ["BaseSampling",] from typing import Union, Iterable, Tuple import abc @@ -278,87 +278,4 @@ class BaseSampling(tf.Module): self.ori_matrix = tf.Variable(trafo_mat.astype(np.float32), shape=(3, 4), dtype=tf.float32) else: - self.ori_matrix.assign(trafo_mat.astype(np.float32)) - - -# pylint: disable=abstract-method -class InternalAccumulator(tf.Module): - """Wraps a BaseSampling object to enable internal k-space accumulation""" - #: Encoding module that is wrapped - wrapped_module: BaseSampling - _kspace_acc: Tuple[tf.Variable] - - def __init__(self, encoding_module: BaseSampling, n_repetitions: int): - """ - :param encoding_module: Instance of an encoding module - :param n_repetitions: number of repetitions to be expected for - given transversal magnetization - """ - self.wrapped_module = encoding_module - super().__init__(name=encoding_module.name) - self.update(n_repetitions) - - def update(self, n_repetitions: int): - self.wrapped_module.update() - # pylint: disable=protected-access - samples_per_segment = [tf.shape(_)[0] for _ in - self.wrapped_module._segmented_sampling_times] - - with tf.device("CPU:0"): - # pylint: disable=consider-using-generator - self._kspace_acc = tuple([tf.Variable(tf.zeros((n_repetitions, s), - dtype=tf.complex64), - dtype=tf.complex64, shape=(None, s)) - for s in samples_per_segment]) - - def __getattr__(self, attr): - """ This is only called if attr is not in self.__dict__""" - if attr != "wrapped_module": - return getattr(self.wrapped_module, attr) - - def __call__(self, transverse_magnetization: tf.Tensor, r_vectors: tf.Tensor, - segment_index: Union[int, tf.Tensor] = 0, **kwargs): - samples_batch = self.wrapped_module(transverse_magnetization, r_vectors, - segment_index, **kwargs) - with tf.device("CPU:0"): - self._kspace_acc[segment_index].assign_add(samples_batch) - return transverse_magnetization - - def reset(self): - """ Sets all k-space-line accumulators to zero """ - for i, accumulator in enumerate(self._kspace_acc): - self._kspace_acc[i].assign(tf.zeros_like(accumulator)) - - def get_kspace(self): - """ Returns k-space data in shape (1, -1, #samples)""" - return tf.concat(self._kspace_acc, -1) - - -# pylint: disable=abstract-method -class BlochOperatorInput(tf.Module): - """Wraps a BaseSampling object to deal with the difference in input shape for bloch operators - and solution operators """ - - def __init__(self, encoding_module: Union[BaseSampling, InternalAccumulator]): - """ - :param encoding_module: Instance of an encoding module - """ - super().__init__(name=encoding_module.name) - self.wrapped_module = encoding_module - - def __getattr__(self, attr): - """ This is only called if attr is not in self.__dict__""" - if attr != "wrapped_module": - return getattr(self.wrapped_module, attr) - - def __call__(self, magnetization: tf.Tensor, trajectories: tf.Tensor, - segment_index: Union[int, tf.Tensor] = 0, **kwargs): - """ - - :param magnetization: (#batch, 3) - :param trajectories: (#batch, #k-space-points, 3) - """ - m_transverse = tf.reshape(magnetization[:, 0], [-1, 1, 1]) - rvec = tf.reshape(trajectories, [magnetization.shape[0], 1, -1, 3]) - return self.wrapped_module(transverse_magnetization=m_transverse, r_vectors=rvec, - segment_index=segment_index, **kwargs) + self.ori_matrix.assign(trafo_mat.astype(np.float32)) \ No newline at end of file diff --git a/tests/test_encoding_base.py b/tests/test_encoding_base.py index e5fc449..6e65000 100644 --- a/tests/test_encoding_base.py +++ b/tests/test_encoding_base.py @@ -5,7 +5,7 @@ import tensorflow as tf import numpy as np from pint import Quantity -from cmrsim.analytic.encoding import BaseSampling, InternalAccumulator, BlochOperatorInput, EPI +from cmrsim.analytic.encoding import BaseSampling, EPI class TestInputShapes(unittest.TestCase): @@ -50,32 +50,6 @@ class TestInputShapes(unittest.TestCase): k = self.module(m_transverse, r_vectors) self.assertTrue(all(i == j for i, j in zip(target_shape, tuple(k.shape)))) - def test_internal_accumulator(self): - actual_number_of_contrasts = 10 - n_voxels = 1000 - shape_list = itertools.product([1, actual_number_of_contrasts], [1, self.n_sampling_points]) - module = InternalAccumulator(self.module, n_repetitions=actual_number_of_contrasts) - - for reps, timings in shape_list: - with self.subTest(): - target_shape = (actual_number_of_contrasts, self.n_sampling_points) - m_transverse = tf.ones((n_voxels, actual_number_of_contrasts, timings), - dtype=tf.complex64) - r_vectors = tf.random.normal((n_voxels, reps, timings, 3)) - module(m_transverse, r_vectors, 0) - k = module.get_kspace() - self.assertTrue(all(i == j for i, j in zip(target_shape, tuple(k.shape)))) - - def test_bloch_input_handler(self): - module = BlochOperatorInput(self.module) - m_transverse = tf.ones((100, 3), dtype=tf.complex64) - with self.subTest(): - r_vectors = tf.random.normal((100, 1, 3)) - k = module(m_transverse, r_vectors) - with self.subTest(): - r_vectors = tf.random.normal((100, module._segmented_sampling_times[0].shape[0], 3)) - k = module(m_transverse, r_vectors) - class TestValues(unittest.TestCase): def test_dc_input(self): -- GitLab