import ants
import numpy as np
import random
[docs]def reconstruct_image_from_patches(patches,
domain_image,
stride_length=1,
domain_image_is_mask=False):
"""
Reconstruct image from a list of patches.
Arguments
---------
patches : list or array of patches
List or array of patches defining an image. Patches are assumed
to have the same format as returned by extract_image_patches.
domain_image : ANTs image
Image or mask to define the geometric information of the
reconstructed image. If this is a mask image, the reconstruction will only
use patches in the mask.
stride_length: integer or n-D tuple
Defines the sequential patch overlap for max_number_of_patches = "all".
Can be a image-dimensional vector or a scalar.
domain_image_is_mask : boolean
Boolean specifying whether the domain image is a
mask used to limit the region of reconstruction from the patches.
Returns
-------
An ANTs image.
Example
-------
>>> import ants
>>> image = ants.image_read(ants.get_ants_data('r16'))
>>> image_patches = extract_image_patches(image, patch_size=(16, 16), stride_length=4)
>>> reconstructed_image = reconstruct_image_from_patches(image_patches, image, stride_length=4)
"""
image_size = domain_image.shape
dimensionality = domain_image.dimension
if dimensionality != 2 and dimensionality != 3:
raise ValueError("Unsupported dimensionality.")
is_list_patches = isinstance(patches, list)
patch_size = ()
number_of_image_components = 1
if is_list_patches:
patch_dimension = patches[0].shape
patch_size = patch_dimension[0:dimensionality]
if len(patch_dimension) > dimensionality:
number_of_image_components = patch_dimension[dimensionality]
else:
patch_dimension = patches.shape
patch_size = patch_dimension[1:( 2 + dimensionality - 1 )]
if len(patch_dimension) > dimensionality + 1:
number_of_image_components = patch_dimension[dimensionality + 1]
mid_patch_index = tuple(np.int_(np.subtract(np.round(0.5 * (np.array(patch_size))), 1)))
image_array = np.zeros((*image_size, number_of_image_components))
stride_length_tuple = stride_length
if isinstance(stride_length, int):
stride_length_tuple = tuple(np.multiply(np.ones_like(patch_size), stride_length))
elif len(stride_length) != dimensionality:
raise ValueError("stride_length is not a scalar or vector of length dimensionality.")
elif np.any(np.less(stride_length, 1)):
raise ValueError("stride_length elements must be positive integers.")
if domain_image_is_mask:
mask_array = domain_image.numpy()
mask_array[np.where(mask_array != 0)] = 1
count = 0
if dimensionality == 2:
if np.all(np.equal(stride_length_tuple, 1)):
for i in range(image_size[0] - patch_size[0] + 1):
for j in range(image_size[1] - patch_size[1] + 1):
start_index = (i, j)
end_index = tuple(np.add(start_index, patch_size))
do_add = True
if domain_image_is_mask:
if mask_array[start_index[0]:end_index[0],
start_index[1]:end_index[1]][mid_patch_index[0],
mid_patch_index[1]] == 0:
do_add = False
if do_add:
patch = np.empty((1, 1))
if is_list_patches:
patch = patches[count]
else:
if number_of_image_components == 1:
patch = patches[count, :, :]
else:
patch = patches[count, :, :, :]
if number_of_image_components == 1:
patch = np.expand_dims(patch, axis = 2)
image_array[start_index[0]:end_index[0],
start_index[1]:end_index[1], :][mid_patch_index[0],
mid_patch_index[1], :] += patch
count += 1
if not domain_image_is_mask:
for i in range(image_size[0]):
for j in range(image_size[1]):
factor = (min(i + 1, patch_size[0], image_size[0] - i) *
min(j + 1, patch_size[1], image_size[1] - j))
image_array[i, j, :] /= factor
else:
count_array = np.zeros(image_array.shape[0:dimensionality])
for i in range(0, image_size[0] - patch_size[0] + 1, stride_length_tuple[0]):
for j in range(0, image_size[1] - patch_size[1] + 1, stride_length_tuple[1]):
start_index = (i, j)
end_index = tuple(np.add(start_index, patch_size))
patch = np.empty((1, 1))
if is_list_patches:
patch = patches[count]
else:
if number_of_image_components == 1:
patch = patches[count, :, :]
else:
patch = patches[count, :, :, :]
if number_of_image_components == 1:
patch = np.expand_dims(patch, axis = 2)
image_array[start_index[0]:end_index[0],
start_index[1]:end_index[1], :] += patch
count_array[start_index[0]:end_index[0],
start_index[1]:end_index[1]] += np.ones(patch_size)
count += 1
count_array[np.where(count_array == 0)] = 1
for i in range(number_of_image_components):
image_array[:, :, i] /= count_array
else:
if np.all(np.equal(stride_length_tuple, 1)):
for i in range(image_size[0] - patch_size[0] + 1):
for j in range(image_size[1] - patch_size[1] + 1):
for k in range(image_size[2] - patch_size[2] + 1):
start_index = (i, j, k)
end_index = tuple(np.add(start_index, patch_size))
do_add = True
if domain_image_is_mask:
if mask_array[start_index[0]:end_index[0],
start_index[1]:end_index[1],
start_index[2]:end_index[2]][mid_patch_index[0],
mid_patch_index[1],
mid_patch_index[2]] == 0:
do_add = False
if do_add:
patch = np.empty((1, 1, 1))
if is_list_patches:
patch = patches[count]
else:
if number_of_image_components == 1:
patch = patches[count, :, :, :]
else:
patch = patches[count, :, :, :, :]
if number_of_image_components == 1:
patch = np.expand_dims(patch, axis = 3)
image_array[start_index[0]:end_index[0],
start_index[1]:end_index[1],
start_index[2]:end_index[2], :][mid_patch_index[0],
mid_patch_index[1],
mid_patch_index[2], :] += patch
count += 1
if not domain_image_is_mask:
for i in range(image_size[0] + 1):
for j in range(image_size[1] + 1):
for k in range(image_size[2] + 1):
factor = (min(i + 1, patch_size[0], image_size[0] - i) *
min(j + 1, patch_size[1], image_size[1] - j) *
min(k + 1, patch_size[2], image_size[2] - k))
image_array[i, j, k, :] /= factor
else:
count_array = np.zeros(image_array.shape[0:dimensionality])
for i in range(0, image_size[0] - patch_size[0] + 1, stride_length_tuple[0]):
for j in range(0, image_size[1] - patch_size[1] + 1, stride_length_tuple[1]):
for k in range(0, image_size[2] - patch_size[2] + 1, stride_length_tuple[2]):
start_index = (i, j, k)
end_index = tuple(np.add(start_index, patch_size))
patch = np.empty((1, 1, 1))
if is_list_patches:
patch = patches[count]
else:
if number_of_image_components == 1:
patch = patches[count, :, :, :]
else:
patch = patches[count, :, :, :, :]
if number_of_image_components == 1:
patch = np.expand_dims(patch, axis = 3)
image_array[start_index[0]:end_index[0],
start_index[1]:end_index[1],
start_index[2]:end_index[2], :] += patch
count_array[start_index[0]:end_index[0],
start_index[1]:end_index[1],
start_index[2]:end_index[2]] += np.ones(patch_size)
count += 1
count_array[np.where(count_array == 0)] = 1
for i in range(number_of_image_components):
image_array[:, :, :, i] /= count_array
if dimensionality == 2:
image_array = np.transpose(image_array, [2, 0, 1])
else:
image_array = np.transpose(image_array, [3, 0, 1, 2])
image_array = np.squeeze(image_array)
reconstructed_image = ants.from_numpy(image_array,
origin=domain_image.origin,
spacing=domain_image.spacing,
direction=domain_image.direction,
has_components=(number_of_image_components != 1))
return(reconstructed_image)