Source code for epdfsuite.utilities

from scipy.ndimage import gaussian_filter1d
from scipy.optimize import curve_fit
from scipy.special import erf
import numpy as np
from matplotlib import pyplot as plt
from .filereader import load_data
from scipy.ndimage import rotate, binary_erosion
from skimage.transform import hough_line, hough_line_peaks
from skimage.feature import canny
import fabio
import os
import sys
import shutil


[docs] def draw_mask(dm4_image): """ Launch the pyFAI-drawmask GUI to interactively draw a pixel mask. The input DM4 image is temporarily exported as an EDF file, passed to the ``pyFAI-drawmask`` tool, then the EDF file is deleted. The mask produced by the GUI is saved alongside the image by pyFAI. Parameters ---------- dm4_image : str Path to the DM4 image file. """ # load data and metadata detector_info, raw_image = load_data(dm4_image) # Define output EDF file name edffile = dm4_image.replace('.dm4', '.edf') # Create EDF image and save edf_image = fabio.edfimage.EdfImage(data=raw_image, header=detector_info) edf_image.write(edffile) # edit command to use the same python executable as the current environment (important for pyFAI-drawmask to find the right fabio installation) path = shutil.which("pyFAI-drawmask") os.system(f'"{sys.executable}" {path} {edffile}') os.remove(edffile)
[docs] def detect_edge_angle_hough(edge_data, sigma=1, erosion_px=10, num_peaks=5, plot=False): """ Detect the dominant straight edge in an image using the Hough transform. The pipeline is: normalise → erode NaN mask → Canny edge detection → standard Hough transform (0.05° angular resolution) → extract dominant peak. Parameters ---------- edge_data : ndarray 2D image, possibly with NaN pixels marking invalid regions. sigma : float, optional Gaussian smoothing sigma passed to the Canny detector. Default is 1. Use 1–2 for quasi-binary (beamstop/background) images. erosion_px : int, optional Number of pixels to erode from the border of the valid mask before running Canny, to avoid false edges at mask boundaries. Default is 10. num_peaks : int, optional Maximum number of peaks to extract from the Hough accumulator. Only the strongest peak is used. Default is 5. plot : bool, optional If ``True``, display diagnostic plots of the masked image, Canny edges, and the Hough accumulator. Default is ``False``. Returns ------- line_angle_rad : float Angle of the detected edge line with respect to the horizontal, in radians. line_angle_deg : float Same angle in degrees. edge_point : tuple of float ``(x, y)`` coordinates of the point on the line at mid-image height. edge_line : tuple ``(theta, rho, line_angle_deg)`` — Hough normal angle (rad), signed distance from origin (px), and line angle (deg). """ arr = edge_data.astype(float) valid = ~np.isnan(arr) # Normalise to [0, 1] vmin, vmax = np.nanmin(arr), np.nanmax(arr) arr_norm = (arr - vmin) / (vmax - vmin + 1e-12) # Erosion: remove border pixels of the NaN mask valid_eroded = binary_erosion(valid, iterations=erosion_px) # Apply mask: pixels outside eroded region → 0 arr_masked = np.where(valid_eroded, arr_norm, 0.0) # Canny edge detection (normalised, masked image) # low_threshold / high_threshold: adjust to SNR edge_map = canny(arr_masked, sigma=sigma, low_threshold=0.1, high_threshold=0.3, mask=valid_eroded) # Standard Hough transform # tested_angles: angular resolution — 3600 pts = 0.05° precision tested_angles = np.linspace(-np.pi / 2, np.pi / 2, 3600, endpoint=False) h, theta, d = hough_line(edge_map, theta=tested_angles) # Extract peaks _, peak_angles, peak_dists = hough_line_peaks( h, theta, d, num_peaks=num_peaks, threshold=0.3 * h.max() # ignore weak peaks ) if len(peak_angles) == 0: print("[WARN] No Hough peak detected.") return 0.0, 0.0, None, None theta = peak_angles[0] # normal to the line rho = peak_dists[0] # signed distance from origin to line # Line angle (convention: angle w.r.t. horizontal) line_angle_rad = theta + np.pi / 2 line_angle_rad = (line_angle_rad + np.pi / 2) % np.pi - np.pi / 2 line_angle_deg = np.degrees(line_angle_rad) # ---------------------------------------------------------------- # Geometric reconstruction of the line from (theta, rho) # Equation: x*cos(theta) + y*sin(theta) = rho # ---------------------------------------------------------------- ny, nx = edge_data.shape x0_img = nx / 2.0 # image centre (origin of Hough frame if skimage y0_img = ny / 2.0 # native frame is used, i.e. corner (0,0)) # Point on the line at y = image centre # → solve: x*cos(theta) + y_mid*sin(theta) = rho y_mid = ny / 2.0 if np.abs(np.cos(theta)) > 1e-6: x_at_ymid = (rho - y_mid * np.sin(theta)) / np.cos(theta) else: x_at_ymid = rho / (np.cos(theta) + 1e-12) # near-horizontal line # Point on the line at x = image centre x_mid = nx / 2.0 if np.abs(np.sin(theta)) > 1e-6: y_at_xmid = (rho - x_mid * np.cos(theta)) / np.sin(theta) else: y_at_xmid = rho / (np.sin(theta) + 1e-12) # Returned parameters summary edge_point = (x_at_ymid, y_mid) # a point on the line edge_line = (theta, rho, line_angle_deg) # (normal, distance, line angle in °) if plot: fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Masked image axes[0].imshow(arr_masked, cmap='gray') axes[0].set_title(f'Masked image (erosion {erosion_px}px)') # Canny edge map axes[1].imshow(edge_map, cmap='gray') axes[1].set_title(f'Canny edges ({edge_map.sum()} px)') # Hough accumulator axes[2].imshow( np.log(1 + h), extent=[np.degrees(theta[0]), np.degrees(theta[-1]), d[-1], d[0]], aspect='auto', cmap='hot' ) axes[2].set_xlabel('θ (degrees)') axes[2].set_ylabel('ρ (pixels)') axes[2].set_title('Hough accumulator (log)') # Mark peaks for a, dist in zip(peak_angles, peak_dists): axes[2].plot(np.degrees(a), dist, 'c+', ms=10, mew=2) # Overlay detected line on image axes[1].set_title(f'Canny + detected edge ({line_angle_deg:.2f}°)') h_img, w_img = edge_map.shape angle = peak_angles[0] rho = peak_dists[0] if np.abs(np.sin(angle)) > 1e-6: x_vals = np.array([0, w_img]) y_vals = (rho - x_vals * np.cos(angle)) / np.sin(angle) else: x_vals = np.array([rho, rho]) y_vals = np.array([0, h_img]) axes[1].plot(x_vals, y_vals, 'r-', lw=2, label=f'{line_angle_deg:.2f}°') axes[1].legend() plt.tight_layout() plt.show() return line_angle_rad, line_angle_deg, edge_point, edge_line
[docs] def compute_mtf_slanted_edge(image_path, mask=None, pixel_size=None, binning_factor=1, roi_half_width=15, nbins=500, smooth_sigma=0.5, use_erf_fit=True, plot=True, outputfile=None): """ Compute the MTF using the slanted-edge method, with automatic edge angle and position detection via Hough transform. Parameters ---------- image_path : str - Path to the image file. mask : str - Path to a fabio mask file (0=valid, 1=masked). pixel_size : float - Pixel size in µm. binning_factor : int - Binning factor applied to the detector (default 1). roi_half_width : int - Half-width of the band around the edge (pixels). nbins : int - Number of sub-pixel bins for the ESF. smooth_sigma : float - Sigma of the Gaussian smoothing applied to the ESF. use_erf_fit : bool - Fit the ESF with an error function before differentiation. plot : bool - Display diagnostic plots. outputfile : str - If provided, save the MTF to this text file. Returns ------- freq_pixel : 1D array - Spatial frequencies (cycles/pixel) mtf : 1D array - Corresponding MTF values """ # ------------------------------------------------------------------ # 1. Load image and mask # ------------------------------------------------------------------ detector_info, image = load_data(image_path, normalize=False, verbose=False) if pixel_size is None: pixel_size = detector_info.get('pixel_size', None) if pixel_size is None: raise ValueError("Pixel size not found in metadata.") pixel_size = pixel_size * binning_factor if mask is not None: import fabio maskdata = fabio.open(mask).data image = image.astype(float) image[maskdata != 0] = np.nan # ------------------------------------------------------------------ # 2. Detect edge angle and position (single Hough call) # ------------------------------------------------------------------ edge_angle_rad, edge_angle_deg, edge_point, edge_line = detect_edge_angle_hough( image, plot=False ) theta_hough, rho_hough, _ = edge_line x_edge_at_ymid = edge_point[0] # x position of the edge at mid-height (info/debug) print(f"[INFO] Edge detected: angle={edge_angle_deg:.2f}°, " f"rho={rho_hough:.1f} px, x_edge≈{x_edge_at_ymid:.1f} px") # ------------------------------------------------------------------ # 3. Signed distance of each pixel to the Hough line # Line equation: x·cos(θ) + y·sin(θ) = ρ # → signed distance: d(x,y) = x·cos(θ) + y·sin(θ) − ρ # (sign encodes which side of the edge the pixel lies on) # ------------------------------------------------------------------ ny, nx = image.shape y_idx, x_idx = np.indices((ny, nx)) d_raw = x_idx * np.cos(theta_hough) + y_idx * np.sin(theta_hough) - rho_hough d_offset = (x_edge_at_ymid * np.cos(theta_hough) + (ny / 2.0) * np.sin(theta_hough) - rho_hough) d = d_raw - d_offset # ------------------------------------------------------------------ # 3b. Adapt ROI bounds independently on each side of the edge. # No symmetry is required: the ESF is normalised to [0,1] so # each side only needs enough pixels to establish its plateau. # The beamstop side can be much narrower than the bright side. # ------------------------------------------------------------------ valid = ~np.isnan(image) # Maximum available distance on each side within the valid mask d_pos_max = d[valid & (d > 0)].max() if (valid & (d > 0)).any() else roi_half_width d_neg_max = np.abs(d[valid & (d < 0)].min()) if (valid & (d < 0)).any() else roi_half_width # Independent limits: use as much as available up to roi_half_width d_pos_lim = min(roi_half_width, d_pos_max) # bright side d_neg_lim = min(roi_half_width, d_neg_max) # dark (beamstop) side print(f"[INFO] Asymmetric ROI: dark side={d_neg_lim:.1f} px, " f"bright side={d_pos_lim:.1f} px " f"(available: left={d_neg_max:.1f}, right={d_pos_max:.1f})") # ------------------------------------------------------------------ # 4. Select pixels inside the asymmetric ROI band around the edge # ------------------------------------------------------------------ valid = ~np.isnan(image) roi = (d > -d_neg_lim) & (d < d_pos_lim) valid_roi = valid & roi d_vals = d[valid_roi] i_vals = image[valid_roi].astype(float) if len(d_vals) < 100: raise ValueError("Too few valid pixels in ROI. " "Check the mask or increase roi_half_width.") # ------------------------------------------------------------------ # 5. Sub-pixel binning → ESF # ------------------------------------------------------------------ d_min, d_max = d_vals.min(), d_vals.max() bins = np.linspace(d_min, d_max, nbins + 1) bin_centers = 0.5 * (bins[:-1] + bins[1:]) esf_sum = np.zeros(nbins) esf_counts = np.zeros(nbins) bin_idx = np.clip(np.digitize(d_vals, bins) - 1, 0, nbins - 1) np.add.at(esf_sum, bin_idx, i_vals) np.add.at(esf_counts, bin_idx, 1) valid_bins = esf_counts > 0 x_esf = bin_centers[valid_bins] esf = esf_sum[valid_bins] / esf_counts[valid_bins] if len(esf) < 10: raise ValueError("ESF too short after binning. " "Increase nbins or roi_half_width.") # ------------------------------------------------------------------ # 6. Normalise ESF to [0, 1] and enforce rising orientation # ------------------------------------------------------------------ esf_min, esf_max = np.nanmin(esf), np.nanmax(esf) esf_norm = (esf - esf_min) / (esf_max - esf_min + 1e-12) if esf_norm[0] > esf_norm[-1]: esf_norm = esf_norm[::-1] x_esf = x_esf[::-1] # ------------------------------------------------------------------ # 7. Optional erf fit → regularised ESF on a uniform grid # ------------------------------------------------------------------ if use_erf_fit: def erf_model(x, x0, sigma, a, b): """Generalised error function with free amplitude and offset.""" return a * 0.5 * (1 + erf((x - x0) / (np.sqrt(2) * sigma))) + b try: p0 = [np.median(x_esf), 1.0, 1.0, 0.0] popt, _ = curve_fit(erf_model, x_esf, esf_norm, p0=p0, maxfev=5000) x_fit = np.linspace(x_esf.min(), x_esf.max(), nbins) esf_fit = erf_model(x_fit, *popt) esf_fit = (esf_fit - esf_fit.min()) / (esf_fit.max() - esf_fit.min() + 1e-12) x_esf = x_fit esf_norm = esf_fit print(f"[INFO] erf fit: x0={popt[0]:.2f} px, sigma={popt[1]:.3f} px") except Exception as e: print(f"[WARN] erf fit failed ({e}), continuing without fit.") # Light Gaussian smoothing # Light Gaussian smoothing (skip if sigma == 0) esf_smooth = gaussian_filter1d(esf_norm, sigma=smooth_sigma) if smooth_sigma > 0 else esf_norm.copy() # ------------------------------------------------------------------ # 8. LSF = derivative of the ESF # dx is the sub-pixel geometric step (used for np.gradient only) # ------------------------------------------------------------------ dx = np.abs(np.mean(np.diff(x_esf))) # sub-pixel step in pixels, always > 0 lsf = np.gradient(esf_smooth, dx) # Hanning window to suppress spectral leakage window = np.hanning(len(lsf)) lsf *= window # Normalise so that the area under the LSF equals 1 lsf_sum = np.sum(np.abs(lsf)) if lsf_sum > 0: lsf /= lsf_sum # Centre the LSF peak to avoid FFT phase artefacts peak_idx = np.argmax(lsf) shift = len(lsf) // 2 - peak_idx lsf_centered = np.roll(lsf, shift) # ------------------------------------------------------------------ # 8b. Resample LSF onto a 1-pixel grid before FFT # dx < 1 px (sub-pixel binning) would push the Nyquist frequency # above 0.5 cyc/px, which is unphysical. # We interpolate the LSF onto a regular 1-pixel grid so that # freq_pixel is correctly bounded to [0, 0.5] cycles/pixel. # ------------------------------------------------------------------ x_lsf_subpix = np.arange(len(lsf_centered)) * dx # sub-pixel axis (pixels) x_lsf_1px = np.arange(x_lsf_subpix[0], x_lsf_subpix[-1], 1.0) # 1-pixel-step grid lsf_1px = np.interp(x_lsf_1px, x_lsf_subpix, lsf_centered) # Re-normalise after resampling lsf_sum = np.sum(np.abs(lsf_1px)) if lsf_sum > 0: lsf_1px /= lsf_sum # ------------------------------------------------------------------ # 9. FFT → MTF (on the 1-pixel-grid LSF) # ------------------------------------------------------------------ mtf_complex = np.fft.fft(lsf_1px) mtf = np.abs(mtf_complex) mtf /= mtf[0] # normalise to 1 at f = 0 n_half = len(mtf) // 2 mtf = mtf[:n_half] freq_pixel = np.fft.fftfreq(len(lsf_1px), d=1.0)[:n_half] # 0 → 0.5 cyc/px # Physical and normalised frequencies freq_phys = freq_pixel / pixel_size # cycles/µm fnyq_phys = 1.0 / (2.0 * pixel_size) # Nyquist frequency in cycles/µm freq_norm = freq_phys / fnyq_phys # normalised to Nyquist # MTF50 and MTF20 mtf50_idx = np.argmin(np.abs(mtf - 0.5)) mtf20_idx = np.argmin(np.abs(mtf - 0.2)) print(f"MTF50: {freq_norm[mtf50_idx]:.3f} f_Nyq " f"({freq_phys[mtf50_idx]:.3f} µm⁻¹)") print(f"MTF20: {freq_norm[mtf20_idx]:.3f} f_Nyq " f"({freq_phys[mtf20_idx]:.3f} µm⁻¹)") # Determine Wiener epsilon from the noise level in the ESF tail (where signal is flat) signal_patch, noise_patch = extract_noise_and_signal_patches(image, edge_line) wiener_epsilon = estimate_wiener_epsilon_spectral(noise_patch, signal_patch) print(f"Estimated Wiener epsilon (noise/signal ratio): {wiener_epsilon:.4f}") # ------------------------------------------------------------------ # 10. Diagnostic plots # ------------------------------------------------------------------ if plot: fig, axes = plt.subplots(2, 2, figsize=(14, 10)) # --- Image with detected edge line and ROI band --- axes[0, 0].imshow(image, cmap='gray', origin='upper') if np.abs(np.sin(theta_hough)) > 1e-6: x_line = np.array([0, nx - 1]) y_line = (rho_hough - x_line * np.cos(theta_hough)) / np.sin(theta_hough) else: x_line = np.array([rho_hough, rho_hough]) y_line = np.array([0, ny - 1]) axes[0, 0].plot(x_line, y_line, 'r-', lw=2, label=f'Edge {edge_angle_deg:.2f}°') axes[0, 0].contour((d > -d_neg_lim) & (d < d_pos_lim), levels=[0.5], colors='cyan', linewidths=1, linestyles='--') axes[0, 0].set_title('Image + detected edge (red) + ROI (cyan)') axes[0, 0].legend(fontsize=8) # --- ESF --- axes[0, 1].plot(x_esf, esf_norm, 'b-', linewidth=2) axes[0, 1].set_xlabel('Distance to edge (pixels)') axes[0, 1].set_ylabel('Normalised intensity') axes[0, 1].set_title('Edge Spread Function (ESF)') axes[0, 1].grid(True, alpha=0.3) # --- LSF (resampled at 1 px for consistency with MTF) --- axes[1, 0].plot(x_lsf_1px, lsf_1px, 'g-', linewidth=2) axes[1, 0].set_title('Line Spread Function (LSF, 1-px grid)') axes[1, 0].set_xlabel('Position (pixels)') axes[1, 0].grid(True, alpha=0.3) # --- MTF (normalised frequency axis) --- axes[1, 1].plot(freq_pixel, mtf, 'r-', linewidth=2, label='Measured MTF') axes[1, 1].set_xlabel('Spatial frequency (cycles/pixel)') axes[1, 1].set_ylabel('MTF') axes[1, 1].set_xlim(0, 0.5) axes[1, 1].set_ylim(0, 1.05) axes[1, 1].legend() axes[1, 1].grid(True, alpha=0.3) axes[1, 1].set_title('Modulation Transfer Function (MTF)') secax = axes[1, 1].secondary_xaxis( 'top', functions=(lambda f: f * fnyq_phys, lambda f: f / fnyq_phys) ) secax.set_xlabel('Spatial frequency (µm⁻¹)') plt.tight_layout() plt.savefig('debug_mtf_slanted.png', dpi=100) plt.show() # ------------------------------------------------------------------ # 11. Optional output file # ------------------------------------------------------------------ if outputfile is not None: header = ("# MTF computed from slanted-edge image\n" "# Col 1: spatial frequency (cycles/pixel)\n" "# Col 2: MTF\n" "# Col 3: Wiener epsilon (noise/signal ratio) used for deconvolution\n") np.savetxt(outputfile, np.column_stack((freq_pixel, mtf, np.full_like(freq_pixel, wiener_epsilon))), header=header, comments='') print(f"MTF saved: {outputfile}") return freq_pixel, mtf
[docs] def estimate_wiener_epsilon_spectral(noise_patch, signal_patch, subtract_noise=True): """ Estimate the Wiener regularisation parameter epsilon from image data. Computes epsilon as the square root of the ratio of the mean power spectral densities (PSD) of the noise and signal patches: ``epsilon = sqrt( <|N(f)|²> / <|S(f)|²> )`` This estimate is used to set the noise-to-signal power ratio in the Wiener filter: ``W(f) = MTF / (MTF² + epsilon²)``. Parameters ---------- noise_patch : ndarray 2D (or 1D) sub-image extracted from the beamstop region (dark, noisy side). signal_patch : ndarray 2D (or 1D) sub-image extracted from the bright background region. subtract_noise : bool, optional If ``True`` (default), subtract the mean of ``noise_patch`` from ``signal_patch`` before computing the signal PSD, to account for any DC offset in the background. Returns ------- epsilon : float Estimated noise-to-signal PSD ratio, suitable for use as ``wiener_epsilon`` in :func:`deconvolve_mtf_2d`. """ # Centre signals noise = noise_patch - np.nanmean(noise_patch) if subtract_noise: signal = signal_patch - np.nanmean(noise_patch) else: signal = signal_patch - np.nanmean(signal_patch) # Prepare for FFT (fill NaN with 0) noise_f = np.nan_to_num(noise, nan=0.0) signal_f = np.nan_to_num(signal, nan=0.0) # Compute PSD (FFT²), 1D or 2D depending on shape if noise_f.ndim == 1: noise_psd = np.abs(np.fft.fftshift(np.fft.fft(noise_f)))**2 else: noise_psd = np.abs(np.fft.fftshift(np.fft.fft2(noise_f)))**2 if signal_f.ndim == 1: signal_psd = np.abs(np.fft.fftshift(np.fft.fft(signal_f)))**2 else: signal_psd = np.abs(np.fft.fftshift(np.fft.fft2(signal_f)))**2 # Average PSD mean_noise_psd = np.mean(noise_psd) mean_signal_psd = np.mean(signal_psd) # Noise/signal PSD ratio epsilon = np.sqrt(mean_noise_psd / (mean_signal_psd + 1e-12)) return epsilon
[docs] def extract_noise_and_signal_patches(image, edge_line, band_width=500, noise_box=None, erosion_px=5): """ Extract noise and signal pixel patches on each side of the detected edge. The image is split along the Hough line into two regions: - **signal patch** (bright side, ``d > +erosion_px``): background pixels. - **noise patch** (dark side, ``d < -erosion_px``): beamstop pixels. An erosion band of ``erosion_px`` pixels around the edge is excluded from both patches to avoid contamination by the edge transition itself. Diagnostic plots are displayed showing the two zones. Parameters ---------- image : ndarray 2D image, with NaN for masked/invalid pixels. edge_line : tuple ``(theta, rho, angle_deg)`` as returned by :func:`detect_edge_angle_hough`. band_width : float, optional Total width of the extraction band centred on the edge (pixels). Default is 500. noise_box : ignored Reserved for future use. erosion_px : int, optional Width of the exclusion zone on each side of the edge (pixels). Default is 10. Returns ------- signal_patch : ndarray 1D array of pixel values from the bright (background) side. noise_patch : ndarray 1D array of pixel values from the dark (beamstop) side. """ theta, rho, _ = edge_line ny, nx = image.shape y_idx, x_idx = np.indices((ny, nx)) # Signed distance to the Hough line (same convention as in the plot) d = x_idx * np.cos(theta) + y_idx * np.sin(theta) - rho # Build masks for each side of the edge, with erosion # Convention: d < 0 = beamstop side (shadow, noise), d > 0 = background side (signal) mask_band = np.abs(d) < (band_width / 2) valid = ~np.isnan(image) # Erosion: exclude a band of +/- erosion_px around the edge mask_signal = (d > +erosion_px) & mask_band & valid # background only, distance > erosion_px mask_noise = (d < -erosion_px) & mask_band & valid # beamstop only, distance > erosion_px signal_patch = image[mask_signal] noise_patch = image[mask_noise] n_signal = signal_patch.size n_noise = noise_patch.size print(f"Signal patch (background, eroded {erosion_px}px): {n_signal} pixels, " f"Noise patch (beamstop, eroded {erosion_px}px): {n_noise} pixels") # Overlay on original image plt.figure(figsize=(7, 7)) img_disp = np.copy(image) img_disp = np.where(np.isnan(img_disp), np.nanmedian(img_disp), img_disp) plt.imshow(img_disp, cmap='gray', origin='upper') # Overlay noise band (beamstop) in red mask_noise_disp = np.zeros_like(image, dtype=float) mask_noise_disp[mask_noise] = 1.0 plt.contour(mask_noise_disp, levels=[0.5], colors='red', linewidths=2, linestyles='-', label='Noise (beamstop)') # Overlay signal band (background) in blue mask_signal_disp = np.zeros_like(image, dtype=float) mask_signal_disp[mask_signal] = 1.0 plt.contour(mask_signal_disp, levels=[0.5], colors='blue', linewidths=1, linestyles='--', label='Signal (background)') from matplotlib.lines import Line2D legend_elements = [ Line2D([0], [0], color='red', lw=2, label='Noise (beamstop)'), Line2D([0], [0], color='blue', lw=2, linestyle='--', label='Signal (background)') ] plt.title(f'Extracted zones\nNoise (beamstop, red), Signal (background, blue)\n' f'Signal: {n_signal} px, Noise: {n_noise} px\nErosion: {erosion_px} px') plt.axis('off') plt.legend(handles=legend_elements, loc='lower right') plt.tight_layout() plt.show() # 1D patch values (sanity check) plt.figure(figsize=(8, 3)) plt.plot(noise_patch, '.', color='red', alpha=0.7, label='Noise (beamstop)') plt.plot(signal_patch, '.', color='blue', alpha=0.5, label='Signal') plt.title('Extracted values (1D)') plt.xlabel('Index') plt.ylabel('Intensity') plt.legend() plt.tight_layout() plt.show() return signal_patch, noise_patch
#----------------------------------------------------------------------- # Wiener 2D deconvolution with radial MTF #-----------------------------------------------------------------------
[docs] def deconvolve_mtf_2d(image, mtf_file, clip=True, wiener_epsilon=None, min_epsilon=0.005, pre_smooth_sigma=0.5, use_rolloff=True, u_cutoff=0.4, rolloff_window='tukey', rolloff_alpha=0.5, rolloff_order=4, plot=False): """ Wiener 2D MTF deconvolution with optional high-frequency roll-off. Applies a Wiener filter built from a radially symmetric MTF to restore spatial frequencies attenuated by the detector. An optional roll-off window suppresses noise amplification at high frequencies. Parameters ---------- image : ndarray 2D image to deconvolve. mtf_file : str Path to the MTF file (3-column text: freq (cyc/px), MTF, epsilon). clip : bool, optional If ``True`` (default), clip negative values in the output to zero. wiener_epsilon : float or None, optional Regularisation parameter. If ``None``, read from column 3 of ``mtf_file`` (floored at ``min_epsilon``). min_epsilon : float, optional Minimum allowed epsilon to prevent filter instability. Default is 0.005. pre_smooth_sigma : float, optional Sigma (pixels) of Gaussian pre-smoothing applied before deconvolution to reduce Poisson noise amplification. Default is 0.5. Set to 0 to disable. use_rolloff : bool, optional If ``True`` (default), multiply the Wiener filter by a roll-off window to suppress noise at frequencies above ``u_cutoff``. u_cutoff : float or None, optional Roll-off cutoff frequency in cycles/pixel (max 0.5 = Nyquist). If ``None``, automatically set to the frequency where ``MTF = epsilon``. Default is 0.4. rolloff_window : {'tukey', 'hann', 'butterworth'}, optional Shape of the roll-off window. Default is ``'tukey'``. rolloff_alpha : float, optional For the Tukey window: fraction of the passband that is flat (0 = Hann, 1 = rectangular). Default is 0.5. rolloff_order : int, optional For the Butterworth window: filter order (higher = steeper). Default is 4. plot : bool, optional If ``True``, display the Wiener filter profile. Default is ``False``. Returns ------- image_deconv : ndarray Deconvolved image, same shape as ``image``. NaN pixels are preserved. """ # ------------------------------------------------------------------ # 1. Load MTF # ------------------------------------------------------------------ mtf_data = np.loadtxt(mtf_file, comments='#') if mtf_data.ndim != 2 or mtf_data.shape[1] < 2: raise ValueError("MTF file must have 2 columns: freq (cyc/px) and MTF.") freq_1d = mtf_data[:, 0] mtf_1d = mtf_data[:, 1] if wiener_epsilon is None: wiener_epsilon = max(mtf_data[0, 2], min_epsilon) if freq_1d[0] > 0: freq_1d = np.concatenate([[0.0], freq_1d]) mtf_1d = np.concatenate([[1.0], mtf_1d]) # ------------------------------------------------------------------ # 2. Handle NaNs # ------------------------------------------------------------------ nan_mask = np.isnan(image) image_filled = image.copy().astype(float) if nan_mask.any(): image_filled[nan_mask] = np.nanmean(image) # ------------------------------------------------------------------ # 3. Optional pre-smoothing to reduce Poisson noise before deconv # Acts as a noise regulariser without affecting the MTF correction # ------------------------------------------------------------------ if pre_smooth_sigma > 0: from scipy.ndimage import gaussian_filter image_filled = gaussian_filter(image_filled, sigma=pre_smooth_sigma) print(f"[INFO] Pre-smoothing applied: sigma={pre_smooth_sigma} px") # ------------------------------------------------------------------ # 4. Build 2D radial frequency grid # ------------------------------------------------------------------ ny, nx = image_filled.shape fy = np.fft.fftfreq(ny) fx = np.fft.fftfreq(nx) FX, FY = np.meshgrid(fx, fy) freq_radial = np.sqrt(FX**2 + FY**2) # ------------------------------------------------------------------ # 5. Interpolate MTF onto 2D grid # ------------------------------------------------------------------ mtf_2d = np.interp(freq_radial, freq_1d, mtf_1d, left=1.0, right=0.0) # ------------------------------------------------------------------ # 6. Wiener filter normalised to W(0)=1 # ------------------------------------------------------------------ wiener_filter = mtf_2d / (mtf_2d**2 + wiener_epsilon**2) mtf_at_zero = np.interp(0.0, freq_1d, mtf_1d) w_at_zero = mtf_at_zero / (mtf_at_zero**2 + wiener_epsilon**2) wiener_filter /= w_at_zero #print(f"[INFO] Wiener filter: max={wiener_filter.max():.2f}, " # f"epsilon={wiener_epsilon}, W(0)={w_at_zero:.6f}") # ------------------------------------------------------------------ # 7. Optional roll-off window # ------------------------------------------------------------------ if use_rolloff: # Auto u_cutoff: frequency where MTF(u) = epsilon → amplification ~0.5/epsilon # beyond this point, the filter significantly amplifies noise if u_cutoff is None: mtf_epsilon_idx = np.argmin(np.abs(mtf_1d - wiener_epsilon)) u_cutoff = freq_1d[mtf_epsilon_idx] u_cutoff = np.clip(u_cutoff, 0.1, 0.5) print(f"[INFO] Auto u_cutoff = {u_cutoff:.3f} cyc/px " f"(MTF = epsilon = {wiener_epsilon:.4f})") u = freq_radial / u_cutoff if rolloff_window == 'hann': R = np.where(u <= 1.0, 0.5 * (1.0 + np.cos(np.pi * u)), 0.0) elif rolloff_window == 'butterworth': R = 1.0 / (1.0 + u ** (2 * rolloff_order)) elif rolloff_window == 'tukey': R = np.ones_like(u) mask_rolloff = (u >= rolloff_alpha) & (u <= 1.0) mask_zero = u > 1.0 R[mask_rolloff] = 0.5 * (1.0 + np.cos( np.pi * (u[mask_rolloff] - rolloff_alpha) / (1.0 - rolloff_alpha) )) R[mask_zero] = 0.0 else: raise ValueError(f"Unknown rolloff_window: '{rolloff_window}'. " f"Choose 'hann', 'butterworth', or 'tukey'.") wiener_filter *= R #print(f"[INFO] Roll-off applied: window={rolloff_window}, " # f"u_cutoff={u_cutoff:.3f} cyc/px, " # f"alpha={rolloff_alpha if rolloff_window == 'tukey' else 'N/A'}") # ------------------------------------------------------------------ # 8. Diagnostic plot # ------------------------------------------------------------------ if plot: f_plot = np.linspace(0, 0.5, 300) mtf_plot = np.interp(f_plot, freq_1d, mtf_1d) W_plot = mtf_plot / (mtf_plot**2 + wiener_epsilon**2) W_plot /= W_plot[0] plt.figure(figsize=(8, 4)) plt.plot(f_plot, mtf_plot, 'b-', lw=1.5, label='MTF') plt.plot(f_plot, W_plot, 'r--', lw=1.5, label=f'Wiener only (ε={wiener_epsilon})') if use_rolloff: u_plot = f_plot / u_cutoff if rolloff_window == 'hann': R_plot = np.where(u_plot <= 1.0, 0.5 * (1.0 + np.cos(np.pi * u_plot)), 0.0) elif rolloff_window == 'butterworth': R_plot = 1.0 / (1.0 + u_plot ** (2 * rolloff_order)) elif rolloff_window == 'tukey': R_plot = np.ones_like(u_plot) m_r = (u_plot >= rolloff_alpha) & (u_plot <= 1.0) m_z = u_plot > 1.0 R_plot[m_r] = 0.5 * (1.0 + np.cos( np.pi * (u_plot[m_r] - rolloff_alpha) / (1.0 - rolloff_alpha) )) R_plot[m_z] = 0.0 plt.plot(f_plot, R_plot, 'g--', lw=1.5, label=f'Roll-off ({rolloff_window})') plt.plot(f_plot, W_plot * R_plot, 'k-', lw=2.0, label='Wiener × roll-off') plt.axvline(u_cutoff, color='gray', linestyle=':', alpha=0.6, label=f'u_cutoff={u_cutoff:.3f}') plt.axhline(1.0, color='gray', linestyle=':', alpha=0.4) plt.axvline(0.5, color='gray', linestyle=':', alpha=0.4, label='Nyquist') plt.xlabel('Spatial frequency (cycles/pixel)') plt.ylabel('Amplitude') plt.legend(fontsize=8) plt.grid(True, alpha=0.3) plt.title('Wiener deconvolution filter') plt.tight_layout() plt.show() # ------------------------------------------------------------------ # 9. FFT → deconvolution → IFFT # ------------------------------------------------------------------ image_fft = np.fft.fft2(image_filled) image_fft_deconv = image_fft * wiener_filter image_deconv = np.real(np.fft.ifft2(image_fft_deconv)) # ------------------------------------------------------------------ # 10. Restore NaNs and clip # ------------------------------------------------------------------ if nan_mask.any(): image_deconv[nan_mask] = np.nan if clip: image_deconv = np.clip(image_deconv, 0, None) #print(f"[INFO] Done. Input range: [{np.nanmin(image):.1f}, {np.nanmax(image):.1f}]") #print(f" Output range: [{np.nanmin(image_deconv):.1f}, {np.nanmax(image_deconv):.1f}]") return image_deconv
## Function to compute DQE from flat and dark images, and MTF file
[docs] def compute_dqe(flat_paths, dark_paths, mtf_file, gain_reference=None, n_freq=128, plot=False, save=None): """ Compute the radially-averaged DQE from flat-field and dark-field images. The DQE is defined as: .. math:: \\mathrm{DQE}(f) = \\frac{\\mathrm{MTF}^2(f)}{\\bar{n} \\cdot \\mathrm{NPS}(f)} where :math:`\\bar{n}` is the mean number of electrons per pixel (signal level) and :math:`\\mathrm{NPS}(f)` is the normalised noise power spectrum: .. math:: \\mathrm{NPS}(f) = \\frac{1}{N_{\\mathrm{img}}\\, N_x N_y\\, \\bar{n}^2} \\sum_k \\left| \\mathcal{F}\\!\\left[ I_k - \\bar{I} \\right](f) \\right|^2 The dark-field mean is subtracted from each flat-field image before computing the NPS, so that the detector read-noise is excluded from :math:`\\bar{n}` but its contribution to the NPS is correctly accounted for. Parameters ---------- flat_paths : list of str Paths to the flat-field (uniform illumination) images. At least 5 images are recommended for a stable NPS estimate; 20–50 are ideal. dark_paths : list of str Paths to the dark-field (shutter closed) images. Used to estimate and subtract the detector dark offset. mtf_file : str Path to the MTF file (columns: frequency in cyc/px, MTF value), as produced by :func:`compute_mtf_slanted_edge`. gain_reference : ndarray or None, optional 2D gain reference map (same shape as the images). When provided, each flat-field image is divided by ``gain_reference`` before computing the NPS to correct for pixel-to-pixel sensitivity variations. ``None`` skips gain correction. Default is ``None``. n_freq : int, optional Number of radial frequency bins for the azimuthal average. Default is 128. plot : bool, optional If ``True``, display MTF², NPS, and DQE curves. Default is ``False``. save : str or None, optional If a file path is given, save the result as a two-column text file (frequency in cyc/px, DQE value) readable by :func:`deconvolve_mtf_2d_rl`. Default is ``None``. Returns ------- freq_bins : ndarray, shape (n_freq,) Radial frequency axis in cycles/pixel (0 to 0.5). dqe : ndarray, shape (n_freq,) Radially-averaged DQE, values in [0, 1]. Notes ----- * All images must have the same shape. * Images are expected to be in raw detector counts (electrons or ADU). * At very low dose the DQE drops because read noise dominates; at very high dose it drops due to detector non-linearity. Run this function at several dose levels to characterise the dose dependence. """ # ------------------------------------------------------------------ # 1. Load dark images → mean dark frame # ------------------------------------------------------------------ dark_stack = np.array([load_data(p)[1].astype(float) for p in dark_paths]) dark_mean = dark_stack.mean(axis=0) # ------------------------------------------------------------------ # 2. Load flat images, subtract dark, optional gain correction # ------------------------------------------------------------------ flat_list = [] for p in flat_paths: img = load_data(p)[1].astype(float) - dark_mean if gain_reference is not None: img = img / np.where(gain_reference > 0, gain_reference, 1.0) flat_list.append(img) flat_stack = np.array(flat_list) # shape (N, ny, nx) n_img, ny, nx = flat_stack.shape # Mean signal level (electrons/pixel) over all frames and pixels n_bar = float(np.mean(flat_stack)) if n_bar <= 0: raise ValueError("Mean flat signal is non-positive after dark subtraction. " "Check dark and flat images.") # ------------------------------------------------------------------ # 3. Compute NPS # NPS(f) = 1/(N * nx * ny * n_bar²) * Σ_k |FFT(I_k - n_bar)|² # ------------------------------------------------------------------ nps_sum = np.zeros((ny, nx), dtype=float) for img in flat_stack: diff = img - n_bar fft_diff = np.fft.fft2(diff) nps_sum += np.abs(fft_diff) ** 2 nps_2d = nps_sum / (n_img * nx * ny * n_bar ** 2) # ------------------------------------------------------------------ # 4. Load MTF, build 2D MTF map, compute MTF² # ------------------------------------------------------------------ mtf_data = np.loadtxt(mtf_file, comments='#') freq_1d = mtf_data[:, 0] mtf_1d = mtf_data[:, 1] if freq_1d[0] > 0: freq_1d = np.concatenate([[0.0], freq_1d]) mtf_1d = np.concatenate([[1.0], mtf_1d]) fy = np.fft.fftfreq(ny) fx = np.fft.fftfreq(nx) FX, FY = np.meshgrid(fx, fy) freq_radial = np.sqrt(FX**2 + FY**2) mtf_2d = np.interp(freq_radial, freq_1d, mtf_1d, left=1.0, right=0.0) mtf2_2d = mtf_2d ** 2 # ------------------------------------------------------------------ # 5. DQE 2D = MTF² / (n_bar * NPS) # ------------------------------------------------------------------ dqe_2d = mtf2_2d / (n_bar * np.where(nps_2d > 0, nps_2d, np.inf)) dqe_2d = np.clip(dqe_2d, 0.0, 1.0) # ------------------------------------------------------------------ # 6. Radial (azimuthal) average # ------------------------------------------------------------------ freq_bins = np.linspace(0, 0.5, n_freq + 1) freq_centers = 0.5 * (freq_bins[:-1] + freq_bins[1:]) # Use fftshift so that freq_radial maps cleanly to positive half-axis freq_flat = np.fft.fftshift(freq_radial).ravel() dqe_flat = np.fft.fftshift(dqe_2d).ravel() nps_flat = np.fft.fftshift(nps_2d).ravel() mtf2_flat = np.fft.fftshift(mtf2_2d).ravel() dqe_radial = np.zeros(n_freq) nps_radial = np.zeros(n_freq) mtf2_radial = np.zeros(n_freq) for i, (f_lo, f_hi) in enumerate(zip(freq_bins[:-1], freq_bins[1:])): mask = (freq_flat >= f_lo) & (freq_flat < f_hi) if mask.any(): dqe_radial[i] = dqe_flat[mask].mean() nps_radial[i] = nps_flat[mask].mean() mtf2_radial[i] = mtf2_flat[mask].mean() # ------------------------------------------------------------------ # 7. Optional save # ------------------------------------------------------------------ if save is not None: header = (f"# DQE computed from {n_img} flat / {len(dark_paths)} dark images\n" f"# Mean signal level: {n_bar:.2f} counts/pixel\n" f"# Columns: frequency (cyc/px) DQE") np.savetxt(save, np.column_stack([freq_centers, dqe_radial]), header=header, fmt='%.6f') # ------------------------------------------------------------------ # 8. Optional plot # ------------------------------------------------------------------ if plot: fig, axes = plt.subplots(1, 3, figsize=(14, 4)) axes[0].plot(freq_centers, mtf2_radial, 'b-', lw=2) axes[0].set_title('MTF²') axes[0].set_xlabel('Frequency (cyc/px)') axes[0].set_ylabel('MTF²') axes[0].axvline(0.5, color='gray', ls=':', alpha=0.5, label='Nyquist') axes[0].set_ylim(0, 1.05) axes[0].grid(True, alpha=0.3) axes[0].legend() axes[1].plot(freq_centers, nps_radial, 'r-', lw=2) axes[1].set_title(f'NPS (n̄ = {n_bar:.1f} counts/px)') axes[1].set_xlabel('Frequency (cyc/px)') axes[1].set_ylabel('NPS (normalised)') axes[1].grid(True, alpha=0.3) axes[2].plot(freq_centers, dqe_radial, 'g-', lw=2) axes[2].set_title('DQE') axes[2].set_xlabel('Frequency (cyc/px)') axes[2].set_ylabel('DQE') axes[2].set_ylim(0, 1.05) axes[2].axvline(0.5, color='gray', ls=':', alpha=0.5, label='Nyquist') axes[2].grid(True, alpha=0.3) axes[2].legend() plt.suptitle(f'DQE measurement — {n_img} flat images', y=1.01) plt.tight_layout() plt.show() return freq_centers, dqe_radial
[docs] def deconvolve_mtf_dqe_2d(image, mtf_file, dqe_file): """ Deconvolve an image using a DQE-weighted Wiener filter. This is a simplified version of :func:`deconvolve_mtf_2d` that applies the DQE weighting directly to the Wiener filter without pre-smoothing or roll-off. The filter is: .. math:: W(f) = \\frac{\\mathrm{DQE}(f)}{\\mathrm{MTF}(f)} Parameters ---------- image : ndarray 2D image to deconvolve. mtf_file : str Path to the MTF file (columns: frequency in cyc/px, MTF value, epsilon). dqe_file : str Path to the DQE file (columns: frequency in cyc/px, DQE value). Returns ------- image_deconv : ndarray Deconvolved image, same shape as ``image``. NaN pixels are preserved. """ # ------------------------------------------------------------------ # 1. Load MTF and DQE # ------------------------------------------------------------------ mtf_data = np.loadtxt(mtf_file, comments='#') dqe_data = np.loadtxt(dqe_file, comments='#') if mtf_data.ndim != 2 or mtf_data.shape[1] < 2: raise ValueError("MTF file must have at least 2 columns: freq (cyc/px) and MTF.") if dqe_data.ndim != 2 or dqe_data.shape[1] < 2: raise ValueError("DQE file must have 2 columns: freq (cyc/px) and DQE.") freq_mtf = mtf_data[:, 0] mtf_1d = mtf_data[:, 1] # Toujours la 2e colonne, même si 3 colonnes freq_dqe = dqe_data[:, 0] dqe_1d = dqe_data[:, 1] if freq_mtf[0] > 0: freq_mtf = np.concatenate([[0.0], freq_mtf]) mtf_1d = np.concatenate([[1.0], mtf_1d]) if freq_dqe[0] > 0: freq_dqe = np.concatenate([[0.0], freq_dqe]) dqe_1d = np.concatenate([[1.0], dqe_1d]) # ------------------------------------------------------------------ # 2. Handle NaNs # ------------------------------------------------------------------ nan_mask = np.isnan(image) image_filled = image.copy().astype(float) if nan_mask.any(): image_filled[nan_mask] = np.nanmean(image) # ------------------------------------------------------------------ # 3. Build 2D frequency grid and interpolate MTF/DQE # ------------------------------------------------------------------ ny, nx = image_filled.shape fy = np.fft.fftfreq(ny) fx = np.fft.fftfreq(nx) FX, FY = np.meshgrid(fx, fy) freq_radial = np.sqrt(FX**2 + FY**2) mtf_2d = np.interp(freq_radial, freq_mtf, mtf_1d, left=1.0, right=0.0) dqe_2d = np.interp(freq_radial, freq_dqe, dqe_1d, left=1.0, right=0.0) # ------------------------------------------------------------------ # 4. Build Wiener filter: W(f) = DQE(f) / MTF(f) # Clamp MTF to avoid division by zero # ------------------------------------------------------------------ mtf_2d_safe = np.where(mtf_2d > 1e-6, mtf_2d, 1e-6) wiener_filter = dqe_2d / mtf_2d_safe wiener_filter = np.clip(wiener_filter, 0.0, 1.0) # ------------------------------------------------------------------ # 5. FFT, apply filter, IFFT # ------------------------------------------------------------------ image_fft = np.fft.fft2(image_filled) image_fft_deconv = image_fft * wiener_filter image_deconv = np.real(np.fft.ifft2(image_fft_deconv)) # ------------------------------------------------------------------ # 6. Restore NaNs and clip # ------------------------------------------------------------------ if nan_mask.any(): image_deconv[nan_mask] = np.nan image_deconv = np.clip(image_deconv, 0, None) return image_deconv
## RL filter for memory
[docs] def deconvolve_mtf_2d_rl(image, mtf_file, clip=True, n_iterations=50, tol=1e-2, dqe_file=None, pre_smooth_sigma=0, verbose=False, plot=False): """ Richardson-Lucy 2D deconvolution with a radial MTF. Suited to Poisson noise (electron/photon counting). Regularisation is implicit: too few iterations under-deconvolves; too many amplify noise. The stopping criterion is the relative change of the current estimate *u*: .. math:: \\text{rel} = \\frac{\\|u^{(k+1)} - u^{(k)}\\|_\\infty}{\\|u^{(k)}\\|_\\infty} < \\text{tol} **DQE-weighted correction (optional)** When ``dqe_file`` is provided, the back-projection step is weighted by the 2D DQE map instead of the plain MTF conjugate: .. math:: u^{(k+1)} = u^{(k)} \\cdot \\mathcal{F}^{-1}\\!\\left[ \\mathrm{DQE}(f)\\, H(f)\\, \\mathcal{F}\\!\\left[\\frac{I}{h \\circledast u^{(k)}}\\right] \\right] where :math:`\\mathrm{DQE}(f) = \\mathrm{MTF}^2(f) / (\\bar{n}\\,\\mathrm{NPS}(f))`. Frequencies where :math:`\\mathrm{DQE}(f) \\approx 0` (noise-dominated) are naturally suppressed at every iteration, making the algorithm less sensitive to the choice of ``n_iterations`` and removing the need for ``pre_smooth_sigma`` in most cases. Without ``dqe_file`` the standard R-L update is used and regularisation relies entirely on early stopping via ``tol`` and ``n_iterations``. Parameters ---------- image : ndarray 2D image to deconvolve. mtf_file : str Path to the MTF file (columns: frequency in cyc/px, MTF value). clip : bool, optional If ``True``, clamp negative values to 0 after each iteration. Default is ``True``. n_iterations : int, optional Maximum number of iterations (safety cap). Default is 50. tol : float or None, optional Early-stopping threshold on the relative change ``||Δu||/||u||``. ``None`` disables early stopping. Default is ``1e-2``. dqe_file : str or None, optional Path to the DQE file (same format as ``mtf_file``: columns are frequency in cyc/px and DQE value in [0, 1]). When provided, the correction at each iteration is weighted by the 2D DQE map, which suppresses noise-dominated frequencies without requiring aggressive early stopping or pre-smoothing. ``None`` disables DQE weighting and reproduces the standard R-L behaviour. Default is ``None``. pre_smooth_sigma : float, optional Standard deviation (pixels) for Gaussian pre-smoothing applied before deconvolution. ``0`` disables smoothing. Default is ``0``. verbose : bool, optional If ``True``, print the relative change at each iteration. Default is ``False``. plot : bool, optional If ``True``, display the PSF profile. Default is ``False``. Returns ------- image_deconv : ndarray Deconvolved 2D image. """ # ------------------------------------------------------------------ # 1. Load MTF # ------------------------------------------------------------------ mtf_data = np.loadtxt(mtf_file, comments='#') if mtf_data.ndim != 2 or mtf_data.shape[1] < 2: raise ValueError("MTF file must have 2 columns: freq (cyc/px) and MTF.") freq_1d = mtf_data[:, 0] mtf_1d = mtf_data[:, 1] if freq_1d[0] > 0: freq_1d = np.concatenate([[0.0], freq_1d]) mtf_1d = np.concatenate([[1.0], mtf_1d]) # RL requires a normalised PSF: H(0) = 1 (total energy is conserved). # If MTF(DC) < 1, each iteration multiplies the estimate by MTF(0)^2 < 1, # driving it to zero after enough iterations. # Normalise here so RL is independent of the absolute MTF calibration. _mtf_dc = np.interp(0.0, freq_1d, mtf_1d) if _mtf_dc <= 0: raise ValueError( f"[RL] MTF value at DC = {_mtf_dc:.4g} ≤ 0. " "Check that column 1 of the MTF file contains MTF values (not epsilon)." ) if abs(_mtf_dc - 1.0) > 1e-3: print( f"[RL] MTF(DC) = {_mtf_dc:.4f} ≠ 1.0 — normalising by MTF(DC). " "Without this, each RL iteration multiplies the estimate by " f"MTF(DC)² = {_mtf_dc**2:.4f}, collapsing the image to zero." ) mtf_1d = mtf_1d / _mtf_dc # ------------------------------------------------------------------ if dqe_file is not None: dqe_data = np.loadtxt(dqe_file, comments='#') if dqe_data.ndim != 2 or dqe_data.shape[1] < 2: raise ValueError("DQE file must have 2 columns: freq (cyc/px) and DQE.") dqe_freq_1d = dqe_data[:, 0] dqe_1d = dqe_data[:, 1] if dqe_freq_1d[0] > 0: dqe_freq_1d = np.concatenate([[0.0], dqe_freq_1d]) dqe_1d = np.concatenate([[1.0], dqe_1d]) # ------------------------------------------------------------------ # 2. Handle NaNs # ------------------------------------------------------------------ nan_mask = np.isnan(image) image_filled = image.copy().astype(float) if nan_mask.any(): image_filled[nan_mask] = np.nanmean(image) # ------------------------------------------------------------------ # 3. Optional pre-smoothing # ------------------------------------------------------------------ if pre_smooth_sigma > 0: from scipy.ndimage import gaussian_filter image_filled = gaussian_filter(image_filled, sigma=pre_smooth_sigma) #print(f"[INFO] Pre-smoothing applied: sigma={pre_smooth_sigma} px") # ------------------------------------------------------------------ # 4. Build 2D radial PSF (= IFFT of 2D MTF) # Convolutions are performed in the frequency domain # ------------------------------------------------------------------ ny, nx = image_filled.shape fy = np.fft.fftfreq(ny) fx = np.fft.fftfreq(nx) FX, FY = np.meshgrid(fx, fy) freq_radial = np.sqrt(FX**2 + FY**2) # For RL, the PSF must be non-negative everywhere to prevent divergence. # Using right=0.0 (hard cut beyond the last tabulated frequency) creates a # sharp edge that causes Gibbs oscillations: the PSF goes negative in the # spatial domain, and the clip(correction, 0) in the RL loop kills pixels # at every iteration → image collapses to zero. # Fix: hold the last tabulated MTF value for frequencies beyond the table. # The MTF at the table edge (~0.1) is small → negligible amplification, # but the PSF remains smooth and positive. mtf_2d = np.interp(freq_radial, freq_1d, mtf_1d, left=1.0, right=mtf_1d[-1]) # Build 2D DQE map if provided if dqe_file is not None: dqe_2d = np.interp(freq_radial, dqe_freq_1d, dqe_1d, left=1.0, right=0.0) else: dqe_2d = None # ------------------------------------------------------------------ # 5. Richardson-Lucy algorithm # u^(k+1) = u^(k) * (h* ⊛ (I / (h ⊛ u^(k)))) # implemented in the frequency domain (convolution = multiplication) # ------------------------------------------------------------------ # Initialise with the observed image (clipped positive) u = np.clip(image_filled.copy(), 1e-6, None) I = np.clip(image_filled.copy(), 1e-6, None) # MTF is symmetric (real, positive) → h* = h in frequency domain H = mtf_2d # convolution with h Ht = mtf_2d # convolution with h* (symmetric PSF → identical) for i in range(n_iterations): u_prev = u.copy() # Convolution de l'estimation courante avec la PSF u_fft = np.fft.fft2(u) Hu = np.real(np.fft.ifft2(H * u_fft)) Hu = np.clip(Hu, 1e-12, None) # avoid division by zero # Observed / convolved estimate ratio ratio = I / Hu ratio_fft = np.fft.fft2(ratio) # Correlation with the flipped PSF, weighted by DQE if available if dqe_2d is not None: correction = np.real(np.fft.ifft2(dqe_2d * Ht * ratio_fft)) else: correction = np.real(np.fft.ifft2(Ht * ratio_fft)) # Update u = u * np.clip(correction, 0, None) # ------------------------------------------------------------------ # Stopping criterion: relative change of the estimate u # rel = ||u^(k+1) - u^(k)||_∞ / ||u^(k)||_∞ # Convergence: rel → 0. Divergence: rel grows (noise amplification). # ------------------------------------------------------------------ if tol is not None: rel = np.max(np.abs(u - u_prev)) / (np.max(u_prev) + 1e-12) if verbose: print(f"[RL] iter {i+1:4d} | Δrel={rel:.4e}") if rel < tol: print(f"[RL] Converged at iteration {i+1} " f"(Δrel={rel:.2e} < {tol})") break if i + 1 == n_iterations: print(f"[RL] Maximum iterations reached ({n_iterations}) " f"without convergence (Δrel={rel:.2e}).\n" f" → Increase n_iterations or tol, " f"or enable pre_smooth_sigma to slow divergence.") image_deconv = u # ------------------------------------------------------------------ # 6. Diagnostic PSF plot # ------------------------------------------------------------------ if plot: f_plot = np.linspace(0, 0.5, 300) mtf_plot = np.interp(f_plot, freq_1d, mtf_1d) # 1D PSF ≈ IFFT of 1D MTF (central profile) n_psf = 256 mtf_sym = np.interp(np.fft.fftfreq(n_psf, d=1.0), freq_1d, mtf_1d) psf_1d = np.real(np.fft.ifftshift(np.fft.ifft(mtf_sym))) x_psf = np.arange(n_psf) - n_psf // 2 fig, axes = plt.subplots(1, 2, figsize=(10, 4)) axes[0].plot(f_plot, mtf_plot, 'b-', lw=2) axes[0].set_xlabel('Spatial frequency (cycles/pixel)') axes[0].set_ylabel('MTF') axes[0].set_title('MTF used (Richardson-Lucy)') axes[0].axvline(0.5, color='gray', linestyle=':', alpha=0.5, label='Nyquist') axes[0].grid(True, alpha=0.3) axes[0].legend() axes[1].plot(x_psf, psf_1d / psf_1d.max(), 'r-', lw=2) axes[1].set_xlabel('Position (pixels)') axes[1].set_ylabel('Normalised PSF') axes[1].set_title('Radial PSF (1D profile)') axes[1].set_xlim(-20, 20) axes[1].grid(True, alpha=0.3) plt.tight_layout() plt.show() # ------------------------------------------------------------------ # 7. Restore NaNs and clip # ------------------------------------------------------------------ if nan_mask.any(): image_deconv[nan_mask] = np.nan if clip: image_deconv = np.clip(image_deconv, 0, None) #print(f"[INFO] Richardson-Lucy ({n_iterations} iter). " # f"Input range: [{np.nanmin(image):.1f}, {np.nanmax(image):.1f}]") #print(f" Output range: [{np.nanmin(image_deconv):.1f}, {np.nanmax(image_deconv):.1f}]") return image_deconv