Source code for antspynet.utilities.regression_match_image
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
import numpy as np
import ants
[docs]def regression_match_image(source_image,
reference_image,
mask=None,
poly_order=1,
truncate=True):
"""
Image intensity normalization using linear regression.
Arguments
---------
source_image : ANTsImage
Image whose intensities are matched to the reference
image.
reference_image : ANTsImage
Defines the reference intensity function.
poly_order : integer
Polynomial order of fit. Default is 1 (linear fit).
mask : ANTsImage
Defines voxels for regression modeling.
truncate : boolean
Turns on/off the clipping of intensities.
Returns
-------
ANTs image (i.e., source_image) matched to the (reference_image).
Example
-------
>>> import ants
>>> source_image = ants.image_read(ants.get_ants_data('r16'))
>>> reference_image = ants.image_read(ants.get_ants_data('r64'))
>>> matched_image = regression_match_image(source_image, reference_image)
"""
if source_image.shape != reference_image.shape:
raise ValueError("Images do not have the same dimension.")
if mask is None:
source_intensities = np.expand_dims((source_image.numpy()).flatten(), axis=1)
reference_intensities = np.expand_dims((reference_image.numpy()).flatten(), axis=1)
else:
mask_intensities = (mask.numpy()).flatten()
source_intensities = np.expand_dims(source_image.numpy().flatten()[np.where(mask_intensities != 0)], axis=1)
reference_intensities = np.expand_dims(source_image.numpy().flatten()[np.where(mask_intensities != 0)], axis=1)
poly_features = PolynomialFeatures(degree=poly_order)
source_intensities_poly = poly_features.fit_transform(source_intensities)
model = LinearRegression()
model.fit(source_intensities_poly, reference_intensities)
matched_source_intensities = model.predict(source_intensities_poly)
if truncate == True:
min_reference_value = reference_intensities.min()
max_reference_value = reference_intensities.max()
matched_source_intensities[matched_source_intensities < min_reference_value] = min_reference_value
matched_source_intensities[matched_source_intensities > max_reference_value] = max_reference_value
matched_source_image = ants.make_image(source_image.shape, matched_source_intensities)
matched_source_image = ants.copy_image_info(source_image, matched_source_image)
return(matched_source_image)