Source code for mdt.lib.masking

import logging
import numpy as np
from scipy.ndimage import binary_dilation, generate_binary_structure, binary_fill_holes
from mdt.utils import load_brain_mask
from mdt.protocols import load_protocol
from mdt.lib.nifti import load_nifti, write_nifti
import mot.configuration
from scipy.ndimage.filters import median_filter

__author__ = 'Robbert Harms'
__date__ = "2015-07-20"
__maintainer__ = "Robbert Harms"
__email__ = "robbert@xkls.nl"


[docs]def create_median_otsu_brain_mask(dwi_info, protocol, mask_threshold=0, fill_holes=True, **kwargs): """Create a brain mask using the given volume. Args: dwi_info (string or tuple or image): The information about the volume, either: - the filename of the input file - or a tuple with as first index a ndarray with the DWI and as second index the header - or only the image as an ndarray protocol (string or :class:`~mdt.protocols.Protocol`): The filename of the protocol file or a Protocol object mask_threshold (float): everything below this b-value threshold is masked away (value in s/m^2) fill_holes (boolean): if we will fill holes after the median otsu algorithm and before the thresholding **kwargs: the additional arguments for median_otsu. Returns: ndarray: The created brain mask """ logger = logging.getLogger(__name__) logger.info('Starting calculating a brain mask') if isinstance(dwi_info, str): signal_img = load_nifti(dwi_info) dwi = signal_img.get_data() elif isinstance(dwi_info, (tuple, list)): dwi = dwi_info[0] else: dwi = dwi_info if isinstance(protocol, str): protocol = load_protocol(protocol) if len(dwi.shape) == 4: unweighted_ind = protocol.get_unweighted_indices() if len(unweighted_ind): unweighted = np.mean(dwi[..., unweighted_ind], axis=3) else: unweighted = np.mean(dwi, axis=3) else: unweighted = dwi.copy() brain_mask = median_otsu(unweighted, **kwargs) brain_mask = brain_mask > 0 if fill_holes: brain_mask = binary_fill_holes(brain_mask) if mask_threshold: brain_mask = np.mean(dwi[..., protocol.get_weighted_indices()], axis=3) * brain_mask > mask_threshold logger.info('Finished calculating a brain mask') return brain_mask
[docs]def generate_simple_wm_mask(scalar_map, whole_brain_mask, threshold=0.3, median_radius=1, nmr_filter_passes=2): """Generate a simple white matter mask by thresholding the given map and smoothing it using a median filter. Everything below the given threshold will be masked (not used). It also applies the regular brain mask to only retain values inside the brain. Args: scalar_map (str or ndarray): the path to the FA file whole_brain_mask (str or ndarray): the general brain mask used in the FA model fitting threshold (double): the FA threshold. Everything below this threshold is masked (set to 0). To be precise: where fa_data < fa_threshold set the value to 0. median_radius (int): the radius of the median filter nmr_filter_passes (int): the number of passes we apply the median filter """ filter_footprint = np.zeros((1 + 2 * median_radius,) * 3) filter_footprint[median_radius, median_radius, median_radius] = 1 filter_footprint[:, median_radius, median_radius] = 1 filter_footprint[median_radius, :, median_radius] = 1 filter_footprint[median_radius, median_radius, :] = 1 if isinstance(scalar_map, str): map_data = load_nifti(scalar_map).get_data() else: map_data = np.copy(scalar_map) map_data[map_data < threshold] = 0 wm_mask = map_data.astype(np.bool) if len(wm_mask.shape) > 3: wm_mask = wm_mask[:, :, :, 0] wm_mask[np.logical_not(load_brain_mask(whole_brain_mask))] = 0 if nmr_filter_passes == 0: return wm_mask mask = load_brain_mask(whole_brain_mask) wm_mask_masked = np.ma.masked_array(wm_mask, mask=mask) for ind in range(nmr_filter_passes): wm_mask_masked = median_filter(wm_mask_masked, footprint=filter_footprint, mode='constant') return wm_mask_masked
[docs]def create_write_median_otsu_brain_mask(dwi_info, protocol, output_fname, **kwargs): """Write a brain mask using the given volume and output as the given volume. Args: dwi_info (string or tuple or ndarray): the filename of the input file or a tuple with as first index a ndarray with the DWI and as second index the header or only the image. protocol (string or :class:`~mdt.protocols.Protocol`): The filename of the protocol file or a Protocol object output_fname (string): the filename of the output file (the extracted brain mask) If None, no output is written. If ``dwi_info`` is an ndarray also no file is written (we don't have the header). Returns: ndarray: The created brain mask """ if isinstance(dwi_info, str): signal_img = load_nifti(dwi_info) dwi = signal_img.get_data() header = signal_img.header else: dwi = dwi_info[0] header = dwi_info[1] mask = create_median_otsu_brain_mask(dwi, protocol, **kwargs) write_nifti(mask, output_fname, header) return mask
[docs]def median_otsu(unweighted_volume, median_radius=4, numpass=4, dilate=1): """ Simple brain extraction tool for dMRI data. This function is inspired from the ``median_otsu`` function from ``dipy`` and is copied here to remove a dependency. It uses a median filter smoothing of the ``unweighted_volume`` automatic histogram Otsu thresholding technique, hence the name *median_otsu*. This function is inspired from Mrtrix's bet which has default values ``median_radius=3``, ``numpass=2``. However, from tests on multiple 1.5T and 3T data. From GE, Philips, Siemens, the most robust choice is ``median_radius=4``, ``numpass=4``. Args: unweighted_volume (ndarray): ndarray of the unweighted volumes brain volumes median_radius (int): Radius (in voxels) of the applied median filter (default 4) numpass (int): Number of pass of the median filter (default 4) dilate (None or int): optional number of iterations for binary dilation Returns: ndarray: a 3D ndarray with the binary brain mask """ b0vol = unweighted_volume logger = logging.getLogger(__name__) logger.info('We will use a single precision float type for the calculations.'.format()) for env in mot.configuration.get_cl_environments(): logger.info('Using device \'{}\'.'.format(str(env))) for ind in range(numpass): b0vol = median_filter(b0vol, size=median_radius, mode='mirror') thresh = _otsu(b0vol) mask = b0vol > thresh if dilate is not None: cross = generate_binary_structure(3, 1) mask = binary_dilation(mask, cross, iterations=dilate) return mask
def _otsu(image, nbins=256): """ Return threshold value based on Otsu's method. Copied from scikit-image to remove dependency. Parameters ---------- image : array Input image. nbins : int Number of bins used to calculate histogram. This value is ignored for integer arrays. Returns ------- threshold : float Threshold value. """ hist, bin_centers = np.histogram(image, nbins) hist = hist.astype(np.float) # class probabilities for all possible thresholds weight1 = np.cumsum(hist) weight2 = np.cumsum(hist[::-1])[::-1] # class means for all possible thresholds mean1 = np.cumsum(hist * bin_centers[1:]) / weight1 mean2 = (np.cumsum((hist * bin_centers[1:])[::-1]) / weight2[::-1])[::-1] # Clip ends to align class 1 and class 2 variables: # The last value of `weight1`/`mean1` should pair with zero values in # `weight2`/`mean2`, which do not exist. variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:])**2 idx = np.argmax(variance12) threshold = bin_centers[:-1][idx] return threshold