Note
Go to the end to download the full example code. or to run this example in your browser via Binder
3D adaptive histogram equalization#
Adaptive histogram equalization (AHE) can be used to improve the local contrast of an image [1]. Specifically, AHE can be useful for normalizing intensities across images. This example compares the results of applying global histogram equalization and AHE to a 3D image and a synthetically degraded version of it.
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from skimage import exposure, util
# Prepare data and apply histogram equalization
from skimage.data import cells3d
im_orig = util.img_as_float(cells3d()[:, 1, :, :]) # grab just the nuclei
# Reorder axis order from (z, y, x) to (x, y, z)
im_orig = im_orig.transpose()
# Rescale image data to range [0, 1]
im_orig = np.clip(im_orig, np.percentile(im_orig, 5), np.percentile(im_orig, 95))
im_orig = (im_orig - im_orig.min()) / (im_orig.max() - im_orig.min())
# Degrade image by applying exponential intensity decay along x
sigmoid = np.exp(-3 * np.linspace(0, 1, im_orig.shape[0]))
im_degraded = (im_orig.T * sigmoid).T
# Set parameters for AHE
# Determine kernel sizes in each dim relative to image shape
kernel_size = (im_orig.shape[0] // 5, im_orig.shape[1] // 5, im_orig.shape[2] // 2)
kernel_size = np.array(kernel_size)
clip_limit = 0.9
# Perform histogram equalization
im_orig_he, im_degraded_he = (
exposure.equalize_hist(im) for im in [im_orig, im_degraded]
)
im_orig_ahe, im_degraded_ahe = (
exposure.equalize_adapthist(im, kernel_size=kernel_size, clip_limit=clip_limit)
for im in [im_orig, im_degraded]
)
# Define functions to help plot the data
def scalars_to_rgba(scalars, cmap, vmin=0.0, vmax=1.0, alpha=0.2):
"""
Convert array of scalars into array of corresponding RGBA values.
"""
norm = colors.Normalize(vmin=vmin, vmax=vmax)
scalar_map = cm.ScalarMappable(norm=norm, cmap=cmap)
rgbas = scalar_map.to_rgba(scalars)
rgbas[:, 3] = alpha
return rgbas
def plt_render_volume(vol, fig_ax, cmap, vmin=0, vmax=1, bin_widths=None, n_levels=20):
"""
Render a volume in a 3D matplotlib scatter plot.
Better would be to use napari.
"""
vol = np.clip(vol, vmin, vmax)
xs, ys, zs = np.mgrid[
0 : vol.shape[0] : bin_widths[0],
0 : vol.shape[1] : bin_widths[1],
0 : vol.shape[2] : bin_widths[2],
]
vol_scaled = vol[:: bin_widths[0], :: bin_widths[1], :: bin_widths[2]].flatten()
# Define alpha transfer function
levels = np.linspace(vmin, vmax, n_levels)
alphas = np.linspace(0, 0.7, n_levels)
alphas = alphas**11
alphas = (alphas - alphas.min()) / (alphas.max() - alphas.min())
alphas *= 0.8
# Group pixels by intensity and plot separately,
# as 3D scatter does not accept arrays of alpha values
for il in range(1, len(levels)):
sel = vol_scaled >= levels[il - 1]
sel *= vol_scaled <= levels[il]
if not np.max(sel):
continue
c = scalars_to_rgba(
vol_scaled[sel], cmap, vmin=vmin, vmax=vmax, alpha=alphas[il - 1]
)
fig_ax.scatter(
xs.flatten()[sel],
ys.flatten()[sel],
zs.flatten()[sel],
c=c,
s=0.5 * np.mean(bin_widths),
marker='o',
linewidth=0,
)
# Create figure with subplots
cmap = 'Blues'
fig = plt.figure(figsize=(10, 6))
axs = [
fig.add_subplot(2, 3, i + 1, projection=Axes3D.name, facecolor="none")
for i in range(6)
]
ims = [im_orig, im_orig_he, im_orig_ahe, im_degraded, im_degraded_he, im_degraded_ahe]
# Prepare lines for the various boxes to be plotted
verts = np.array([[i, j, k] for i in [0, 1] for j in [0, 1] for k in [0, 1]]).astype(
np.float32
)
lines = [
np.array([i, j])
for i in verts
for j in verts
if np.allclose(np.linalg.norm(i - j), 1)
]
# "render" volumetric data
for iax, ax in enumerate(axs[:]):
plt_render_volume(ims[iax], ax, cmap, 0, 1, [2, 2, 2], 20)
# plot 3D box
rect_shape = np.array(im_orig.shape) + 2
for line in lines:
ax.plot(
(line * rect_shape)[:, 0] - 1,
(line * rect_shape)[:, 1] - 1,
(line * rect_shape)[:, 2] - 1,
linewidth=1,
color='gray',
)
# Add boxes illustrating the kernels
ns = np.array(im_orig.shape) // kernel_size - 1
for axis_ind, vertex_ind, box_shape in zip(
[1] + [2] * 4,
[
[0, 0, 0],
[ns[0] - 1, ns[1], ns[2] - 1],
[ns[0], ns[1] - 1, ns[2] - 1],
[ns[0], ns[1], ns[2] - 1],
[ns[0], ns[1], ns[2]],
],
[np.array(im_orig.shape)] + [kernel_size] * 4,
):
for line in lines:
axs[axis_ind].plot(
((line + vertex_ind) * box_shape)[:, 0],
((line + vertex_ind) * box_shape)[:, 1],
((line + vertex_ind) * box_shape)[:, 2],
linewidth=1.2,
color='crimson',
)
# Plot degradation function
axs[3].scatter(
xs=np.arange(len(sigmoid)),
ys=np.zeros(len(sigmoid)) + im_orig.shape[1],
zs=sigmoid * im_orig.shape[2],
s=5,
c=scalars_to_rgba(sigmoid, cmap=cmap, vmin=0, vmax=1, alpha=1.0)[:, :3],
)
# Subplot aesthetics
for iax, ax in enumerate(axs[:]):
# Get rid of panes and axis lines
for dim_ax in [ax.xaxis, ax.yaxis, ax.zaxis]:
dim_ax.set_pane_color((1.0, 1.0, 1.0, 0.0))
dim_ax.line.set_color((1.0, 1.0, 1.0, 0.0))
# Define 3D axes limits, see https://github.com/
# matplotlib/matplotlib/issues/17172#issuecomment-617546105
xyzlim = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]).T
XYZlim = np.asarray([min(xyzlim[0]), max(xyzlim[1])])
ax.set_xlim3d(XYZlim)
ax.set_ylim3d(XYZlim)
ax.set_zlim3d(XYZlim * 0.5)
try:
ax.set_aspect('equal')
except NotImplementedError:
pass
ax.set_xlabel('x', labelpad=-20)
ax.set_ylabel('y', labelpad=-20)
ax.text2D(0.63, 0.2, "z", transform=ax.transAxes)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.grid(False)
ax.elev = 30
plt.subplots_adjust(
left=0.05, bottom=-0.1, right=1.01, top=1.1, wspace=-0.1, hspace=-0.45
)
# Highlight AHE
rect_ax = fig.add_axes([0, 0, 1, 1], facecolor='none')
rect_ax.set_axis_off()
rect = patches.Rectangle(
(0.68, 0.01),
0.315,
0.98,
edgecolor='gray',
facecolor='none',
linewidth=2,
linestyle='--',
)
rect_ax.add_patch(rect)
# Add text
rect_ax.text(
0.19,
0.34,
'$I_{degr}(x,y,z) = e^{-x}I_{orig}(x,y,z)$',
fontsize=9,
rotation=-15,
color=scalars_to_rgba([0.8], cmap='Blues', alpha=1.0)[0],
)
fc = {'size': 14}
rect_ax.text(
0.03,
0.58,
r'$\it{Original}$' + '\ninput image',
rotation=90,
fontdict=fc,
horizontalalignment='center',
)
rect_ax.text(
0.03,
0.16,
r'$\it{Degraded}$' + '\ninput image',
rotation=90,
fontdict=fc,
horizontalalignment='center',
)
rect_ax.text(0.13, 0.91, 'Input volume:\n3D cell image', fontdict=fc)
rect_ax.text(
0.51,
0.91,
r'$\it{Global}$' + '\nhistogram equalization',
fontdict=fc,
horizontalalignment='center',
)
rect_ax.text(
0.84,
0.91,
r'$\it{Adaptive}$' + '\nhistogram equalization (AHE)',
fontdict=fc,
horizontalalignment='center',
)
rect_ax.text(0.58, 0.82, 'non-local', fontsize=12, color='crimson')
rect_ax.text(0.87, 0.82, 'local kernel', fontsize=12, color='crimson')
# Add colorbar
cbar_ax = fig.add_axes([0.12, 0.43, 0.008, 0.08])
cbar_ax.imshow(np.arange(256).reshape(256, 1)[::-1], cmap=cmap, aspect="auto")
cbar_ax.set_xticks([])
cbar_ax.set_yticks([0, 255])
cbar_ax.set_xticklabels([])
cbar_ax.set_yticklabels([1.0, 0.0])
plt.show()
Total running time of the script: (0 minutes 13.695 seconds)