import tifffile
from tifffile import imread
import numpy as np
import dask.array as da
import dask
from dask.distributed import Client, progress
from numba import jit, prange, int64
from . import gap_pixels, nexus
from .pre_proc_opts import PreProcOpts
from scipy import optimize, sparse, special, interpolate
from scipy.ndimage.morphology import binary_dilation
from skimage.morphology import disk
from astropy.convolution import Gaussian2DKernel, convolve
from scipy.ndimage.filters import median_filter
from functools import wraps
from typing import Optional, Tuple, Union, List, Callable, Dict
from warnings import warn, catch_warnings, simplefilter
import pandas as pd
import h5py
[docs]def stack_nested(data_list: Union[tuple, list, dict], func: Callable = np.stack):
"""Applies a numpy/dask concatenation/stacking function recursively to a recursive
python collection (tuple/list/dict) containing numpy or dask arrays on the lowest level.
Args:
data_list (Union[tuple, list, dict]): Collection of numpy arrays (can be recursive)
func (Callable, optional): Concatenation function to apply. Defaults to np.stack.
Returns:
same as data_list: tuple/list/dict with concatenated/stacked numpy arrays
"""
if np.ndim(data_list) == 0:
data_list = [data_list]
if isinstance(data_list[0], tuple):
return tuple(stack_nested(arr, func) for arr in zip(*data_list))
elif isinstance(data_list[0], list):
return list(stack_nested(arr, func) for arr in zip(*data_list))
elif isinstance(data_list[0], dict):
return {k: stack_nested(list(o[k] for o in data_list), func) for k in data_list[0].keys()}
else:
return func(data_list)
[docs]def loop_over_stack(fun):
"""
Decorator to (sequentially) loop a 2D processing function over a stack.
In brief, if you have a function that either modifies a (single) image or extracts some reduced
data from it, this decorator wraps it such that it can operate on a whole stack of images.
Works on all functions with signature fun(imgs: np.ndarray, *args, **kwargs),
where imgs is a numpy 3D stack or a 2D single image. It has to return either
a numpy array,n which case it returns a stacked array of the function output,
or a collection containing numpy arrays, each of which is stacked individually.
If any of the positional/named arguments is an iterable of the same length
as the image stack, it is distributed over the function calls for each
image.
Note:
`loop_over_stack` only works on functions eating numpy arrays, *not* dask arrays.
If you want to apply a function to a dask-array image stack, you have to *additionally*
wrap it in `dask.array.map_blocks`, `diffractem.dataset._map_sub_blocks`,
`diffractem.compute.map_reduction_func` or similar.
Args:
fun (Callable): function to be decorated
Returns:
Callable: function that loops over an image stack automatically
"""
#TODO: handle functions with multiple outputs, and return a list of ndarrays
#TODO: allow switches for paralellization
@wraps(fun)
def loop_fun(imgs, *args, **kwargs):
# this is NOT intended for dask arrays!
if not isinstance(imgs, np.ndarray):
raise TypeError(f'loop_over_stack only works on numpy arrays (not dask etc.). '
f'Passed type to {fun} is {type(imgs)}')
if imgs.ndim == 2:
return fun(imgs, *args, **kwargs)
elif imgs.shape[0] == 1:
# some gymnastics if arrays are on weird dimensions (often after map_blocks)
args = [a.squeeze() if isinstance(a, np.ndarray) else a for a in args]
kwargs = {k: a.squeeze() if isinstance(a, np.ndarray) else a for k, a in kwargs.items()}
return stack_nested([fun(imgs.squeeze(), *args, **kwargs)])
#print('Applying {} to {} images of shape {}'.format(fun, imgs.shape[0], imgs.shape[1:]))
def _isiterable(arg):
try:
(a for a in arg)
return True
except TypeError:
return False
iter_args = []
for a in args:
if _isiterable(a) and len(a) == len(imgs):
if isinstance(a, np.ndarray):
a = a.squeeze()
iter_args.append(a)
else:
# iter_args.append(repeat(a, lenb(imgs)))
iter_args.append([a]*len(imgs))
iter_kwargs = []
for k, a in kwargs.items():
if _isiterable(a) and len(a) == len(imgs):
if isinstance(a, np.ndarray):
a = a.squeeze()
iter_kwargs.append(a)
else:
# iter_kwargs.append(repeat(a, len(imgs)))
iter_kwargs.append([a] * len(imgs))
out = []
for arg in zip(imgs, *(iter_args + iter_kwargs)):
theArgs = arg[1:1+len(args)]
theKwargs = {k: v for k, v in zip(kwargs, arg[1+len(args):])}
# print('Arguments: ', theArgs)
# print('KW Args: ', theKwargs)
out.append(fun(arg[0], *theArgs, **theKwargs))
if not out:
# required for dask map_blocks init runs
# print('Looping of',fun,'requested for zero-size input.')
return np.ndarray(imgs.shape, dtype=imgs.dtype)
try:
# print(type(out), type(out[0]))
return stack_nested(out)
except ValueError as err:
print('Function',fun,'failed for output array construction.')
raise err
return loop_fun
@loop_over_stack
def _generate_pattern_info(img: np.ndarray, opts: PreProcOpts,
reference: Optional[np.ndarray] = None,
pxmask: Optional[np.ndarray] = None,
centers: Optional[np.ndarray] = None,
lorentz_fit: Optional[bool] = True) -> dict:
"""
'Macro' function computing information from diffraction data and returning it
as a dictionary. Primarily intended to be called from `get_pattern_info`.
Note:
This function is different from most in `proc2d` in that it returns a
dictionary, *no* a `np.ndarray`. This has, among others, the implication, that
it cannot be called through the dask array interface via `map_blocks`.
Args:
img (np.ndarray): diffraction image or stack thereof as numpy array.
opts (PreProcOpts): pre-processing options.
reference (Optional[np.ndarray], optional): reference image for flat-field.
correction. If None, grabs the file name from the options file and loads it.
This is discouraged as it requires reloading it over and over. Defaults to None.
pxmask (Optional[np.ndarray], optional): similar, for pixel mask. Defaults to None.
Returns:
dict: Diffraction pattern information.
"""
#TODO consider using a NamedTuple for return values instead of a dict
# computations on diffraction patterns. To be called from get_pattern_info.
reference = imread(opts.reference) if reference is None else reference
pxmask = imread(opts.pxmask) if pxmask is None else pxmask
from diffractem.proc_peaks import _ctr_from_pks
# apply flatfield and dead-pixel correction to get more accurate COM
# CONSIDER DOING THIS OUTSIDE GET PATTERN INFO!
img = apply_flatfield(img, reference, keep_type=False)
img = correct_dead_pixels(img, pxmask, strategy='replace', mask_gaps=False, replace_val=-1)
if centers is None:
# thresholded center-of-mass calculation over x-axis sub-range
img_ct = img[:,(img.shape[1]-opts.com_xrng)//2:(img.shape[1]+opts.com_xrng)//2]
com = center_of_mass(img_ct, threshold=opts.com_threshold*np.quantile(img_ct,1-5e-5)) + [(img.shape[1]-opts.com_xrng)//2, 0]
if lorentz_fit:
# Lorentz fit of direct beam
lorentz = lorentz_fast(img, com[0], com[1], radius=opts.lorentz_radius,
limit=opts.lorentz_maxshift, scale=7, threads=False)
x0, y0 = lorentz[1], lorentz[2]
else:
x0, y0 = com[0], com[1]
lorentz = [np.nan] * 4
else:
# print(centers.shape)
x0, y0 = centers[0], centers[1]
# print(centers)
lorentz = [np.nan] * 4
com = [np.nan] * 2
# Get peaks using peakfinder8. Note that pf8 parameters are taken straight from the options file,
# with automatic underscore/hyphen replacement.
# Note that peak positions are CXI convention, i.e. refer to pixel _center_
if opts.find_peaks:
peak_data = get_peaks(img, x0, y0, pxmask=pxmask, max_peaks=opts.max_peaks,
**{k.replace('-','_'): v for k, v in opts.peak_search_params.items()},
as_dict=True)
# print(peak_data.keys())
else:
peak_data = {'nPeaks': 0, 'peakXPosRaw': np.zeros((opts.max_peaks,)),
'peakYPosRaw': np.zeros((opts.max_peaks,)),
'peakTotalIntensity': np.zeros((opts.max_peaks,))}
if opts.friedel_refine and (peak_data['nPeaks'] >= opts.min_peaks):
# prepare peak list. Note the .5, as _ctr_from_pks expects CrystFEL peak convention,
# i.e. positions refer to pixel _corner_
pkl = np.stack((peak_data['peakXPosRaw'] + .5, peak_data['peakYPosRaw'] + .5,
peak_data['peakTotalIntensity']), -1)[:int(peak_data['nPeaks']),:]
if opts.friedel_max_radius is not None:
rsq = (pkl[:, 0] - x0) ** 2 + (pkl[:, 1] - y0) ** 2
pkl = pkl[rsq < opts.friedel_max_radius ** 2, :]
ctr_refined, cost, _ = _ctr_from_pks(pkl, np.array([x0, y0]), int_weight=False,
sigma=opts.peak_sigma)
else:
ctr_refined, cost = np.array([x0, y0]), np.nan
# print(ctr_refined, x0, y0)
# virtual ADF detectors
adf1 = apply_virtual_detector(img, opts.r_adf1[0], opts.r_adf1[1], ctr_refined[0], ctr_refined[1])
adf2 = apply_virtual_detector(img, opts.r_adf2[0], opts.r_adf2[1], ctr_refined[0], ctr_refined[1])
# Lab-space shifts. Not really consistent to calculate them here, and even worse, copied code
# from Dataset. Just go and cast the first stone.
c, s = np.cos(opts.ellipse_angle*np.pi/180), np.sin(opts.ellipse_angle*np.pi/180)
R = np.array([[c, -s], [s, c]])
RR = R.T @ ([[opts.ellipse_ratio**(-.5)],[opts.ellipse_ratio**(.5)]] * R)
# panel-space shift
x0, y0, pxs = opts.xsize/2, opts.ysize/2, opts.pixel_size * 1e3
shift_labels = [opts.det_shift_x_path, opts.det_shift_y_path]
shift_p = np.array([ctr_refined[0] - x0 + 0.5,
ctr_refined[1] - y0 + 0.5])
# real-space shift
shift_mm = - pxs * (RR @ shift_p)
# X0 = RR @ [[-xsz//2], [-ysz//2]]
pattern_info = {'com_x': com[0], 'com_y': com[1],
'lor_pk': lorentz[0],
'lor_x': lorentz[1],
'lor_y': lorentz[2],
'lor_hwhm': lorentz[3],
'center_x': ctr_refined[0],
'center_y': ctr_refined[1],
'center_refine_score': cost,
'adf1': adf1,
'adf2': adf2,
shift_labels[0]: shift_mm[0],
shift_labels[1]: shift_mm[1],
'num_peaks': peak_data['nPeaks'],
'peak_data': peak_data}
return pattern_info
[docs]def get_pattern_info(img: Union[np.ndarray, da.Array], opts: PreProcOpts, client: Optional[Client] = None,
reference: Optional[np.ndarray] = None,
pxmask: Optional[np.ndarray] = None,
centers: Optional[Union[np.ndarray, da.Array]] = None,
lorentz_fit: Optional[bool] = True,
lazy: bool = False, sync: bool = True,
errors: str = 'raise', via_array: bool = False,
output_file: Optional[str] = None,
shots: Optional[pd.DataFrame] = None,
dummy_stack_name: str = 'corrected') -> Tuple[pd.DataFrame, dict]:
"""'Macro' function for getting information about diffraction patterns.
`get_pattern_info` finds diffraction peaks and computes information such as pattern center on a given diffraction
pattern or stack thereof. By default (`lazy=False` and `sync=True`) it will return a pandas DataFrame containing
general information on each pattern, and a dict holding the found peaks in CXI format.
The options for preprocessing are passed as a `PreProcOpts` object.
Note:
This function is essentially a smart wrapper around `prof2d._generate_pattern_info`. If you'd like to change
what is actually calculated and how, that is the function to modify!
Note:
As this function is computationally heavy, it is **very** advisable to use a *dask.distributed* cluster for
computation, with a client object supplied to the function call.
Args:
img (Union[np.ndarray, da.Array]): stack of diffraction patterns, typically a dask array
opts (PreProcOpts): pre-processing options.
client (Optional[Client], optional): Client object for dask.distributed cluster. If None, runs
computation by simply calling `compute` on the stack dask array (discouraged). Defaults to None.
reference (Optional[np.ndarray], optional): Flat-field reference image. If None, load the one specified in
preprocessing options. Defaults to None.
pxmask (Optional[np.ndarray], optional): Pixel mask image. If None, load the one specified in
preprocessing options. Defaults to None.
centers (np.array or da.Array, optional): N x 2 matrix with known centers of all diffraction patterns. If set,
the center-of-mass and Lorentz fit steps are skipped. Depending on the setting of `opts.friedel_refine`,
Friedel-mate center refinement is still performed. Defaults to None.
lazy (bool, optional): Return `dask.delayed` objects for pattern info generation tasks instead of the final
results. Mostly useful for debugging or embedding into more complex workflows. Defaults to False.
sync (bool, optional): Immediately compute pattern info. If False, returns futures to pattern info dictionaries
instead of DataFrame and peak dict. Defaults to True.
errors (str, optional): Behavior if errors arise during eager computation (i.e., `lazy=False`, `sync=True`). If
'raise', errors are raised, if 'skip', they are skipped, and the final data is missing the corresponding
shots, which needs to be handled downstream to avoid making a mess. Defaults to 'raise'.
via_array (bool, optional): Modify calculation such that it avoids `dask.delayed` objects.
This drastically improves the scheduling behavior for large datasets. It is also required if you
supply the pattern centers to the function. However, precludes the use of lazy and sync. Defaults to False.
output_file (str, optional): Filename to store calculation results into. The file will be a valid diffractem-type
data file that can be loaded using Dataset objects.
shots (pd.DataFrame, optional): Dataframe of shot data of same height as the image array. If not None,
its columns will be joined to those of the shot data for storing the results into the output file.
dummy_ds_name (str, optional): Name of virtual data set to be written into the output file in order to
fake data, if required by another program (e.g. CrystFEL). Defaults to 'corrected'.
Returns:
Tuple[pd.DataFrame, dict]: pandas DataFrame holding general pattern information, and dict holding CXI-format
peaks. (note that return values are different when using `lazy=True` or `sync=False` - see above)
"""
#TODO could this be refactored into dataset, automatically applying it to the diffraction data set?
#TODO would including an option to return cts on top of res_del make sense?
reference = imread(opts.reference) if reference is None else reference
pxmask = imread(opts.pxmask) if pxmask is None else pxmask
# print(type(pxmask))
if len(img.shape) == 2:
img = img[np.newaxis, ...]
if isinstance(img, da.Array) and not via_array:
if centers is not None:
raise ValueError('If pattern centers are given, you have to set via_array=True.')
cts = img.to_delayed().squeeze()
res_del = [dask.delayed(_generate_pattern_info, nout=1,
pure=True)(c, opts=opts, reference=reference, pxmask=pxmask, lorentz_fit=lorentz_fit,
dask_key_name=f'pattern_info-{ii}') for ii, c in enumerate(cts)]
if lazy:
return res_del
if client is not None:
# print(f'Running get_pattern_info on cluster at {client.scheduler_info()["address"]}. \n'
# f'Watch progress at {client.dashboard_link} (or forward port if remote).')
res_del = client.persist(res_del)
progress(res_del, notebook=False)
ftrs = client.compute(res_del)
if not sync:
return ftrs
alldat = stack_nested(client.gather(ftrs, errors=errors), func=np.concatenate)
else:
warn('get_pattern_info is run on a dask array without distributed client - might be slow!')
alldat = dask.compute()
shotdata = pd.DataFrame({k: v for k, v in alldat.items() if isinstance(v, np.ndarray) and (v.ndim == 1)})
peakinfo = alldat['peak_data']
elif isinstance(img, da.Array) and via_array:
# do extra step by casting output of _generate_pattern_info into a dask array and
# keep using the dask array API instead of the delayed api as above. For yet not understood
# reasons this leads to a much better behavior of the dask scheduler and yields identical results.
# it's just horribly inelegant.
# function to turn output of _generate_pattern_info into a dask array
_encode_info = lambda info: np.concatenate([np.stack([v for k, v in sorted(info.items()) if k != 'peak_data'], axis=1),
np.hstack([v.reshape(v.shape[0],-1) for k, v in sorted(info['peak_data'].items())])], axis=1)
# compute info for a single image to get structure of output
template = _generate_pattern_info(img[:1,...].compute(), opts, centers=None if centers is None else centers[:1,:])
if centers is not None:
centers = da.from_array(centers, chunks=(img.chunks[0], 2)).reshape((-1, 2, 1))
info_array = img.map_blocks(lambda img, centers: _encode_info(
_generate_pattern_info(img, opts, centers=centers, lorentz_fit=lorentz_fit)), centers,
dtype=np.float, drop_axis=[1,2], new_axis=1,
chunks=(img.chunks[0], _encode_info(template).shape[1]),
name='pattern_info')
# return info_array # for debugging purposes
alldat = client.persist(info_array)
progress(alldat, notebook=False)
alldat = alldat.compute(sync=True)
# recreate shot data table
cols = [k for k in sorted(template) if k != 'peak_data']
types = {k: v.dtype for k, v in sorted(template.items()) if k != 'peak_data'}
shotdata = pd.DataFrame(alldat[:,:len(cols)], columns=cols).astype(types)
# recreate peak info
pk_cols = [(k, v.reshape(v.shape[0],-1).shape[1], v.dtype)
for k, v in sorted(template['peak_data'].items())]
peakinfo, ii_col = {}, len(cols)
for col, width, dt in pk_cols:
peakinfo[col] = alldat[:, ii_col:ii_col+width].squeeze().astype(dt)
ii_col += width
elif isinstance(img, np.ndarray):
alldat = _generate_pattern_info(img, opts=opts, reference=reference,
pxmask=pxmask, centers=centers, lorentz_fit=lorentz_fit)
shotdata = pd.DataFrame({k: v for k, v in alldat.items() if isinstance(v, np.ndarray) and (v.ndim == 1)})
peakinfo = alldat['peak_data']
else:
raise ValueError('Input image(s) must be a dask or numpy array.')
if output_file is not None:
with h5py.File(output_file, 'w') as fh:
for k, v in peakinfo.items():
fh.create_dataset('/entry/data/' + k, data=v, compression='gzip', chunks=(1,) + v.shape[1:])
fh['/entry/data'].attrs['recommended_zchunks'] = -1
dummy_layout = h5py.VirtualLayout(img.shape, dtype='i1')
fh.create_virtual_dataset('/entry/data/' + dummy_stack_name, dummy_layout, fillvalue=-1)
fh['/entry/data'].attrs['signal'] = dummy_stack_name
shotdata_id = pd.concat([shots, shotdata], axis=1)
nexus.store_table(shotdata_id, file=output_file, subset='entry', path='/%/shots')
print('Wrote analysis results to file', output_file)
return shotdata, peakinfo
@loop_over_stack
def _get_corr_img(img: np.ndarray,
x0: np.ndarray, y0: Union[np.ndarray, None],
nPeaks: Union[np.ndarray, None],
peakXPosRaw: Union[np.ndarray, None],
peakYPosRaw: Union[np.ndarray, None],
opts: PreProcOpts,
reference: Optional[Union[np.ndarray, str]] = None,
pxmask: Optional[Union[np.ndarray, str]] = None):
"""Inner function for full correction pipeline. To be called from `correct_image`. Other than
that function, this one can only run on numpy arrays.
Please see doumentation of `correct_image` for further documentation.
"""
reference = imread(opts.reference) if reference is None else reference
pxmask = imread(opts.pxmask) if pxmask is None else pxmask
img = img.astype(np.float32)
if opts.correct_saturation:
img = apply_saturation_correction(img, opts.shutter_time, opts.dead_time, opts.dead_time_gap_factor)
img = apply_flatfield(img, reference=reference)
# here, _always_ choose strategy='replace'. Interpolation will only be done on the final step
# img = correct_dead_pixels(img, pxmask=pxmask, strategy='replace', mask_gaps=opts.mask_gaps)
if opts.remove_background:
img = remove_background(img, x0, y0, nPeaks, peakXPosRaw, peakYPosRaw, pxmask=pxmask)
# return img
# has to be re-done after background correction
img = correct_dead_pixels(img, pxmask, strategy='interpolate' if opts.interpolate_dead else 'replace', mask_gaps=opts.mask_gaps)
return img.astype(np.float32) if opts.float else (img * opts.int_factor).round().astype(np.int16)
[docs]def correct_image(img: Union[np.ndarray, da.Array], opts: PreProcOpts,
x0: Union[None, np.ndarray, da.Array, pd.Series] = None,
y0: Union[None, np.ndarray, da.Array, pd.Series] = None,
peakinfo: Union[None, Dict[str, Union[np.ndarray, da.Array]]] = None,
reference: Union[None, Union[np.ndarray, str]] = None,
pxmask: Union[None, Union[np.ndarray, str]] = None) -> Union[np.ndarray, da.Array]:
"""Runs correction pipeline on stack of diffraction images (numpy or dask).
The correction pipeline comprises flat-field, saturation and dead-pixel correction, as well as
background subtraction, optionally including exclusion of diffraction peaks for computation of the
background (recommended).
Note:
This function essentially wraps `proc2d._get_corr_image` with smart features to take care
of dask input arrays. If you want to change the
correction pipeline, that is the function to modify.
Args:
img (Union[np.ndarray, da.Array]): Diffraction pattern stack
opts (PreProcOpts): Pre-processing options. Options used are: (...)
x0 (Union[None, np.ndarray, da.Array, pd.Series], optional): Pattern X centers
(None: use image center). Defaults to None.
y0 (Union[None, np.ndarray, da.Array, pd.Series], optional): Pattern Y centers
(None: use image center). Defaults to None.
peakinfo (Union[None, Dict[str, Union[np.ndarray, da.Array]]], optional): Diffraction peak dict
in CXI format (None: no peak exclusion during background subtraction). Defaults to None.
reference (Union[None, Union[np.ndarray, str]], optional): Flat-field reference
(None: use reference file specified in options). Defaults to None.
pxmask (Union[None, Union[np.ndarray, str]], optional): Pixel mask reference
(default: use reference file specified in options). Defaults to None.
Returns:
Union[np.ndarray, da.Array]: Corrected image stack of identical dimension as input stack.
"""
if isinstance(img, np.ndarray):
# take care of numpy image with dask arguments (just in case)
innerargs = [x0, y0, peakinfo['nPeaks'], peakinfo['peakXPosRaw'], peakinfo['peakYPosRaw']]
innerargs = [a.compute() if isinstance(a, da.Array) else a for a in innerargs]
return _get_corr_img(img, *innerargs, opts, reference, pxmask)
reference = imread(opts.reference) if reference is None else reference
pxmask = imread(opts.pxmask) if pxmask is None else pxmask
N = img.shape[0]
if (x0 is None) or (y0 is None):
x0 = y0 = None
else:
if not isinstance(x0, da.Array):
x0 = da.from_array(x0.values if isinstance(x0, pd.Series) else x0, chunks=img.chunks[0])
y0 = da.from_array(y0.values if isinstance(y0, pd.Series) else y0, chunks=img.chunks[0])
if peakinfo is None:
npk = pkx = pky = None
else:
if not isinstance(peakinfo['nPeaks'], da.Array):
peakinfo = {'nPeaks': da.from_array(peakinfo['nPeaks'], chunks=img.chunks[0]),
'peakXPosRaw': da.from_array(peakinfo['peakXPosRaw'], chunks=(img.chunks[0],-1)),
'peakYPosRaw': da.from_array(peakinfo['peakYPosRaw'], chunks=(img.chunks[0],-1))}
npk = peakinfo['nPeaks'].reshape((N,1,1))
pkx = peakinfo['peakXPosRaw'].reshape((N,1,-1))
pky = peakinfo['peakYPosRaw'].reshape((N,1,-1))
return da.map_blocks(_get_corr_img, img,
None if x0 is None else x0.reshape((N,1,1)),
None if y0 is None else y0.reshape((N,1,1)),
npk, pkx, pky,
reference=reference, pxmask=pxmask, opts=opts,
dtype=np.float32 if opts.float else np.int16, chunks=img.chunks)
[docs]def analyze_and_correct(imgs: np.ndarray, opts: PreProcOpts,
correct_non_hits: bool = False,
reference: Union[None, Union[np.ndarray, str]] = None,
pxmask: Union[None, Union[np.ndarray, str]] = None) -> Tuple[np.ndarray, dict]:
"""Analyzes a diffraction pattern (centering and peak finding), and immediately applies a correction.
This function combines `get_pattern_info` and `correct_image`, but works differently in that
it does not inherently handle any lazy/parallel computations: it only simply loops over a numpy
array. It is hence especially useful to check if the preprocessing pipeline works on a small
set, or to embed it into dask delayed objects for parallel execution *outside* the function, which
may be faster than `get_pattern_info` + `correct_image` (see example below).
Args:
imgs (np.ndarray): Input image stack as numpy array
opts (PreProcOpts): pre-processing options
correct_non_hits (bool, optional): Apply correction also to images that do not have
sufficient Bragg spots in them (as defined by opts.min_peaks). Defaults to False.
reference (Union[None, Union[np.ndarray, str]], optional): Reference image as numpy
array or TIF file name. If None, read file defined in options. Defaults to None.
pxmask (Union[None, Union[np.ndarray, str]], optional): Pixel mask image as numpy
array or TIF file name. If None, read file defined in options. Defaults to None.
Returns:
Tuple[np.ndarray, dict]: Corrected image stack and pattern info structure, as
returned by `correct_image` and `get_pattern_info`, respectively.
Example:
To run a parallel computation efficiently, use this function like
>>> results = [dask.delayed(proc2d.analyze_and_correct)(img_chunk, opts) \
for img_chunk in img_stack.to_delayed().ravel()]
>>> dask.compute(results)
"""
reference = imread(opts.reference) if reference is None else reference
pxmask = imread(opts.pxmask) if pxmask is None else pxmask
info = _generate_pattern_info(imgs, opts, reference=reference, pxmask=pxmask)
hits = info['num_peaks'] >= opts.min_peaks
if correct_non_hits or all(hits):
imgs = _get_corr_img(imgs.astype(np.float32), info['center_x'], info['center_y'],
info['peak_data']['nPeaks'],
info['peak_data']['peakXPosRaw'],
info['peak_data']['peakYPosRaw'],
opts, reference=reference, pxmask=pxmask)
else:
imgs = imgs.astype(np.float32)
imgs[hits,...] = _get_corr_img(imgs[hits,...], info['center_x'][hits,...], info['center_y'][hits,...],
info['peak_data']['nPeaks'][hits,...],
info['peak_data']['peakXPosRaw'][hits,...],
info['peak_data']['peakYPosRaw'][hits,...],
opts, reference=reference, pxmask=pxmask)
return imgs, info
#TODO: this might not really belong here -> move to some tools module? Or make private?
[docs]def mean_clip(c: np.ndarray, sigma: float = 2.0) -> float:
"""Iteratively keeps only the values from the array that satisfies
0 < c < c_mean + sigma*std
and return the mean of the array. Assumes the
array contains positive entries,
if it does not or the array is empty returns -1
Args:
c (np.ndarray): input value array
sigma (float, optional): number of standard deviations away from the mean
that is used for mean calculation. Defaults to 2.0.
Returns:
float: Mean of clipped values
"""
c = c[c>0]
if not c.size:
return -1
delta = 1.0
while delta:
c_mean = np.mean(c)
size = c.size
c = c[c < c_mean + sigma*np.sqrt(c_mean)]
delta = size-c.size
return c_mean
#TODO: as before
[docs]def func_lorentz(p: Union[list, tuple, np.ndarray],
x: Union[float, np.ndarray],
y: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""Function that returns a Student't distribution or generalised Cauchy
distribution in Two Dimensions(x,y):
amp * [(1 + ((x-x_0)/scale)**2) + (1 + ((y-y_0)/scale)**2)] ** (-shape/2)
Args:
p (Union[list, tuple, np.ndarray]): Parameter array: [amp, x_0, y_0, scale, shape]
x (Union[float, np.ndarray]): x coordinate(s)
y (Union[float, np.ndarray]): y coordinate(s)
Returns:
Union[float, np.ndarray]: function value at (x, y)
"""
return p[0]*((1+((x-p[1])/p[3])**2.0 + ((y-p[2])/p[3])**2.0)**(-p[4]/2.0))
[docs]@loop_over_stack
def lorentz_fit(img, amp: float = 1.0,
x_0: float = 0.0, y_0: float = 0.0,
scale: float = 5.0, shape: float = 2.0,
threshold: float = 0):
"""
Fits a Lorentz profile to find the center (x_0, y_0) of a diffraction
pattern, ignoring any pixels with values < threshold.
The fit function is based on:
amp * [(1 + ((x-x_0)/scale)**2) + (1 + ((y-y_0)/scale)**2)] ** (-shape/2)
Build upon optimize.least_squares function which is thread safe
Note: least.sq is not. Analytical Jacobian has been added.
Note:
If possible (i.e. you leave shape at 2.0), do **not** use this function,
it's really slow. Instead use `lorentz_fast`.
Args:
see fit function above
Returns:
OptimizeResult: result of optimization
"""
param = np.array([amp,x_0,y_0,scale, shape])
def jac_lorentz(p, x, y, img):
radius = ((x-p[1])/p[3])**2.0 + ((y-p[2])/p[3])**2.0
func = ((1+radius)**(-p[4]/2.0-1.0))
d_amp = ((1+radius)**(-p[4]/2.0))
d_x0 = (x-p[1])/(p[3]**2.0)*p[0]*p[4]*func
d_y0 = (y-p[2])/(p[3]**2.0)*p[0]*p[4]*func
d_scale = p[0]*p[4]/p[3]*radius*func
d_shape = -0.5*p[0]*((1+radius)**(-p[4]/2.0))*np.log(1+radius)
return np.transpose([d_amp, d_x0,d_y0, d_scale, d_shape]/(img**0.5))
def func_error(p, x, y, img):
return (func_lorentz(p,cut[1],cut[0])-img)/(img**0.5)
cut = np.where(img > threshold)
out = optimize.least_squares(func_error,param,jac_lorentz,loss='linear',
max_nfev=1000,args=(cut[1],cut[0],img[cut]),
bounds=([1.0,0.0,0.0,1.0, 1.0],np.inf))
return out
[docs]@loop_over_stack
def lorentz_fast(img, x_0: float = None, y_0: float = None, amp: float = None,
scale: float = 5.0, radius: float = None, limit: float = None,
threshold: int = 0, threads: bool = False, verbose: bool = False):
"""Fast Lorentzian fit for finding beam center; especially suited for refinement after a reasonable estimate
(i.e. to a couple of pixels) has been made by another method such as truncated COM.
Compared to the other fits, it always assumes a shape parameter 2 (i.e. standard Lorentzian with asymptotic x^-2).
It can restrict the fit to only a small region around the initial value for the beam center, which massively speeds
up the function. Also, it auto-estimates the intial parameters somewhat reasonably if nothing else is given.
Args:
img (float): input image or image stack. If a stack is supplied, it is serially looped. Not accepting dask directly.
x_0 (float, optional): estimated x beam center. If None, is assumed to be in the center of the image. Defaults to None.
y_0 (float, optional): analogous. Defaults to None.
amp (float, optional): estimated peak amplitude. If None, is set to the 99.99% percentile of img. Defaults to None.
scale (float, optional): peak HWHM estimate in pixels. Defaults to 5.0.
radius (float, optional): radius of a box around x_0, y_0 where the fit is actually done. If None, the entire image is used. Defaults to None.
limit (float, optional): If not None, the fit result is discarded if the found beam_center is further away than this value from
the initial estimate. Defaults to None.
threshold (int, optional): pixel value threshold below which pixels are ignored. Defaults to 0.
threads (bool, optional): if True, uses scipy.optimize.least_squares, which for larger arrays (radius more than around 15)
uses multithreaded function evaluation. Especially for radius < 50, this may be slower than single-threaded.
In this case, best set to False. Defaults to False.
verbose (bool, optional): if True, a message is printed on some occasions. Defaults to False.
Returns:
np.ndarray: numpy array of refined parameters [amp, x0, y0, scale]
"""
if (x_0 is None) or (not np.isfinite(x_0)) or np.isnan(x_0):
x_0 = img.shape[1] / 2
if (y_0 is None) or (not np.isfinite(y_0)) or np.isnan(y_0):
y_0 = img.shape[0] / 2
if radius is not None:
try:
x1 = int(x_0 - radius)
x2 = int(x_0 + radius)
y1 = int(y_0 - radius)
y2 = int(y_0 + radius)
except ValueError as err:
print('Weird:', x0, y0, radius)
raise err
if (x1 < 0) or (x2 > img.shape[1]) or (y1 < 0) or (y2 > img.shape[0]):
print('Cannot cut image around peak. Centering.')
x1 = int(img.shape[1] / 2 - radius)
x2 = int(img.shape[1] / 2 + radius)
y1 = int(img.shape[0] / 2 - radius)
y2 = int(img.shape[0] / 2 + radius)
img = img[y1:y2, x1:x2]
else:
x1 = 0
y1 = 0
if amp is None:
try:
amp = np.percentile(img, 99)
except Exception as err:
print('Something weird: {} Cannot get image percentile. Img size is {}. Skipping.'.format(err, img.shape))
return np.array([-1, x_0, y_0, scale])
cut = np.where(img > threshold)
x = cut[1] + x1
y = cut[0] + y1
img = img[cut]
norm = img ** 0.5
function = lambda p: p[0] * ((1 + ((x - p[1]) / p[3]) ** 2.0 + ((y - p[2]) / p[3]) ** 2.0) ** (-1))
error = lambda p: (function(p) - img) / norm
# The Jacobian is not used anymore, but let's keep it here, just in case
def jacobian(p):
radius = ((x - p[1]) / p[3]) ** 2.0 + ((y - p[2]) / p[3]) ** 2.0
func = ((1 + radius) ** (-2.0))
d_amp = ((1 + radius) ** (-1.0))
d_x0 = (x - p[1]) / (p[3] ** 2.0) * p[0] * 2.0 * func
d_y0 = (y - p[2]) / (p[3] ** 2.0) * p[0] * 2.0 * func
d_scale = p[0] * 2.0 / p[3] * radius * func
res = np.stack((d_amp / norm, d_x0 / norm, d_y0 / norm, d_scale / norm), axis=-1)
return res
param = np.array([amp, x_0, y_0, scale])
# print(param)
try:
if threads:
# new algorithm: uses multithreaded evaluation sometimes, which is not always desired!
# out = optimize.least_squares(error, param, jac=jacobian, loss='linear',
# max_nfev=1000, method='lm', verbose=0,
# x_scale=(amp, 1, 1, 5)).x
out = optimize.least_squares(error, param, loss='linear',
max_nfev=1000, method='lm', verbose=0,
x_scale=(amp, 1, 1, 5),xtol=1e-6).x
else:
# old algorithm never uses multithreading. May be better.
# out = optimize.leastsq(error, param, Dfun=jacobian)[0]
out = optimize.leastsq(error, param, xtol=1e-6)[0]
except Exception as err:
# print(param)
print('Fitting did not work: {} with initial parameters {}'.format(err, param))
# raise err
return param
change = out - param
if limit and np.abs(change[1:3]).max() >= limit:
if verbose:
print('Found out of limit fit result {}. Reverting to init values {}.' \
.format(out[1:3], param[1:3]))
out = param
return out
[docs]@loop_over_stack
def center_of_mass(img: np.ndarray, threshold: float = 0.0):
"""
Returns the center of mass of an image using all the pixels larger than
the threshold. Automatically skips values below threshold. Fast for sparse
images, for more crowded ones `center_of_mass2` may be faster.
Args:
img (np.ndarray): Input image
threshold (float, optional): minimum pixel value to include. Defaults to 0.0.
Returns:
np.ndarray: [x0, y0] -> image center of mass
"""
cut = np.where(img>threshold)
(y0,x0) = np.sum(img[cut]*cut,axis=1)/np.sum(img[cut])
return np.array([x0,y0])
[docs]def center_of_mass2(img: np.ndarray, threshold: Optional[float] = None):
"""
Returns the center of mass of an image using all the pixels larger than
the threshold. Automatically skips values below threshold. Can be faster
than `center_of_mass` for crowded images (just try it out).
Args:
img (np.ndarray): Input image
threshold (float, optional): minimum pixel value to include. If None,
does not apply a threshold. Defaults to None.
Returns:
np.ndarray: [x0, y0] -> image center of mass
"""
vec = np.stack(np.meshgrid(np.arange(0, img.shape[-1]),
np.arange(0, img.shape[-2])), axis=-1)
if threshold is not None:
imgt = np.where(img >= threshold, img, 0)
else:
imgt = img
com = (np.tensordot(imgt, vec, 2)
/imgt.sum(axis=(-2, -1)).reshape(-1, 1))
return com
[docs]@loop_over_stack
def apply_virtual_detector(img: np.ndarray, r_inner: float, r_outer: float,
x0: Optional[float] = None, y0: Optional[float] = None) -> float:
"""
Apply a "virtual STEM detector" to stack, with given inner and outer radii. Returns the mean value of all pixels
that fall inside this annulus.
Args:
img (np.ndarray): input image (or stack thereof)
r_inner (float): Inner radius
r_outer (float): Outer radius
x0 (float): Beam center position along x. If None, assumes center of image. Defaults to None. Should follow
CXI convention, i.e. relative to pixel center, not corner.
y0 (float): Similar for y
Returns:
float: mean value of pixels inside the annulus defined by r_inner and r_outer
"""
ysize, xsize = img.shape
x0 = xsize/2 - 0.5 if x0 is None else x0
y0 = ysize/2 - 0.5 if y0 is None else y0
x = np.arange(xsize) - x0
y = np.arange(ysize) - y0
X, Y = np.meshgrid(x, y)
R = np.sqrt(X**2 + Y**2)
# print(r_inner, r_outer)
mask = ((R < r_outer) & (R >= r_inner)) & ((img >= 0) if (img.dtype == np.integer) else np.isfinite(img))
# print(mask.sum())
return img[mask].mean()
[docs]@loop_over_stack
def get_peaks(img: np.ndarray, x0: float, y0: float, max_peaks: int = 500,
pxmask: Optional[np.ndarray] = None, min_snr: float = 4., threshold: float = 7.,
min_pix_count: int = 2, max_pix_count: int = 20, local_bg_radius: int = 3,
min_res: int = 0, max_res: int = 500,
as_dict: bool = True, extended_info: bool = False) -> Union[dict,np.ndarray]:
"""Find peaks in diffraction pattern using the peakfinder8 algorithm as used in
CrystFEL, OnDA and Cheetah. For explanation of the finding parameters, please consult the
CrystFEL documentation (or just run `man indexamajig`).
Args:
img (np.ndarray): image stack
x0 (float): image stack x center
y0 (float): image stack y center
max_peaks (int, optional): maximum number of peaks. Defaults to 500.
pxmask (Optional[np.ndarray], optional): pixel mask. Defaults to None.
min_snr (float, optional): minimum peak SNR. Defaults to 4..
threshold (float, optional): count threshold. Defaults to 8.
min_pix_count (int, optional): minimum number of pixels in peak. Defaults to 2.
max_pix_count (int, optional): maximum number of pixels in peak. Defaults to 20.
local_bg_radius (int, optional): radius for peak backgroud estimation. Defaults to 3.
min_res (int, optional): minimum resolution (= radial range) in pixels. Defaults to 0.
max_res (int, optional): maximum resolution (= radial range) in pixels. Defaults to 500.
as_dict (bool, optional): return results as a dictionary instead of a
single numpy array. Defaults to True.
Returns:
dict: CXI-format peaks information. If as_dict=False, instead returns a 1d array
of size (3 * max_peaks + 1), which contains x positions, y positions, intensities,
and number of peaks concatenated.
Note:
The returned peak positions follow CXI convention, that is, they refer to pixel *centers*,
not corners (as in `CrystFEL`). For `CrystFEL`-convention you have to add 0.5 to the
returned peak positions.
"""
from .peakfinder8_extension import peakfinder_8
X, Y = np.meshgrid(range(img.shape[1]), range(img.shape[0]))
R = (((X-x0)**2 + (Y-y0)**2)**.5).astype(np.float32)
mask = np.ones_like(img, dtype=np.int8) if pxmask is None else (pxmask == 0).astype(np.int8)
mask[R > max_res] = 0
mask[R < min_res] = 0
pks = peakfinder_8(max_peaks,
img.astype(np.float32),
mask,
R,
img.shape[1],
img.shape[0],
1,
1,
threshold,
min_snr,
min_pix_count,
max_pix_count,
local_bg_radius)
fill = [0]*(max_peaks-len(pks[0]))
result = [('peakXPosRaw', np.array(pks[0] + fill)),
('peakYPosRaw', np.array(pks[1] + fill)),
('peakTotalIntensity', np.array(pks[2] + fill)),
('nPeaks', np.array(len(pks[0])))]
if extended_info:
result = result[:-1] + [
('peakIndex', np.array(pks[3] + fill)),
('peakNPix', np.array(pks[4] + fill)),
('PeakMaxIntensity', np.array(pks[5] + fill)),
('PeakSigma', np.array(pks[6] + fill)),
('PeakSNR', np.array(pks[7] + fill)),
('nPeaks', np.array(len(pks[0])))
]
if as_dict:
return dict(result)
# return result
else:
return np.array(pks[0] + fill + pks[1] + fill + pks[2] + fill + [len(pks[0])])
[docs]@loop_over_stack
def radial_proj(img: np.ndarray, x0: Optional[float] = None, y0: Optional[float] = None,
scale: float = 1, scale_axis: float = 0,
my_func: Union[Callable[[np.ndarray], np.ndarray], List[Callable[[np.ndarray], np.ndarray]]] = np.nanmean,
min_size: int = 600, max_size: int = 850, filter_len: int = 1) -> np.ndarray:
"""
Applies a function to azimuthal bins of the image around
the center (x0, y0) for each integer radius and returns the result
in a np.array of size max_size, yielding a radial profile. Skips values that are set to -1 or nan.
Optionally, a median filter can be applied to the output.
Args:
img (np.ndarray): input image or stack
x0 (Optional[float], optional): x center of pattern. Center of image is None. Defaults to None.
y0 (Optional[float], optional): y center of pattern. Center of image is None. . Defaults to None.
my_func (Union[Callable[[np.ndarray], np.ndarray], List[Callable[[np.ndarray], np.ndarray]]], optional): function
to call on all pixel values at a given radius, or iterable thereof. Defaults to np.nanmean.
min_size (int, optional): Minimum length of the output profile. Defaults to 600.
max_size (int, optional): Maximum length of the output profile. Defaults to 850.
filter_len (int, optional): Kernel size of median filter applied after profile calculation.
filter_len must be odd, and filtering is at the moment incompatible with multiple functions. Defaults to 1.
Returns:
np.ndarray: radial profile calculated using my_func
Note:
The median filter will currently only work, if a single function is used only! Sorry for that.
"""
#TODO ellipticity correction?
if isinstance(my_func, tuple) and (len(my_func) > 1) and (filter_len > 1):
raise ValueError('radial_proj with filtering only works if a single function is used. Sorry.')
if filter_len//2 == filter_len/2:
raise ValueError('filter_len must be odd.')
if not (isinstance(my_func, list) or isinstance(my_func, tuple)):
my_func = [my_func]
(ylen,xlen) = img.shape
(y,x) = np.ogrid[0:ylen,0:xlen]
#print(x0,y0)
x0 = img.shape[1]/2 - 0.5 if x0 is None else float(x0)
y0 = img.shape[0]/2 - 0.5 if y0 is None else float(y0)
# fault tolerance if absurd centers are supplied
if np.isnan(x0) or np.isnan(y0) or x0<0 or x0>=xlen or y0<0 or y0>=ylen:
result = np.empty(max_size*len(my_func))
result.fill(np.nan)
return result
x, y = x - x0, y - y0
# ellipticity correction
if scale != 1:
c, s = np.cos(scale_axis), np.sin(scale_axis)
x, y = scale*(c*x - s*y), s*x + c*y
x, y = c*x + s*y, -s*x + c*y
radius = (np.rint((x**2 + y**2)**0.5) # radius coordinate of each pixel
.astype(np.int32))
center = img[int(np.round(y0)),int(np.round(x0))]
radius[np.where((img==-1) | np.isnan(img))]=0 # ignore bad pixels by setting radius to zero
row = radius.flatten()
col = np.arange(len(row))
mat = sparse.csr_matrix((img.flatten(), (row, col)))
rng = np.min([1+np.max(radius), max_size])
size = np.max([rng, min_size])
result = -1 * np.ones(size*len(my_func))
fstart = np.arange(0, size*len(my_func), size)
for r in range(1, rng):
rbin_data = mat[r].data
if rbin_data.size:
result[r + fstart] = [fn(rbin_data) for fn in my_func]
if center > -1:
result[fstart] = [fn(center) for fn in my_func]
if filter_len > 1:
result[filter_len//2:] = median_filter(result, filter_len)[filter_len//2:]
assert (result.size >= min_size) and (result.size <= max_size)
return result
[docs]@loop_over_stack
def cut_peaks(img: np.ndarray, nPeaks: np.ndarray, peakXPosRaw: np.ndarray,
peakYPosRaw: np.ndarray, radius: int = 2, replaceval: Union[int, float, None] = None) -> np.ndarray:
"""Cuts peaks out of an image and replaces them with replaceval. Peak positions are provided in CXI format.
This function is mainly interesting for calculation of radial profiles, ignoring Bragg peaks.
Args:
img (np.ndarray): Input image (or stack thereof)
nPeaks (np.ndarray): number of peaks
peakXPosRaw (np.ndarray): peak X positions
peakYPosRaw (np.ndarray): peak y positions
radius (int, optional): Radius of circle within which image values are replaced around each peak.
Defaults to 2.
replaceval (Union[int, float, None], optional): Value to paint into the circles. If None,
uses -1 on integer images and np.nan otherwise. Defaults to None.
Returns:
np.ndarray: Image with cut-out peaks.
"""
#print(nPeaks)
if replaceval is None:
replaceval = -1 if issubclass(img.dtype.type, np.integer) else np.nan
nPeaks = nPeaks.squeeze()
peakXPosRaw = peakXPosRaw.squeeze()
peakYPosRaw = peakYPosRaw.squeeze()
#print(peakYPosRaw[:nPeaks.squeeze()])
mask = np.zeros_like(img).astype(np.bool)
#print(img.shape)
mask[(peakYPosRaw[:nPeaks]).round().astype(int), (peakXPosRaw[:nPeaks]).round().astype(int)] = True
mask = binary_dilation(mask,disk(radius),1)
img_nopeaks = np.where(mask,replaceval,img)
return img_nopeaks
[docs]@loop_over_stack
def strip_img(img: np.ndarray, prof: np.ndarray,
x0: Optional[float] = None, y0: Optional[float] = None,
pxmask: Optional[np.ndarray] = None, truncate: bool = False,
offset: Union[float, int] = 0, keep_edge_offset: bool = False,
replaceval: Optional[float] = None, interp: bool = True,
dtype: Optional[np.dtype] = None) -> np.ndarray:
"""Subtract a radial profile from a diffraction pattern, assuming radial symmetry of the background.
Args:
img (np.ndarray): Input image (or stack thereof)
prof (np.ndarray): Radial profile to be subtracted
x0 (float, optional): Diffraction pattern center along x. If None, use the image center.
Defaults to None.
y0 (float, optional): Diffraction pattern center along y. If None, use the image center.
Defaults to None.
pxmask (Optional[np.ndarray], optional): Pixel mask to apply *after* subtraction. Defaults to None.
truncate (bool, optional): Replace all values below the offset by replaceval. Defaults to False.
offset (Union[float, int], optional): Offset to apply to the output image. Required if
you want to keep positive pixel values. Defaults to 0.
keep_edge_offset (bool, optional): [description]. Defaults to False.
replaceval (Optional[float], optional): Replace value for pixels falling below offset.
Defaults to None.
interp (bool, optional): Interpolate background pixel values, otherwise use nearest
neighbour. Defaults to True.
dtype (Optional[np.dtype], optional): If not None, convert output
image to this data type. Defaults to None.
Returns:
np.ndarray: Image with subtracted radial profile.
"""
#TODO ellipticity correction?
x0 = img.shape[1]/2 - 0.5 if x0 is None else float(x0)
y0 = img.shape[0]/2 - 0.5 if y0 is None else float(y0)
if np.isnan(x0) or np.isnan(y0):
return np.zeros(img.shape)
prof = prof.flatten() # background profile
ylen,xlen = img.shape
y,x = np.ogrid[0:ylen,0:xlen]
if interp:
iprof = interpolate.interp1d(range(len(prof)), prof, fill_value=0, bounds_error=False)
radius = ((x-x0)**2 + (y-y0)**2)**0.5
profile = np.zeros(1+np.floor(np.max(radius)).astype(np.int32))
bkg = iprof(radius)
else:
radius = (np.rint(((x-x0)**2 + (y-y0)**2)**0.5)).astype(np.int32)
profile = np.zeros(1+np.max(radius))
comlen = min(len(profile), len(prof))
np.copyto(profile[:comlen], prof[:comlen])
bkg = profile[radius]
img_out = img - bkg + offset if keep_edge_offset else img - bkg
dtype = img_out.dtype if dtype is None else dtype
if replaceval is None:
replaceval = np.nan if np.issubdtype(dtype, np.floating) else -1
if truncate:
img_out[img_out < offset] = replaceval
if pxmask is not None:
img_out = correct_dead_pixels(img_out, pxmask, 'replace',
replace_val=replaceval, mask_gaps=False)
if not dtype == img_out.dtype:
if np.issubdtype(dtype, np.integer):
img_out = img_out.round()
img_out = img_out.astype(dtype)
return img_out
[docs]@loop_over_stack
def remove_background(img: np.ndarray, x0: Optional[float] = None, y0: Optional[float] = None,
nPeaks: Optional[np.ndarray] = None, peakXPosRaw: Optional[np.ndarray] = None, peakYPosRaw: Optional[np.ndarray] = None,
peak_radius=3, filter_len=5, rfunc: Callable[[np.ndarray], np.ndarray] = np.nanmean,
pxmask=None, truncate=False, offset=0) -> np.ndarray:
"""Combines `radial_proj`, `cut_peaks` and `strip_img` into a background-removal protocol for diffration
patterns, assuming radial symmetry of the background.
The diffraction pattern is first azimuthally integrated, excluding Bragg peaks, and the resulting radial
profile is further smoothed. The profile is then re-projected to the full image and subtracted. This procedure
usually works excellently well - at least, if the peak finding has been done carefully. If there are hard
issues with peak finding, it might be worth setting rfunc=np.nanmedian.
Peaks have to be provided in CXI format and convention.
Args:
img (np.ndarray): Input image or stack thereof
x0 (Optional[float], optional): Diffraction pattern center along x. If None, use the image center.
Defaults to None.
y0 (Optional[float], optional): Diffraction pattern center along y. If None, use the image center.
Defaults to None.
nPeaks (Optional[np.ndarray], optional): Number of peaks. Defaults to None.
peakXPosRaw (Optional[np.ndarray], optional): peak X positions. Defaults to None.
peakYPosRaw (Optional[np.ndarray], optional): peak Y positions. Defaults to None.
peak_radius (int, optional): Radius around each peak excluded from background calculation. Defaults to 3.
filter_len (int, optional): Range of median filter applied to radial profile. Defaults to 5.
rfunc (Callable[[np.ndarray], np.ndarray], optional): Function for calculation of the radial profile
through azimuthal averaging. Defaults to np.nanmean.
pxmask ([type], optional): Pixel mask to be applied after correction. Defaults to None.
truncate (bool, optional): Set all pixels of value < offset to 0. Defaults to False.
offset (int, optional): Offset for the output image. Defaults to 0.
Returns:
np.ndarray: [description]
"""
if np.issubdtype(img.dtype, np.integer) and offset == 0:
warn('Removing background on an integer image with zero offset will likely cause trouble later on.')
replace_val = np.nan if np.issubdtype(img.dtype, np.floating) else -1
x0 = img.shape[1]/2 - .5 if x0 is None else x0
y0 = img.shape[0]/2 - .5 if y0 is None else y0
pxmask = ((img == np.nan) | (img == -1)) if pxmask is None else pxmask
#print((pxmask == 0).sum())
# print(nPeaks)
if (nPeaks is not None) and (nPeaks > 0):
img_nopk = cut_peaks(img, nPeaks, peakXPosRaw, peakYPosRaw, radius=peak_radius, replaceval=replace_val)
else:
img_nopk = img.copy()
# ALWAYS mask gaps for the background determination
# TODO THIS FAILS FOR IMAGES NOT MATCHING THE MAIN DETECTOR GEOMETRY
img_nopk = correct_dead_pixels(img_nopk, pxmask, mask_gaps=True, strategy='replace', replace_val=replace_val)
# return img_nopk
r0 = radial_proj(img_nopk, x0, y0, my_func=rfunc, filter_len=filter_len)
# don't supply pxmask to strip_image, as it will enforce replace behavior and has no advantage vs. separate
# correction
img_nobg = strip_img(img, prof=r0, x0=x0, y0=y0, pxmask=None, truncate=truncate,
keep_edge_offset=True, interp=True, dtype=img.dtype)
return img_nobg
@jit(['int32[:,:](int32[:,:], float64, float64, int64, int64, int64)',
'int16[:,:](int16[:,:], float64, float64, int64, int64, int64)',
'int64[:,:](int64[:,:], float64, float64, int64, int64, int64)',
'float64[:,:](float64[:,:], float64, float64, int64, int64, float64)',
'float32[:,:](float32[:,:], float64, float64, int64, int64, float64)'],
nopython=True, nogil=True) # ahead-of-time compilation using numba. Otherwise painfully slow.
def _center_sgl_image(img, x0, y0, xsize, ysize, padval):
"""Shifts a *single* image (not applicable to stacks!), such that the original image coordinates x0, y0
are in the center of the output image, which has a size of xsize, ysize.
This function is typically used to change diffraction images such that the zero-order beam sits in the
center of the image. The size of the output image should be sufficiently larger as to not truncate
the shifted diffraction pattern.
Note:
The coordinates in this function refer to pixel centers (CXI convention), *not* pixel corners
(CrystFEL convention). I.e., if shifting based on CrystFEL output or similar, the shifts
must be increased by 0.5.
Args:
img (np.ndarray): Input image
x0 (float): x position in input image to be shifted to the center of the output image
y0 (float): y position in input image to be shifted to the center of the output image
xsize (int): x size of the output image
ysize (int): y size of the output image
padval (float or int): value of the pixels used to pad the output image.
Returns:
np.ndarray: output image of size (ysize, xsize) with centered diffraction pattern
"""
simg = np.array(padval).astype(img.dtype) * np.ones((ysize, xsize), dtype=img.dtype)
#int64=np.int64
#x0 -= 0.5
xin = np.ceil(np.array([-xsize / 2, xsize / 2]) + x0, np.empty(2)).astype(int64) # initial coordinate system
xout = np.array([0, simg.shape[1]], dtype=int64) # now start constructing the final coordinate system
if xin[0] < 0:
xout[0] = -xin[0]
xin[0] = 0
if xin[1] > img.shape[1]:
xout[1] = xout[1] - (xin[1] - img.shape[1])
xin[1] = img.shape[1]
yin = np.ceil(np.array([-ysize / 2, ysize / 2]) + y0, np.empty(2)).astype(int64)
yout = np.array([0, simg.shape[0]], dtype=int64)
if yin[0] < 0:
yout[0] = -yin[0]
yin[0] = 0
if yin[1] > img.shape[0]:
yout[1] = yout[1] - (yin[1] - img.shape[0])
yin[1] = img.shape[0]
#print(xin,xout,yin,yout)
simg[yout[0]:yout[1], xout[0]:xout[1]] = img[yin[0]:yin[1], xin[0]:xin[1]]
return simg
[docs]def center_image(imgs: Union[np.ndarray, da.Array], x0: Union[np.ndarray, da.Array],
y0: Union[np.ndarray, da.Array], xsize: int, ysize: int,
padval: Union[float, int, None] = None, parallel: bool = True):
"""
Shifts a stack of images, such that the original image coordinates x0, y0
are in the center of the output image, which has a size of xsize, ysize.
This function is typically used to change diffraction images such that the zero-order beam sits in the
center of the image. The size of the output image should be sufficiently larger as to not truncate
the shifted diffraction pattern.
Note:
The coordinates in this function refer to pixel centers (CXI convention), *not* pixel corners
(CrystFEL convention). I.e., if shifting based on CrystFEL output or similar, the shifts
must be increased by 0.5.
Args:
imgs (Union[np.ndarray, da.Array]): Input image stack
x0 (Union[np.ndarray, da.Array]): x position in input image to be shifted to the center of the output image
y0 (Union[np.ndarray, da.Array]): y position in input image to be shifted to the center of the output image
xsize (int): x size of the output image
ysize (int): y size of the output image
padval (Union[float, int, None], optional): value of the pixels used to pad the output image.
If None, use nan for float images and -1 for integer images. Defaults to None.
parallel (bool, optional): execute operation in parallel. Defaults to True.
Returns:
Union[np.ndarray, da.Array]: output image stack of size (ysize, xsize) with centered diffraction patterns
"""
if padval is None:
padval = np.nan if not issubclass(imgs.dtype.type, np.integer) else -1
print('Padding with value ', padval)
if isinstance(imgs, da.Array):
# Preprocess arguments and call function again, using map_blocks along the stack direction
x0 = x0.reshape(-1, 1, 1)
y0 = y0.reshape(-1, 1, 1)
if not isinstance(x0, da.Array):
x0 = da.from_array(x0, (imgs.chunks[0], 1, 1))
if not isinstance(y0, da.Array):
y0 = da.from_array(y0, (imgs.chunks[0], 1, 1))
return imgs.map_blocks(center_image, x0, y0, xsize, ysize, padval,
chunks=(imgs.chunks[0], ysize, xsize),
dtype=imgs.dtype, parallel=False)
# condition the input arguments a bit...
x0 = x0.reshape(-1)
x0[np.isnan(x0)] = imgs.shape[2] / 2
y0 = y0.reshape(-1)
y0[np.isnan(y0)] = imgs.shape[1] / 2
simgs = np.array(padval).astype(imgs.dtype) * np.ones((imgs.shape[0], ysize, xsize), dtype=imgs.dtype)
if parallel:
it = prange(imgs.shape[0]) # uses numba's prange for parallelization
else:
it = range(imgs.shape[0])
for ii in it:
# print(x0[ii], y0[ii])
simg = _center_sgl_image(imgs[ii, :, :], x0[ii], y0[ii], xsize, ysize, padval)
simgs[ii, :, :] = simg
return simgs
[docs]def apply_saturation_correction(img: np.ndarray, exp_time: float, dead_time: float = 1.9e-3,
gap_factor: float = 2):
"""Apply detector correction function to image. Should ideally be done even before flatfield.
Uses a 5th order polynomial approximation to the Lambert function, which is appropriate
for a paralyzable detector, up to the point where its signal starts inverting (which is where
nothing can be done anymore)
The default dead time value of 1.9 microseconds has been determined for a Medipix3 sensor.
Args:
img (np.ndarray): Input image or image stack
exp (float): Exposure time in ms
dead_time (float, optional): Dead time of detector in ms. Defaults to 1.9e-3.
gap_factor (float, optional): Factor to scale dead time for gap pixels. Defaults to 2.4.
"""
lambert = lambda x: x - x**2 + 3/2*x**3 - 8/3*x**4 + 125/24*x**5
satcorr = lambda y, sat: -lambert(-sat*y)/sat # saturation parameter: dead time/exposure time
if gap_factor != 1:
dt = dead_time * (1 + (gap_factor-1)*gap_pixels())
else:
dt = dead_time
return satcorr(img, dt/exp_time)
[docs]def apply_flatfield(img: Union[np.ndarray, da.Array], reference: Union[np.ndarray, str], keep_type: bool = True,
ref_smooth_range: Optional[float] = None,
normalize_reference: bool = False) -> Union[np.ndarray, da.Array]:
"""Corrects the detector response by dividing the images in the image (stack) by a reference
image (gain reference image), which should vary around 1.
Args:
img (Union[np.ndarray, da.Array]): Input image
reference (Union[np.ndarray, str]): array containing the reference image, or filename of
a TIF file containing the reference image
keep_type (bool, optional): Keep the image data type, that is, round the pixel values
back to integers if the input is an integer image. If False, the output image
will always be a float. Defaults to True.
ref_smooth_range (Optional[float], optional): If not None, applies a Gaussian blur to the
reference image before correction, use this parameter to set its width. Defaults to None.
normalize_reference (bool, optional): Re-normalize the reference image such that its
average value is exactly 1. Defaults to False.
Returns:
np.ndarray: flatfield-corrected image
"""
if isinstance(reference, str):
reference = imread(reference).astype(np.float32)
elif isinstance(reference, np.ndarray):
reference = reference.astype(np.float32)
else:
raise TypeError('reference must be either numpy array or TIF filename')
if normalize_reference:
reference = reference/np.nanmean(reference)
if ref_smooth_range is not None:
reference = convolve(reference, Gaussian2DKernel(ref_smooth_range),
boundary='extend', nan_treatment='interpolate')
if len(img.shape) > 2:
reference = reference.reshape((1,reference.shape[-2],reference.shape[-1]))
if keep_type:
return (img/reference).astype(img.dtype)
else:
return img/reference
[docs]def correct_dead_pixels(img: Union[np.ndarray, da.Array], pxmask: Union[np.ndarray, str],
strategy: str = 'interpolate',
interp_range: int = 1, replace_val: Union[float, int] = None,
mask_gaps: bool = False, edge_mask_x: Union[int, Tuple] = (100, 30),
edge_mask_y: Union[int, Tuple] = 0, invert_mask: bool = False) -> np.ndarray:
"""Corrects a set of images for dead pixels by either replacing values with a
constant, or interpolation from a Gaussian-smoothed version of the image. It
requires a binary array (pxmask) which is 1 (or 255 or True) for dead pixels.
The function accepts a 3D array where the first dimension corresponds to a stack/movie.
Args:
img (np.ndarray): the image or image stack (first dimension is stack).
For strategy=='replace' it can be a dask or numpy array, otherwise numpy only.
pxmask (Union[np.ndarray, str]): pixel mask with values as described above, or name of
a TIF file containing the pixel mask
strategy (str, optional): 'interpolate' or 'replace'. Defaults to 'interpolate'.
interp_range (int, optional): range of interpolation for 'interpolate' strategy, in pixels.
Defaults to 1.
replace_val (Union[float, int], optional): replacement value for 'replace' strategy. If None, use
-1 for integer images and nan for float images. Defaults to None.
mask_gaps (bool, optional): mask gaps between detector panels as returned by the gap_pixels() function.
Defaults to False.
edge_mask_x (int, optional): Declare this number of pixels near the edges along x as
invalid and replace them with replaceval. Defaults to 70.
edge_mask_y (int, optional): Declare this number of pixels near the edges along y as
invalid and replace them with replaceval. Defaults to 0.
invert_mask (bool, optional): invert the pixel mask, i.e., invalid pixels are zero/False. Defaults to False.
Returns:
np.ndarray: dead-pixel corrected image. Can be da.Array for 'replace' strategy.
"""
assert strategy in ('interpolate', 'replace')
if replace_val is None:
replace_val = -1 if isinstance(img, np.integer) else np.nan
if isinstance(pxmask, str):
pxmask = imread(pxmask)
elif isinstance(pxmask, np.ndarray) or isinstance(pxmask, da.Array):
pxmask = pxmask.astype(np.bool)
if invert_mask:
pxmask = np.logical_not(pxmask)
else:
raise TypeError('pxmask must be either Numpy array, or TIF file name')
if mask_gaps:
pxmask[gap_pixels()] = True
if edge_mask_x:
if isinstance(edge_mask_x, int):
rng = (edge_mask_x, edge_mask_x)
else:
rng = edge_mask_x
# print(rng)
pxmask[:, :rng[0]] = True
pxmask[:, -rng[1]:] = True
if edge_mask_y:
if isinstance(edge_mask_y, int):
rng = (edge_mask_y, edge_mask_y)
else:
rng = edge_mask_y
# print(rng)
pxmask[:rng[0],:] = True
pxmask[-rng[1]:,:] = True
if strategy == 'interpolate':
if (img.ndim > 2) and strategy == 'interpolate':
return np.stack([correct_dead_pixels(theImg, pxmask=pxmask,
strategy='interpolate', interp_range=interp_range,
replace_val=replace_val) for theImg in img])
kernel = Gaussian2DKernel(interp_range)
with catch_warnings():
simplefilter("ignore")
img_flt = convolve(img.astype(float), kernel, boundary='extend',
nan_treatment='interpolate', mask=pxmask)
if isinstance(img, np.integer):
img_flt = np.nan_to_num(img_flt, copy=False, nan=-1).astype(np.int32)
img_out = np.where(pxmask, img_flt, img)
return img_out
elif strategy == 'replace':
if isinstance(img, np.ndarray):
if img.ndim > 2:
# putmask does not support broadcasting
np.putmask(img, np.broadcast_to(pxmask, img.shape),
replace_val)
else:
np.putmask(img, pxmask, replace_val)
return img
elif isinstance(img, da.Array):
#dask arrays are immutable. This requires a slightly different way
sz = pxmask.shape
pml = da.from_array(pxmask.reshape(1, sz[-2], sz[-1]),
chunks=(1, sz[-2], sz[-1]))
pml = da.broadcast_to(pml, img.shape, chunks=img.chunks)
return da.where(pml, replace_val, img)