Source code for pypeit.core.flexure

""" Module for flexure routines

.. include common links, assuming primary doc root is up one directory
.. include:: ../include/links.rst

"""
import copy
import inspect

from astropy.stats import sigma_clipped_stats
from IPython import embed
from matplotlib import gridspec
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
from scipy import interpolate
from scipy import optimize
from scipy import signal

from pypeit import log
from pypeit import onespec
from pypeit import PypeItError
from pypeit import wavemodel
from pypeit.core import arc
from pypeit.core import extract
from pypeit.core import fitting
from pypeit.core import parse
from pypeit.core import qa
from pypeit.core import skyspec
from pypeit.core import trace
from pypeit.core.wavecal import autoid


[docs] def spat_flexure_shift(sciimg, slits, bpm=None, maxlag=20, sigdetect=10., debug=False, qa_outfile=None, qa_vrange=None): """ Calculate a rigid flexure shift in the spatial dimension between the slitmask and the science image. It is *important* to use original=True when defining the slitmask as everything should be relative to the initial slits Otherwise, the WaveTilts could get out of sync with science images Args: sciimg (`numpy.ndarray`_): Science image slits (:class:`pypeit.slittrace.SlitTraceSet`): Slits object bpm (`numpy.ndarray`_, optional): Bad pixel mask (True = Bad) maxlag (:obj:`int`, optional): Maximum flexure searched for sigdetect (:obj:`float`, optional): Sigma threshold above fluctuations for the slit detection in the collapsed sobel image debug (:obj:`bool`, optional): Run in debug mode qa_outfile (:obj:`str`, optional): Path to the output file where the QA is saved. If None, the QA is not generated. qa_vrange (:obj:`tuple`, optional): Tuple with the vmin and vmax values for the imshow plot in the QA. If None, the vmin and vmax values are calculated from the data. Returns: float: The spatial flexure shift relative to the initial slits """ log.info("Measuring spatial flexure") # Mask -- Includes short slits and those excluded by the user (e.g. ['rdx']['slitspatnum']) slitmask = slits.slit_img(initial=True, exclude_flag=slits.bitmask.exclude_for_flexure) _sciimg = sciimg if slitmask.shape == sciimg.shape \ else arc.resize_mask2arc(slitmask.shape, sciimg) # mask (as much as possible) the objects on the slits to help the cross-correlation # need to copy the bpm to avoid changing the input bpm _bpm = np.zeros_like(_sciimg, dtype=int) if bpm is None else copy.deepcopy(bpm) for i in range(slits.nslits): left_edge = np.round(slits.left_init[:, i]).astype(int) right_edge = np.round(slits.right_init[:, i]).astype(int) for j in range(_sciimg.shape[0]): # mask the region between the left and right edges leaving a margin of maxlag pixels if left_edge[j]+maxlag < right_edge[j]-maxlag: _bpm[j, left_edge[j]+maxlag:right_edge[j]-maxlag] = 1 # # create sobel images of both slitmask and the science image sci_sobel, sci_edges = trace.detect_slit_edges(_sciimg, bpm=_bpm, sigdetect=sigdetect) slits_sobel, slits_edges = trace.detect_slit_edges(slitmask, bpm=bpm, sigdetect=1.) corr = signal.fftconvolve(sci_edges, np.fliplr(slits_edges), mode='same', axes=1) xcorr = np.sum(corr, axis=0) lags = signal.correlation_lags(sci_edges.shape[1], slits_edges.shape[1], mode='same') lag0 = np.where(lags == 0)[0][0] xcorr_max = xcorr[lag0 - maxlag:lag0 + maxlag] lags_max = lags[lag0 - maxlag:lag0 + maxlag] # detect the highest peak in the cross-correlation _, _, pix_max, _, _, _, _, _ = arc.detect_lines(xcorr_max, cont_subtract=False, input_thresh=0., nfind=1, debug=debug) # No peak? -- e.g. data fills the entire detector if (len(pix_max) == 0) or pix_max[0] == -999.0: log.warning( 'No peak found in the x-correlation between the traced slits and the science/calib ' 'image. Assuming there is NO SPATIAL FLEXURE.\nIf a flexure is expected, consider ' 'either changing the maximum lag for the cross-correlation, or the ' '"spat_flexure_sigdetect" parameter, or use the manual flexure correction.' ) return 0. lag0_max = np.where(lags_max == 0)[0][0] shift = round(pix_max[0] - lag0_max, 3) log.info('Spatial flexure measured: {}'.format(shift)) if debug: # 1D plot of the cross-correlation plt.figure(figsize=(10, 6)) plt.minorticks_on() plt.tick_params(axis='both', direction='in', top=True, right=True, which='both') # plot xcorr_max but add a buffer of 20 pixels on each side pad = 20 _xcorr_max = xcorr[lag0 - (maxlag+pad):lag0 + (maxlag+pad)] _lags_max = lags[lag0 - (maxlag+pad):lag0 + (maxlag+pad)] plt.plot(_lags_max, _xcorr_max, 'k-', lw=1) plt.axvline(shift, color='r', linestyle='--', label=f'Measured shift = {shift:.1f} pixels') plt.axvline(maxlag, color='g', linestyle='--', label='Max lag') plt.axvline(-maxlag, color='g', linestyle='--') plt.xlabel('Lag (pixels)') plt.ylabel('Cross-correlation') plt.title('Spatial Flexure Cross-correlation') plt.legend() plt.tight_layout() plt.show() # 2D plot spat_flexure_qa(sciimg, slits, shift, gpm=np.logical_not(bpm), vrange=qa_vrange) if qa_outfile is not None: # Generate the QA plot log.info("Generating QA plot for spatial flexure") spat_flexure_qa(sciimg, slits, shift, gpm=np.logical_not(bpm), vrange=qa_vrange, outfile=qa_outfile) return shift
[docs] def spat_flexure_qa(img, slits, shift, gpm=None, vrange=None, outfile=None): """ Generate QA for the spatial flexure Args: img (`numpy.ndarray`_): Image of the detector slits (:class:`pypeit.slittrace.SlitTraceSet`): Slits object shift (:obj:`float`): Shift in pixels gpm (`numpy.ndarray`_, optional): Good pixel mask (True = Bad) vrange (:obj:`tuple`, optional): Tuple with the min and max values for the imshow plot outfile (:obj:`str`, optional): Path to the output file where the QA is saved. If None, the QA is shown on screen and not saved. """ debug = True if outfile is None else False # check that vrange is a tuple if vrange is not None and not isinstance(vrange, tuple): log.warning('vrange must be a tuple with the min and max values for the imshow plot. Ignoring vrange.') vrange = None # TODO: should we use initial or tweaked slits in this plot? left_slits, right_slits, mask_slits = slits.select_edges(initial=True, flexure=None) left_flex, right_flex, mask = slits.select_edges(initial=True, flexure=shift) if debug: # where to start and end the plot in the spatial&spectral direction nxsnip = 1 spat_starts = [0] spat_ends = [img.shape[1]] upper_ystart = 0 upper_yend = img.shape[0] else: # where to start and end the plot in the spatial direction xstart = int(np.floor(np.min([left_slits, left_flex]) - 20)) xend = int(np.ceil(np.max([right_slits, right_flex]) + 20)) # how many snippets to plot in the spatial direction if slits.nslits == 1: # if longslit plot 2 snippets, one for the left edge and one for the right edge nxsnip = 2 snippet = int((xend - xstart) // nxsnip) spat_starts = [xstart, xstart + snippet] spat_ends = [xend - snippet, xend] elif slits.nslits <= 12: # if 12 or less slits plot 3-4 snippets equally spaced nxsnip = 3 if slits.nslits <= 6 else 4 snippet = int((xend - xstart) // nxsnip) spat_starts = [xstart, xstart + snippet, xstart + 2*snippet] spat_ends = [xend - 2*snippet, xend - snippet, xend] if slits.nslits > 6: # add the 4th snippet spat_starts.append(xstart + 3*snippet) spat_ends.insert(0, xend - 3*snippet) else: # if more than 12 slits plot 4 snippets nxsnip = 4 # approximately, we want 3 slits in each snippet snippet = int(3 * (xend - xstart)/slits.nslits) # this would give nx many snippets nx = int((xend - xstart) // snippet) # but we want to plot only nxsnip of those snippets spat_starts = [xstart + i * snippet for i in np.linspace(0, nx - 1, nxsnip, dtype=int)] spat_ends = [xstart + i * snippet for i in np.linspace(1, nx, nxsnip, dtype=int)] # where to start and end the plot in the spectral direction for both the upper and lower sections lower_ystart = 0 lower_yend = int(snippet) upper_ystart = int(img.shape[0] - snippet) upper_yend = img.shape[0] # plot the spatial flexure rows = 1 if debug else 2 fig = plt.figure(figsize=(9, 8) if debug else (nxsnip*4, 8)) gs = gridspec.GridSpec(rows, nxsnip, figure=fig) # spectral vector for plotting the slits spec = np.tile(np.arange(slits.nspec), (slits.nslits, 1)).T thin = 10 # legend elements legend_elements = [Line2D([0], [0], color='C3', lw=1, ls='--', label='initial left edges'), Line2D([0], [0], color='C1', lw=1, ls='--', label='initial right edges'), Line2D([0], [0], color='C3', lw=1, label='shifted left edges'), Line2D([0], [0], color='C1', lw=1, label='shifted right edges')] # loop over the 2 rows if we save the plot in the output directory, otherwise plot the whole detector for r in range(rows): _ystar, _yend = (upper_ystart, upper_yend) if r == 0 else (lower_ystart, lower_yend) # loop over the snippets for s in range(nxsnip): ax = fig.add_subplot(gs[r, s]) if vrange is None: # get vmin and vmax for imshow _xstart = spat_starts[s] if spat_starts[s] >= 0 else 0 _xend = spat_ends[s] if spat_ends[s] <= img.shape[1] else img.shape[1] _img = img[_ystar:_yend, _xstart:_xend] _gpm = gpm[_ystar:_yend, _xstart:_xend] if gpm is not None else np.ones_like(_img, dtype=bool) m, med, sig = sigma_clipped_stats(_img[_gpm], sigma_lower=5.0, sigma_upper=5.0) vmin = m - 1.0 * sig vmax = m + 4.0 * sig else: vmin, vmax = vrange # imshow img instead of _img to show the actual pixel values in each snippet ax.imshow(img, origin='lower', vmin=vmin, vmax=vmax) ax.set_ylim(_ystar, _yend) ax.set_xlim(spat_starts[s], spat_ends[s]) # plot the slits for i in range(slits.nslits): plt.plot(left_slits[::thin, i], spec[::thin, i], color='C3', lw=1, ls='--', zorder=5) plt.plot(right_slits[::thin, i], spec[::thin, i], color='C1', lw=1, ls='--', zorder=5) plt.plot(left_flex[::thin, i], spec[::thin, i], color='C3', lw=1, zorder=6) plt.plot(right_flex[::thin, i], spec[::thin, i], color='C1', lw=1, zorder=6) ax.tick_params(axis='both', labelsize=6) if r == 0 and s == 0: plt.suptitle(f'Shift={shift:.1f} pixels', fontsize=18) ax.legend(handles=legend_elements, fontsize=7) if not debug: ax.set_ylabel('Upper snippets', fontsize=18) elif r == 1 and s == 0: ax.set_ylabel('Lower snippets', fontsize=18) plt.tight_layout() if debug: plt.show() else: fig.savefig(outfile, dpi=200) plt.close(fig)
[docs] def spec_flex_shift(obj_skyspec, sky_file=None, arx_skyspec=None, arx_fwhm_pix=None, spec_fwhm_pix=None, mxshft=20, excess_shft="crash", method="boxcar", minwave=None, maxwave=None): """ Calculate shift between object sky spectrum and archive sky spectrum Args: obj_skyspec (:class:`~pypeit.onespec.OneSpec`): Spectrum of the sky related to our object sky_file (:obj:`str`, optional): Name of the archival sky file. If equal to 'model', instead, a model sky spectrum will be generated using :func:`~pypeit.wavemodel.nearIR_modelsky` and the spectral resolution of obj_skyspec. If None, arx_skyspec and arx_fwhm_pix must be provided. arx_skyspec (:class:`~pypeit.onespec.OneSpec`, optional): Archived sky spectrum. If None, it will be loaded from the sky_file (sky_file must be provided). arx_fwhm_pix (:obj:`float`, optional): Spectral FWHM (in pixels) of the archived sky spectrum. If None, it will be calculated using sky_file (sky_file must be provided). spec_fwhm_pix (:obj:`float`, optional): Spectral FWHM (in pixels) of the sky spectrum related to our object/slit. mxshft (:obj:`int`, optional): Maximum allowed shift from flexure; note there are cases that have been known to exceed even 30 pixels. excess_shft (:obj:`str`, optional): Behavior of the code when a measured flexure exceeds ``mxshft``. Options are "crash", "set_to_zero", and "continue", where "set_to_zero" sets the shift to zero and moves on, and "continue" simply uses the large flexure shift value. method (:obj:`str`, optional): Which method is used for the spectral flexure correction. Two methods are available: 'boxcar' and 'slitcen' (see spec_flexure_slit()). In this routine, 'method' is only passed to final dict. minwave (:obj:`float`, optional): Minimum wavelength to use for the correlation. If ``None`` or less than the minumum wavelength of either ``obj_skyspec`` or ``arx_skyspec``, this has no effect. Default is None. maxwave (:obj:`float`, optional): Maximum wavelength to use for the correlation. If ``None`` or greater than the maximum wavelength of either ``obj_skyspec`` or ``arx_skyspec``, this has no effect. Default is None. Returns: dict: Contains flexure info. Keys are: - polyfit= fit to the cross-correlation - shift= best shift in pixels - subpix= subpixelation of input spectrum - corr= correlation function - sky_spec= object sky spectrum used (rebinned, etc.) - arx_spec= archived sky spectrum used - corr_cen= center of the correlation function - smooth= Degree of smoothing of input spectrum to match archive """ # TODO None of these routines should have dependencies on XSpectrum1d! # Check input mode if sky_file is None and arx_skyspec is None: raise PypeItError("sky_file or arx_skyspec must be provided") elif sky_file is not None and arx_skyspec is not None: log.warning("sky_file and arx_skyspec both provided. Using arx_skyspec.") sky_file = None # Arxiv sky spectrum if sky_file is not None: # Load arxiv sky spectrum log.info("Loading the arxiv sky spectrum and computing its spectral FWHM") arx_skyspec, arx_fwhm_pix = get_archive_spectrum(sky_file, obj_skyspec=obj_skyspec, spec_fwhm_pix=spec_fwhm_pix) elif arx_fwhm_pix is None: # get arxiv sky spectrum resolution (FWHM in pixels) log.info("Computing the spectral FWHM for the provided arxiv sky spectrum") arx_fwhm_pix = autoid.measure_fwhm(arx_skyspec.flux, sigdetect=4., fwhm=4.) if arx_fwhm_pix is None: raise PypeItError('Failed to measure the spectral FWHM of the archived sky spectrum. ' 'Not enough sky lines detected. Provide a value using arx_fwhm_pix') # initialize smooth_fwhm_pix smooth_fwhm_pix = None # smooth to the same resolution as the object sky spectrum? Yes, if not using a model sky if sky_file != 'model': # get gaussian sigma (pixels) for smoothing smooth_fwhm_pix = get_fwhm_gauss_smooth(arx_skyspec, obj_skyspec, arx_fwhm_pix, spec_fwhm_pix=spec_fwhm_pix) if smooth_fwhm_pix is None: # smooth_fwhm_pix is None if spec_fwhm_pix<0, i.e., the wavelength calibration is bad log.warning('No flexure correction could be computed for this slit/object') return None if smooth_fwhm_pix > 0: arx_skyspec = arx_skyspec.gauss_smooth(smooth_fwhm_pix) # Determine region of wavelength overlap minwave = 0 if minwave is None else minwave maxwave = np.inf if maxwave is None else maxwave min_wave = max(np.amin(arx_skyspec.wave), np.amin(obj_skyspec.wave), minwave) max_wave = min(np.amax(arx_skyspec.wave), np.amax(obj_skyspec.wave), maxwave) # Define wavelengths of overlapping spectra keep_idx = np.where((obj_skyspec.wave>=min_wave) & (obj_skyspec.wave<=max_wave))[0] # Rebin both spectra onto overlapped wavelength range if len(keep_idx) <= 50: log.warning("Not enough overlap between sky spectra") return None # rebin onto object ALWAYS keep_wave = obj_skyspec.wave[keep_idx] arx_skyspec = arx_skyspec.rebin(keep_wave) obj_skyspec = obj_skyspec.rebin(keep_wave) # Deal with bad pixels log.debug("Need to mask bad pixels") # Trim edges (rebinning is junk there) arx_skyspec.flux[:2] = 0. arx_skyspec.flux[-2:] = 0. obj_skyspec.flux[:2] = 0. obj_skyspec.flux[-2:] = 0. # Set minimum to 0. For bad rebinning and for pernicious extractions # obj_skyspec.flux[:] = np.maximum(obj_skyspec.flux[:], 0.) # arx_skyspec.flux[:] = np.maximum(arx_skyspec.flux[:], 0.) obj_skyspec.flux = np.clip(obj_skyspec.flux, a_min=0., a_max=None) arx_skyspec.flux = np.clip(arx_skyspec.flux, a_min=0., a_max=None) # clip too large values (>90%) only in obj_skyspec (assuming arx_skyspec is being vetted before) # this is used ony for the cross-correlation obj_skyspec_flux = obj_skyspec.flux _lower, _upper = get_percentile_clipping(obj_skyspec_flux, percent=90.0) obj_skyspec_flux = np.clip(obj_skyspec_flux, _lower, _upper) # Normalize spectra to unit average sky count norm = np.sum(obj_skyspec_flux)/obj_skyspec.npix norm2 = np.sum(arx_skyspec.flux)/arx_skyspec.npix if norm <= 0: log.warning("Bad normalization of object in flexure algorithm") log.warning("Will try the median") norm = np.median(obj_skyspec_flux) if norm <= 0: log.warning("Improper sky spectrum for flexure. Is it too faint??") return None if norm2 <= 0: log.warning('Bad normalization of archive in flexure. You are probably using wavelengths ' 'well beyond the archive.') return None obj_skyspec_flux = obj_skyspec_flux / norm arx_skyspec.flux = arx_skyspec.flux / norm2 # Subtract continuum and apply a ceiling to the spectra percent_ceil = 50. # obj_skyspec _, obj_ampl, _, _, _, _, obj_sky_flux, _ = arc.detect_lines(obj_skyspec_flux, sigdetect=5.0) if obj_ampl.size > 0: obj_lower, obj_upper = get_percentile_clipping(obj_ampl, percent=percent_ceil) obj_sky_flux = np.clip(obj_sky_flux, obj_lower, obj_upper) # arx_skyspec _, arx_ampl, _, _, _, _, arx_sky_flux, _ = arc.detect_lines(arx_skyspec.flux, sigdetect=5.0) if arx_ampl.size > 0: arx_lower, arx_upper = get_percentile_clipping(arx_ampl, percent=percent_ceil) arx_sky_flux = np.clip(arx_sky_flux, arx_lower, arx_upper) # # # Consider sharpness filtering (e.g. LowRedux) # log.debug("Consider taking median first [5 pixel]") # Cross correlation of spectra corr = np.correlate(arx_sky_flux, obj_sky_flux, "same") # Create array around the max of the correlation function for fitting for subpixel max # Restrict to pixels within maxshift of zero lag lag0 = corr.size//2 max_corr = np.argmax(corr[lag0-mxshft:lag0+mxshft]) + lag0-mxshft subpix_grid = np.linspace(max_corr-3., max_corr+3., 7) # Fit a 2-degree polynomial to peak of correlation function. JFH added this if/else to not crash for bad slits if np.any(np.isfinite(corr[subpix_grid.astype(int)])): fit = fitting.PypeItFit(xval=subpix_grid, yval=corr[subpix_grid.astype(int)], func='polynomial', order=np.atleast_1d(2)) fit.fit() max_fit = -0.5 * fit.fitc[1] / fit.fitc[2] shift = float(max_fit) - lag0 # Deal with the case of shifts greater than ``mxshft`` # We need to compare the absolute value of shift to ``mxshft``, since shift can be # positive or negative, while ``mxshft`` is generally only positive # We use the int of abs(shift) to avoid to trigger the error/warning for differences <1pixel # TODO :: I'm not convinced that we need int here... if int(abs(shift)) > mxshft: log.warning(f"Computed shift {shift:.1f} pix is " f"larger than specified maximum {mxshft} pix.") if excess_shft == "crash": raise PypeItError( "Flexure compensation failed for one of your\n" "objects. Either adjust the \"spec_maxshift\"\n" "FlexurePar Keyword, or see the flexure documentation\n" "for information on how to bypass this error using the\n" "\"excessive_shift\" keyword.\n" "https://pypeit.readthedocs.io/en/release/flexure.html" ) elif excess_shft == "set_to_zero": log.warning("Flexure compensation failed for one of your objects.") log.warning("Setting the flexure correction shift to 0 pixels.") # Return the usual dictionary, but with a shift == 0 shift = 0.0 elif excess_shft == "continue": log.warning("Applying flexure shift larger than specified max!") elif excess_shft == "use_median": log.warning("Will try to use a flexure shift from other slit/object. " "If not available, flexure correction will not be applied.") return None else: raise PypeItError(f"FlexurePar Keyword excessive_shift = \"{excess_shft}\" " "not recognized.") log.info(f"Flexure correction of {shift:.3f} pixels") else: fit = fitting.PypeItFit(xval=subpix_grid, yval=0.0*subpix_grid, func='polynomial', order=np.atleast_1d(2)) fit.fit() log.warning('Flexure compensation failed for one of your objects') return None return dict(polyfit=fit, shift=shift, subpix=subpix_grid, corr=corr[subpix_grid.astype(int)], sky_spec=obj_skyspec, arx_spec=arx_skyspec, corr_cen=lag0, smooth=smooth_fwhm_pix, method=method)
[docs] def get_percentile_clipping(arr, percent=90.0): """ Get the values for clipping based on a percentile Args: arr (`numpy.ndarray`_): Array to clip. percent (:obj:`float`): Percentile to clip at. Default is 90.0 Returns: :obj:`float`: Lower value for clipping :obj:`float`: Upper value for clipping """ lower = np.percentile(arr[arr < 0.0], percent) if np.any(arr < 0.0) else 0.0 upper = np.percentile(arr[arr >= 0.0], percent) if np.any(arr >= 0.0) else 0.0 return lower, upper
[docs] def get_fwhm_gauss_smooth(arx_skyspec, obj_skyspec, arx_fwhm_pix, spec_fwhm_pix=None): """ Args: arx_skyspec (:class:`~pypeit.onespec.OneSpec`): Archived sky spectrum. obj_skyspec (:class:`~pypeit.onespec.OneSpec`): Sky spectrum associated with the science target. arx_fwhm_pix (:obj:`float`): Spectral FWHM (in pixels) of the archived sky spectrum. spec_fwhm_pix (:obj:`float`, optional): Spectral FWHM (in pixels) of the sky spectrum related to our object. Returns: :obj:`float`: FWHM of the smoothing Gaussian in pixels. """ # determine object spectral FWHM (in Angstrom) using obj_skyspec # if spec_fwhm_pix (typically from wave calibration) is None if spec_fwhm_pix is None: # pixels spec_fwhm_pix = autoid.measure_fwhm(obj_skyspec.flux, sigdetect=4., fwhm=4.) log.info('Measuring spectral FWHM using the boxcar extracted sky spectrum.') if spec_fwhm_pix is None: log.warning('Failed to measure the spectral FWHM using the boxcar extracted sky spectrum. ' 'Not enough sky lines detected.') return None # object sky spectral dispersion (Angstrom/pixel) obj_disp = np.median(np.diff(obj_skyspec.wave)) # Angstrom spec_fwhm = spec_fwhm_pix * obj_disp # determine arxiv sky spectral FWHM (in Angstrom) # arxiv sky spectral dispersion (Angstrom/pixel) arx_disp = np.median(np.diff(arx_skyspec.wave)) arx_fwhm = arx_fwhm_pix * arx_disp log.info(f"Resolution (FWHM) of Archive={arx_fwhm:.2f} Ang and Observation={spec_fwhm:.2f} Ang") if spec_fwhm <= 0: log.warning('Negative spectral FWHM, likely due to a bad wavelength calibration.') return None # Determine fwhm of the smoothing gaussian # object sky spectral fwhm (Angstrom) obj_med_fwhm2 = np.power(spec_fwhm, 2) # arxiv sky spectral fwhm (Angstrom) arx_med_fwhm2 = np.power(arx_fwhm, 2) if obj_med_fwhm2 >= arx_med_fwhm2: smooth_fwhm = np.sqrt(obj_med_fwhm2-arx_med_fwhm2) # Ang smooth_fwhm_pix = smooth_fwhm / arx_disp else: log.warning("Prefer archival sky spectrum to have higher resolution") smooth_fwhm_pix = 0. log.warning("New Sky has higher resolution than Archive. Not smoothing") return smooth_fwhm_pix
[docs] def flexure_interp(shift, wave): """ Perform interpolation on wave given a shift in pixels Args: shift (float): Shift in pixels wave (`numpy.ndarray`_): extracted wave of size nspec wavein (`numpy.ndarray`_, optional): Apply the shift to this array of wavelengths Returns: `numpy.ndarray`_: Wavelength scale corrected for spectral flexure """ npix = wave.size x = np.linspace(0., 1., npix) f = interpolate.interp1d(x, wave, bounds_error=False, fill_value="extrapolate") return f(x + shift / (npix - 1))
[docs] def spec_flex_shift_global(slit_specs, islit, sky_file, empty_flex_dict, return_later_slits, flex_list, keys_to_update, spec_fwhm_pix=None, mxshft=20, excess_shft="crash", method='slitcen', minwave=None, maxwave=None): """ Calculate flexure shifts using the sky spectrum extracted at the center of the slit Args: slit_specs (:obj:`list`): A list of :class:`~pypeit.onespec.OneSpec` objects The spectra stored in this list are sky spectra, extracted from the center of each slit. islit (:obj:`int`): Index of the slit where the sky spectrum related to our object is. sky_file (`str`): Name of the archival sky file. If equal to 'model', instead, a model sky spectrum will be generated using :func:`~pypeit.wavemodel.nearIR_modelsky` and the spectral resolution of each spectrum from slit_specs. empty_flex_dict (:obj:`dict`): Empty dictionary to be filled with flexure results. return_later_slits (:obj:`list`): List of slit indexes that failed the shift calcultion and we want to come back to to assign a value from a different slit. flex_list (:obj:`list`): A list of :obj:`dict` objects containing flexure results of each slit. keys_to_update (:obj:`list`): List of flexure dictionary keys that we need to update. spec_fwhm_pix (:obj:`float`, optional): Spectral FWHM (in pixels) of the sky spectrum related to our object. mxshft (:obj:`int`, optional): Maximum allowed shift from flexure. Passed to spec_flex_shift(). excess_shft (:obj:`str`, optional): Behavior of the code when a measured flexure exceeds ``mxshft``. Passed to spec_flex_shift() method (:obj:`str`, optional): Which method is used for the spectral flexure correction. Two methods are available: 'boxcar' and 'slitcen' (see spec_flexure_slit()). Passed to spec_flex_shift(). minwave (:obj:`float`, optional): Minimum wavelength to use for the correlation. If ``None`` or less than the minumum wavelength of either this sky or ``sky_spectrum``, this has no effect. Default is None. maxwave (:obj:`float`, optional): Maximum wavelength to use for the correlation. If ``None`` or greater than the maximum wavelength of either this sky or ``sky_spectrum``, this has no effect. Default is None. Returns: :obj:`list`: A list of :obj:`dict` objects containing flexure results of each slit. This is filled with a basically empty dict if the shift calculation failed for the relevant slit. """ # Reset the flexure dictionary flex_dict = copy.deepcopy(empty_flex_dict) # Calculate the shift fdict = spec_flex_shift(slit_specs[islit], sky_file=sky_file, mxshft=mxshft, excess_shft=excess_shft, spec_fwhm_pix=spec_fwhm_pix, method=method, minwave=minwave, maxwave=maxwave) # Was it successful? if fdict is not None: # Update dict for key in keys_to_update[:-1]: flex_dict[key].append(fdict[key]) # Interpolate sky_wave_new = flexure_interp(fdict['shift'], slit_specs[islit].wave) flex_dict['sky_spec'].append(onespec.OneSpec(sky_wave_new, None, slit_specs[islit].flux)) else: # No success, come back to it later return_later_slits.append(islit) log.warning("Flexure shift calculation failed for this slit.") log.info("Will come back to this slit to attempt " "to use saved estimates from other slits") # Append flex_dict, which will be an empty dictionary if the flexure failed for the all the slits flex_list.append(flex_dict.copy()) return flex_list
[docs] def spec_flex_shift_local(slits, slitord, specobjs, islit, sky_file, empty_flex_dict, return_later_slits, flex_list, keys_to_update, spec_fwhm_pix=None, mxshft=20, excess_shft="crash", method='boxcar', minwave=None, maxwave=None): """ Calculate flexure shifts using the sky spectrum boxcar-extracted at the location of the detected objects Args: slits (:class:`~pypeit.slittrace.SlitTraceSet`): Slit trace set. slitord (`numpy.ndarray`_): Array of slit/order numbers. specobjs (:class:`~pypeit.specobjs.SpecObjs`, optional): Spectral extractions. islit (:obj:`int`): Index of the slit where the sky spectrum related to our object is. sky_file (`str`): Name of the archival sky file. If equal to 'model', instead, a model sky spectrum will be generated using :func:`~pypeit.wavemodel.nearIR_modelsky` and the spectral resolution of each spectrum in specobjs. empty_flex_dict (:obj:`dict`): Empty dictionary to be filled with flexure results. return_later_slits (:obj:`list`): List of slit indexes that failed the shift calcultion and we want to come back to to assign a value from a different slit. flex_list (:obj:`list`): A list of :obj:`dict` objects containing flexure results of each slit. keys_to_update (:obj:`list`): List of flexure dictionary keys that we need to update. spec_fwhm_pix (:obj:`float`, optional): Spectral FWHM (in pixels) of the sky spectrum related to our object. mxshft (:obj:`int`, optional): Maximum allowed shift from flexure. Passed to spec_flex_shift(). excess_shft (:obj:`str`, optional): Behavior of the code when a measured flexure exceeds ``mxshft``. Passed to spec_flex_shift() method (:obj:`str`, optional): Which method is used for the spectral flexure correction. Two methods are available: 'boxcar' and 'slitcen' (see spec_flexure_slit()). Passed to spec_flex_shift(). minwave (:obj:`float`, optional): Minimum wavelength to use for the correlation. If ``None`` or less than the minumum wavelength of either this sky or ``sky_spectrum``, this has no effect. Default is None. maxwave (:obj:`float`, optional): Maximum wavelength to use for the correlation. If ``None`` or greater than the maximum wavelength of either this sky or ``sky_spectrum``, this has no effect. Default is None. Returns: :obj:`list`: A list of :obj:`dict` objects containing flexure results of each slit. This is filled with a basically empty dict if the shift calculation failed for the relevant slit. """ # Reset the flexure dictionary flex_dict = copy.deepcopy(empty_flex_dict) # get objects in this slit i_slitord = slitord[islit] indx = specobjs.slitorder_indices(i_slitord) this_specobjs = specobjs[indx] # if no objects in this slit, append an empty dict if len(this_specobjs) == 0: log.info('No object extracted in this slit.') flex_list.append(empty_flex_dict.copy()) return flex_list # Objects in this slit that failed and we want to come back to # to assign values from other objects in the same slit (if available) return_later_sobjs = [] # Loop through objects for ss, sobj in enumerate(this_specobjs): if sobj is None or sobj['BOX_WAVE'] is None: # Nothing extracted; only the trace exists log.info(f'Object # {ss} was not extracted.') # Update dict for key in keys_to_update: # append None flex_dict[key].append(None) continue log.info(f"Working on spectral flexure for object # {ss} in slit {slits.spat_id[islit]}") # get 1D spectrum for this object obj_sky = onespec.OneSpec( sobj.BOX_WAVE[sobj.BOX_MASK], None, sobj.BOX_COUNTS_SKY[sobj.BOX_MASK] ) # Calculate the shift fdict = spec_flex_shift(obj_sky, sky_file=sky_file, mxshft=mxshft, excess_shft=excess_shft, spec_fwhm_pix=spec_fwhm_pix, method=method, minwave=minwave, maxwave=maxwave) if fdict is not None: # Update dict for key in keys_to_update: flex_dict[key].append(fdict[key]) else: # No success, come back to it later return_later_sobjs.append(ss) log.warning("Flexure shift calculation failed for this spectrum.") log.info("Will come back to this spectrum to attempt " "to use saved estimates from other slits/objects") # Check if we need to go back if (len(return_later_sobjs) > 0) and (len(flex_dict['shift']) > 0): log.warning(f'Flexure shift calculation failed for {len(return_later_sobjs)} ' f'object(s) in slit {slits.spat_id[islit]}') # get the median shift among all objects in this slit idx_med_shift = np.where(flex_dict['shift'] == np.percentile(flex_dict['shift'], 50, method='nearest'))[0][0] log.info(f"Median value of the measured flexure shifts in this slit, equal to " f"{flex_dict['shift'][idx_med_shift]:.3f} pixels, will be used") # assign the median shift to the failed objects for obj_idx in return_later_sobjs: # Update dict for key in keys_to_update[:-1]: # insert the median value at the location of the object that failed the calculation flex_dict[key].insert(obj_idx, flex_dict[key][idx_med_shift]) # Interpolate sky_wave_new = flexure_interp(flex_dict['shift'][obj_idx], this_specobjs[obj_idx].BOX_WAVE) flex_dict['sky_spec'].insert( obj_idx, onespec.OneSpec(sky_wave_new, None, this_specobjs[obj_idx].BOX_COUNTS_SKY) ) # if flexure failed for every objects in this slit, save for later to use value from other slits elif (len(return_later_sobjs) > 0) and (len(flex_dict['shift']) == 0): return_later_slits.append(islit) # Append flex_dict, which will be an empty dictionary if the flexure failed for the whole slit flex_list.append(flex_dict.copy()) return flex_list
[docs] def spec_flexure_slit(slits, slitord, slit_bpm, sky_file, method="boxcar", specobjs=None, slit_specs=None, wv_calib=None, mxshft=None, excess_shft="crash", minwave=None, maxwave=None): """Calculate the spectral flexure for every slit (global) or object (local) Args: slits (:class:`~pypeit.slittrace.SlitTraceSet`): Slit trace set slitord (`numpy.ndarray`_): Array of slit/order numbers slit_bpm (`numpy.ndarray`_): True = masked slit sky_file (:obj:`str`): Name of the archival sky file. If equal to 'model', instead, a model sky spectrum will be generated using :func:`~pypeit.wavemodel.nearIR_modelsky` and the spectral resolution of each spectrum that we want to correct for flexure. method (:obj:`str`, optional): Two methods are available: - 'boxcar': Recommended for object extractions. This method uses the boxcar extracted sky and wavelength spectra from the input specobjs - 'slitcen': Recommended when no objects are being extracted. This method uses a spectrum (stored in slitspecs) that is extracted from the center of each slit. specobjs (:class:`~pypeit.specobjs.SpecObjs`, optional): Spectral extractions slit_specs (:obj:`list`, optional): A list of :class:`~pypeit.onespec.OneSpec`, one for each slit. The spectra stored in this list are sky spectra, extracted from the center of each slit. This is only used if ``method='slitcen'``. wv_calib (:class:`pypeit.wavecalib.WaveCalib`): Wavelength calibration object mxshft (:obj:`int`, optional): Passed to spec_flex_shift() excess_shft (:obj:`str`, optional): Passed to spec_flex_shift() minwave (:obj:`float`, optional): Minimum wavelength to use for the correlation. If ``None`` or less than the minumum wavelength of either this sky or ``sky_spectrum``, this has no effect. Default is None. maxwave (:obj:`float`, optional): Maximum wavelength to use for the correlation. If ``None`` or greater than the maximum wavelength of either this sky or ``sky_spectrum``, this has no effect. Default is None. Returns: :obj:`list`: A list of :obj:`dict` objects containing flexure results of each slit. This is filled with a basically empty dict if the slit is skipped. """ log.debug("Consider doing 2 passes in flexure as in LowRedux") # Determine the method slit_cen = True if (specobjs is None) or (method == "slitcen") else False # Initialise the flexure list for each slit flex_list = [] # initiate list of slits to come back to if flexure calculation failed return_later_slits = [] # empty dict empty_flex_dict = dict(polyfit=[], shift=[], subpix=[], corr=[], corr_cen=[], spec_file=sky_file, smooth=[], arx_spec=[], sky_spec=[], method=[]) # flex dict keys that we need to update through the routine keys_to_update = ['polyfit', 'shift', 'subpix', 'corr', 'corr_cen', 'smooth', 'method', 'arx_spec', 'sky_spec'] # Loop over slits # good slits gdslits = np.where(np.logical_not(slit_bpm))[0] for islit in range(slits.nslits): log.info(f"Working on spectral flexure of slit: {slits.spat_id[islit]}") # If no objects on this slit append an empty dictionary if islit not in gdslits: flex_list.append(empty_flex_dict.copy()) continue # get spectral FWHM (in pixels) if available spec_fwhm_pix = None if wv_calib is not None: # Allow for wavelength failures if wv_calib.fwhm_map is not None: # Evaluate the spectral FWHM at the centre of the slit (in both the spectral and spatial directions) spec_fwhm_pix = wv_calib.fwhm_map[islit].eval(slits.nspec/2, 0.5) if slit_cen: # global flexure flex_list = spec_flex_shift_global(slit_specs, islit, sky_file, empty_flex_dict, return_later_slits, flex_list, keys_to_update, spec_fwhm_pix=spec_fwhm_pix, mxshft=mxshft, excess_shft=excess_shft, minwave=minwave, maxwave=maxwave) else: # local flexure flex_list = spec_flex_shift_local(slits, slitord, specobjs, islit, sky_file, empty_flex_dict, return_later_slits, flex_list, keys_to_update, spec_fwhm_pix=spec_fwhm_pix, mxshft=mxshft, excess_shft=excess_shft, minwave=minwave, maxwave=maxwave) # Check if we need to go back to some failed slits if len(return_later_slits) > 0: log.warning(f'Flexure shift calculation failed for {len(return_later_slits)} slits') # take the median value to deal with the cases when there are more than one shift per slit (e.g., local flexure) saved_shifts = np.array([np.percentile(flex['shift'], 50, method='nearest') if len(flex['shift']) > 0 else None for flex in flex_list]) if np.all(saved_shifts == None): # If all the elements in saved_shifts are None means that there are no saved shifts available log.warning(f'No previously saved flexure shift estimates available. ' f'Flexure corrections cannot be performed.') for islit in range(slits.nslits): # we append an empty dictionary flex_list.append(empty_flex_dict.copy()) else: # get the median shift value among all slit med_shift = np.percentile(saved_shifts[saved_shifts!= None], 50, method='nearest') # in which slit the median is? islit_med_shift = np.where(saved_shifts == med_shift)[0][0] log.info(f"Median value of all the measured flexure shifts, equal to " f"{saved_shifts[islit_med_shift]:.3f} pixels, will be used") # global flexure if slit_cen: # get the dict where the med shift is fdict = flex_list[islit_med_shift].copy() # assign fdict to the failed slits for sidx in return_later_slits: # Reset the dict flex_dict = copy.deepcopy(empty_flex_dict) # Update dict for key in keys_to_update[:-1]: flex_dict[key].append(fdict[key][0]) # Interpolate sky_wave_new = flexure_interp(fdict['shift'][0], slit_specs[sidx].wave) flex_dict['sky_spec'].append( onespec.OneSpec(sky_wave_new, None, slit_specs[sidx].flux) ) # insert flex_dict in flex_list at the location of the slit that failed the calculation flex_list[sidx] = flex_dict # local flexure else: # get the dict where the med shift is idx_med_shift = np.where(flex_list[islit_med_shift]['shift'] == med_shift)[0][0] fdict = copy.deepcopy(empty_flex_dict) for key in keys_to_update[:-1]: fdict[key].append(flex_list[islit_med_shift][key][idx_med_shift]) # assign fdict to the failed object for sidx in return_later_slits: # get objects in this slit i_slitord = slitord[sidx] indx = specobjs.slitorder_indices(i_slitord) for i in range(len(specobjs[indx])): # Reset the dict flex_dict = copy.deepcopy(empty_flex_dict) # Update dict for key in keys_to_update[:-1]: flex_dict[key].append(fdict[key][0]) # Interpolate sky_wave_new = flexure_interp(fdict['shift'][0], specobjs[indx][i].BOX_WAVE) flex_dict['sky_spec'].append( onespec.OneSpec(sky_wave_new, None, specobjs[indx][i].BOX_COUNTS_SKY) ) # insert flex_dict in flex_list at the location of the slit that failed the calculation flex_list[sidx] = flex_dict return flex_list
[docs] def spec_flexure_slit_global(sciImg, waveimg, global_sky, par, slits, slitmask, trace_spat, gd_slits, wv_calib, pypeline, det): """Calculate the spectral flexure for every slit Args: sciImg (:class:`~pypeit.images.pypeitimage.PypeItImage`): Science image. waveimg (`numpy.ndarray`_): Wavelength image - shape (nspec, nspat) global_sky (`numpy.ndarray`_): 2D array of the global_sky fit - shape (nspec, nspat) par (:class:`~pypeit.par.pypeitpar.PypeItPar`): Parameters of the reduction. slits (:class:`~pypeit.slittrace.SlitTraceSet`): Slit trace set slitmask (`numpy.ndarray`_): An image with the slit index identified for each pixel (returned from slittrace.slit_img). trace_spat (`numpy.ndarray`_): Spatial pixel values (usually the center of each slit) where the sky spectrum will be extracted. The shape of this array should be (nspec, nslits) gd_slits (`numpy.ndarray`_): True = good slit wv_calib (:class:`pypeit.wavecalib.WaveCalib`): Wavelength calibration pypeline (:obj:`str`): Name of the ``PypeIt`` pipeline method. Allowed options are MultiSlit, Echelle, or IFU. det (:obj:`str`): The name of the detector or mosaic from which the spectrum will be extracted. For example, DET01. Returns: :obj:`list`: A list of :obj:`dict` objects containing flexure results of each slit. This is filled with a basically empty dict if the slit is skipped. """ # TODO :: Need to think about spatial flexure - is the appropriate spatial flexure already included in trace_spat via left/right slits? slit_specs = [] # get boxcar radius. Needs to be in pixels _, binspat = parse.parse_binning(sciImg.detector.binning) box_radius = par['reduce']['extraction']['boxcar_radius'] * sciImg.detector.platescale * binspat for ss in range(slits.nslits): if not gd_slits[ss]: slit_specs.append(None) continue thismask = (slitmask == slits.spat_id[ss]) inmask = sciImg.select_flag(invert=True) & thismask # Pack slit_specs.append(get_sky_spectrum(sciImg.image, sciImg.ivar, waveimg, inmask, global_sky, box_radius, slits, trace_spat[:, ss], pypeline, det)) # Measure flexure flex_list = spec_flexure_slit(slits, slits.slitord_id, np.logical_not(gd_slits), par['flexure']['spectrum'], method=par['flexure']['spec_method'], mxshft=par['flexure']['spec_maxshift'], excess_shft=par['flexure']['excessive_shift'], specobjs=None, slit_specs=slit_specs, wv_calib=wv_calib, minwave=par['flexure']['minwave'], maxwave=par['flexure']['maxwave']) return flex_list
[docs] def get_archive_spectrum(sky_file, obj_skyspec=None, spec_fwhm_pix=None): """ Load an archival sky spectrum Parameters ---------- sky_file : :obj:`str` Name of the archival sky file. If equal to 'model', instead, a model sky spectrum will be generated using :func:`~pypeit.wavemodel.nearIR_modelsky` and the spectral resolution of obj_skyspec. If obj_skyspec is None, then sky_file cannot be 'model'. obj_skyspec : :class:`~pypeit.onespec.OneSpec`, optional Sky spectrum associated with the science target. This must be provided if sky_file is 'model'. spec_fwhm_pix : :obj:`float`, optional Spectral FWHM (in pixels) of the sky spectrum related to our object. Returns ------- sky_spectrum : :class:`~pypeit.onespec.OneSpec` The sky spectrum arx_fwhm_pix : :obj:`float` The FWHM of the sky lines in pixels. """ if sky_file != 'model': # Load Archive. Save the fwhm to avoid the performance hit from calling it on the archive sky spectrum # multiple times sky_spectrum = skyspec.load_sky_spectrum(sky_file) # get arxiv sky spectrum resolution (FWHM in pixels) arx_fwhm_pix = autoid.measure_fwhm(sky_spectrum.flux, sigdetect=4., fwhm=4.) if arx_fwhm_pix is None: raise PypeItError('Failed to measure the spectral FWHM of the archived sky spectrum. ' 'Not enough sky lines detected.') elif obj_skyspec is not None: if spec_fwhm_pix is None: # measure spec_fwhm_pix spec_fwhm_pix = autoid.measure_fwhm(obj_skyspec.flux, sigdetect=4., fwhm=4.) if spec_fwhm_pix is None: log.warning('Failed to measure the spectral FWHM using the boxcar extracted sky spectrum. ' 'Choose one of the provided sky files.') # get the spectral resolution of obj_skyspec # obj_skyspec spectral dispersion (Angstrom/pixel) obj_disp = np.median(np.diff(obj_skyspec.wave)) # FWHM spec_fwhm = spec_fwhm_pix * obj_disp # Compute the resolution at the midpoints of the spectrum in the spectral direction midpix = obj_skyspec.wave.size // 2 # R = lambda / dlambda res = obj_skyspec.wave[midpix] / spec_fwhm # get model sky spectrum wave_sky, flux_sky = wavemodel.nearIR_modelsky(res, (obj_skyspec.wave.min() / 10000., obj_skyspec.wave.max() / 10000.), dlam=obj_disp / 10000., flgd=False) sky_spectrum = onespec.OneSpec(wave_sky, None, flux_sky, fluxed=False) arx_fwhm_pix = spec_fwhm_pix else: raise PypeItError('Archived sky spectrum cannot be loaded. ') return sky_spectrum, arx_fwhm_pix
[docs] def get_sky_spectrum(sciimg, ivar, waveimg, thismask, global_sky, box_radius, slits, trace_spat, pypeline, det): """ Obtain a boxcar extraction of the sky spectrum Args: sciimg (`numpy.ndarray`_): Science image - shape (nspec, nspat) ivar (`numpy.ndarray`_): Inverse variance of the science image - shape (nspec, nspat) waveimg (`numpy.ndarray`_): Wavelength image - shape (nspec, nspat) thismask (`numpy.ndarray`_): Good pixel mask (True=good) that indicates the pixels that should be included in the boxcar extraction global_sky (`numpy.ndarray`_): 2D array of the global_sky fit - shape (nspec, nspat) box_radius (float): Radius of the boxcar extraction (in pixels) slits (:class:`~pypeit.slittrace.SlitTraceSet`): Slit trace set trace_spat (`numpy.ndarray`_): Spatial pixel values (usually the center of each slit) where the sky spectrum will be extracted. The shape of this array should be (nspec, nslits) pypeline (:obj:`str`): Name of the ``PypeIt`` pipeline method. Allowed options are MultiSlit, Echelle, or IFU. det (:obj:`str`): The name of the detector or mosaic from which the spectrum will be extracted. For example, DET01. Returns: :class:`~pypeit.onespec.OneSpec`: The boxcar-extracted sky spectrum. """ wave, _, _, _, _, mask, _, _, counts_sky, _, _ = extract.extract_boxcar( box_radius, trace_spat, sciimg, ivar, thismask, waveimg, global_sky, trace_spec=np.arange(slits.nspec) ) return onespec.OneSpec(wave[mask], None, counts_sky[mask], fluxed=False)
[docs] def spec_flexure_corrQA(ax:plt.Axes, this_flex_dict:dict, cntr:int, name:str): """Spectral Flexure QA Plot Creates one panel of the spectral felxure QA plot, with the overall figure container being handled by the calling function. Parameters ---------- ax Axes onto which to draw the plot this_flex_dict Dictionary of flexure-related information needed for the plot cntr The index into ``this_flex_dict``'s arrays corresponding to the particular object, trace, or location of interest. name Object, trace, or location name to be printed in the plot """ # Fit fit = this_flex_dict['polyfit'][cntr] if fit is not None: xval = np.linspace(-10., 10, 100) + this_flex_dict['corr_cen'][cntr] + this_flex_dict['shift'][cntr] # model = (fit[2]*(xval**2.))+(fit[1]*xval)+fit[0] model = fit.eval(xval) # model = utils.func_val(fit, xval, 'polynomial') mxmod = np.max(model) ylim_min = np.min(model / mxmod) if np.isfinite(np.min(model / mxmod)) else 0.0 ylim = [ylim_min, 1.3] ax.plot(xval - this_flex_dict['corr_cen'][cntr], model / mxmod, 'k-') # Measurements ax.scatter(this_flex_dict['subpix'][cntr] - this_flex_dict['corr_cen'][cntr], this_flex_dict['corr'][cntr] / mxmod, marker='o') # Final shift ax.plot([this_flex_dict['shift'][cntr]] * 2, ylim, 'g:') # Label ax.text(0.5, 0.25, name, transform=ax.transAxes, size='large', ha='center') ax.text(0.5, 0.15, 'flex_shift = {:g}'.format(this_flex_dict['shift'][cntr]), transform=ax.transAxes, size='large', ha='center') # , bbox={'facecolor':'white'}) # Axes ax.set_ylim(ylim) ax.set_xlabel('Lag') else: ax.text(0.5, 0.25, name, transform=ax.transAxes, size='large', ha='center') ax.text(0.5, 0.15, 'flex_shift calculation failed', transform=ax.transAxes, size='large', ha='center') # Axes ax.set_xlabel('Lag')
# TODO: With Python 3.14's deferred evaluation of annotations, may be able # to annotate `specobjs`; however, should really remove PypeIt-specific # objects from `core`.
[docs] def spec_flexure_qa(slitords:np.ndarray, bpm:np.ndarray, basename:str, flex_list:list[dict], specobjs=None, out_dir:str|None=None): """ Generate QA for the spectral flexure calculation Args: slitords (`numpy.ndarray`_): Array of slit/order numbers bpm (`numpy.ndarray`_): Boolean mask; True = masked slit basename (str): Used to generate the output file name flex_list (list): list of :obj:`dict` objects containing the flexure information specobjs (:class:`~pypeit.specobjs.SpecObjs`, optional): Spectrally extracted objects out_dir (str, optional): Path to the output directory for the QA plots. If None, the current is used. """ # Extract the mode and detector from the ``basename`` *_, mode, det = basename.split("_") plt.rcdefaults() plt.rcParams['font.family'] = 'serif' # What type of QA are we doing slit_cen = specobjs is None # Grab the named of the method method = inspect.stack()[0][3] # Mask gdslits = np.where(np.logical_not(bpm))[0] # Loop over slits, and then over objects here for islit in gdslits: # Slit/order number slitord = slitords[islit] this_flex_dict = flex_list[islit] # Check that the default was overwritten if len(this_flex_dict['shift']) == 0 or \ (len(this_flex_dict['shift']) > 0 and np.all([ss is None for ss in this_flex_dict['shift']])): continue # Parse and Setup if slit_cen: nobj = 1 ncol = 1 else: indx = specobjs.slitorder_indices(slitord) this_specobjs = specobjs[indx] nobj = np.sum(indx) if nobj == 0: continue ncol = min(3, nobj) nrow = nobj // ncol + ((nobj % ncol) > 0) # Outfile, one QA file per slit outfile = qa.set_qa_filename( basename, method + '_corr', slit=slitord, det=det, mode=mode, out_dir=out_dir ) plt.figure(figsize=(8, 5.0)) plt.clf() gs = gridspec.GridSpec(nrow, ncol) # Correlation QA if slit_cen: ax = plt.subplot(gs[0, 0]) spec_flexure_corrQA(ax, this_flex_dict, 0, 'Slit Center') else: iplt = 0 for ss, specobj in enumerate(this_specobjs): if specobj is None or (specobj.BOX_WAVE is None and specobj.OPT_WAVE is None): continue ax = plt.subplot(gs[iplt//ncol, iplt % ncol]) spec_flexure_corrQA(ax, this_flex_dict, ss, '{:s}'.format(specobj.NAME)) iplt += 1 # Finish plt.tight_layout(pad=0.2, h_pad=0.0, w_pad=0.0) plt.savefig(outfile)#, dpi=400) plt.close() # Sky line QA (just one object) if slit_cen: iobj = 0 else: # only show the first object in this slit that does not have None shift iobj = np.where([ss is not None for ss in this_flex_dict['shift']])[0][0] specobj = this_specobjs[iobj] # Repackage sky_spec = this_flex_dict['sky_spec'][iobj] arx_spec = this_flex_dict['arx_spec'][iobj] min_wave = max(np.amin(arx_spec.wave), np.amin(sky_spec.wave)) max_wave = min(np.amax(arx_spec.wave), np.amax(sky_spec.wave)) # Sky lines # TODO: Should these be defined / identified somewhere else? Then they # could more easily be included in the documentation. sky_lines = np.array([3370.0, 3914.0, 4046.56, 4358.34, 5577.338, 6300.304, 7340.885, 7993.332, 8430.174, 8919.610, 9439.660, 10013.99, 10372.88]) dwv = 20. gdsky = np.where((sky_lines > min_wave) & (sky_lines < max_wave))[0] if len(gdsky) == 0: log.warning("No sky lines for Flexure QA") continue if len(gdsky) > 6: idx = np.array([0, 1, len(gdsky)//2, len(gdsky)//2+1, -2, -1]) gdsky = gdsky[idx] # Outfile outfile = qa.set_qa_filename( basename, method+'_sky', slit=slitord, det=det, mode=mode, out_dir=out_dir ) # Figure plt.figure(figsize=(8, 5.0)) plt.clf() nrow, ncol = 2, 3 gs = gridspec.GridSpec(nrow, ncol) if slit_cen: plt.suptitle('Sky Comparison for Slit Center', y=0.99) else: plt.suptitle('Sky Comparison for {:s}'.format(specobj.NAME), y=0.99) for ii, igdsky in enumerate(gdsky): skyline = sky_lines[igdsky] ax = plt.subplot(gs[ii//ncol, ii % ncol]) # Norm pix1 = np.where(np.abs(sky_spec.wave-skyline) < dwv)[0] pix2 = np.where(np.abs(arx_spec.wave-skyline) < dwv)[0] f1 = np.sum(sky_spec.flux[pix1]) f2 = np.sum(arx_spec.flux[pix2]) norm = f1/f2 # Plot ax.plot(sky_spec.wave[pix1], sky_spec.flux[pix1], 'k-', label='Obj', drawstyle='steps-mid') ax.plot(arx_spec.wave[pix2], arx_spec.flux[pix2]*norm, 'r-', label='Arx', drawstyle='steps-mid') # Axes ax.xaxis.set_major_locator(plt.MultipleLocator(dwv)) ax.set_xlabel('Wavelength') ax.set_ylabel('Counts') # Legend plt.legend(loc='upper left', scatterpoints=1, borderpad=0.3, handletextpad=0.3, fontsize='small', numpoints=1) # Finish plt.tight_layout(pad=0.2, h_pad=0.0, w_pad=0.0) plt.savefig(outfile)#, dpi=400) plt.close() log.info("Wrote spectral flexure QA: {}".format(outfile)) plt.rcdefaults()
[docs] def calculate_image_phase(imref, imshift, gpm_ref=None, gpm_shift=None, maskval=None): """ Perform a masked cross-correlation and optical flow calculation to robustly estimate the subpixel shifts of two images. If gpm_ref, gpm_shift, and maskval are all None, no pixels will be masked This routine (optionally) requires skimage to calculate the image phase. If skimage is not installed, a standard (unmasked) cross-correlation is used. Parameters ---------- im_ref : `numpy.ndarray`_ Reference image imshift : `numpy.ndarray`_ Image that we want to measure the shift of (relative to im_ref) gpm_ref : `numpy.ndarray`_ Mask of good pixels (True = good) in the reference image gpm_shift : `numpy.ndarray`_ Mask of good pixels (True = good) in the shifted image maskval : float, optional If gpm_ref and gpm_shift are both None, a single value can be specified and this value will be masked in both images. Returns ------- ra_diff : float Relative shift (in pixels) of image relative to im_ref (x direction). In order to align image with im_ref, ra_diff should be added to the x-coordinates of image dec_diff : float Relative shift (in pixels) of image relative to im_ref (y direction). In order to align image with im_ref, dec_diff should be added to the y-coordinates of image """ # Do some checks first try: from skimage.registration import optical_flow_tvl1, phase_cross_correlation except ImportError: log.warning("scikit-image is not installed. Adopting a basic image cross-correlation") return calculate_image_offset(imref, imshift) if imref.shape != imshift.shape: log.warning("Input images shapes are not equal. Adopting a basic image cross-correlation") return calculate_image_offset(imref, imshift) # Set the masks if gpm_ref is None: gpm_ref = np.ones(imref.shape, dtype=bool) if maskval is None else imref != maskval if gpm_shift is None: gpm_shift = np.ones(imshift.shape, dtype=bool) if maskval is None else imshift != maskval # Get a crude estimate of the shift shift, _, _ = phase_cross_correlation(imref, imshift, reference_mask=gpm_ref, moving_mask=gpm_shift) shift = shift.astype(int) # Extract the overlapping portion of the images exref = imref.copy() exshf = imshift.copy() if shift[0] != 0: if shift[0] < 0: exref = exref[:shift[0], :] exshf = exshf[-shift[0]:, :] else: exref = exref[shift[0]:, :] exshf = exshf[:-shift[0], :] if shift[1] != 0: if shift[1] < 0: exref = exref[:, :shift[1]] exshf = exshf[:, -shift[1]:] else: exref = exref[:, shift[1]:] exshf = exshf[:, :-shift[1]] # Compute the flow vector for a fine correction to the cross-correlation v, u = optical_flow_tvl1(exref, exshf) shift = shift.astype(float) shift[0] -= np.median(v) shift[1] -= np.median(u) # Return the total estimated shift return shift[0], shift[1]
[docs] def calculate_image_offset(im_ref, image, nfit=3): """Calculate the x,y offset between two images Args: im_ref (`numpy.ndarray`_): Reference image image (`numpy.ndarray`_): Image that we want to measure the shift of (relative to im_ref) nfit (int, optional): Number of pixels (left and right of the maximum) to include in fitting the peak of the cross correlation. Returns: tuple: Returns two floats, the x and y offset of the image. - ra_diff -- Relative shift (in pixels) of image relative to im_ref (x direction). In order to align image with im_ref, ra_diff should be added to the x-coordinates of image - dec_diff -- Relative shift (in pixels) of image relative to im_ref (y direction). In order to align image with im_ref, dec_diff should be added to the y-coordinates of image """ # Subtract median (should be close to zero, anyway) image -= np.median(image) im_ref -= np.median(im_ref) # cross correlate (note, convolving seems faster) ccorr = signal.correlate2d(im_ref, image, boundary='fill', mode='same') #ccorr = signal.fftconvolve(im_ref, image[::-1, ::-1], mode='same') # Find the maximum amax = np.unravel_index(np.argmax(ccorr), ccorr.shape) # Extract a small region around the maximum, and check the limits xlo, xhi = amax[0]-nfit, amax[0] + nfit+1 ylo, yhi = amax[1]-nfit, amax[1] + nfit+1 if xlo < 0: xlo = 0 if xhi > ccorr.shape[0]-1: xhi = ccorr.shape[0]-1 if ylo < 0: ylo = 0 if yhi > ccorr.shape[1]-1: yhi = ccorr.shape[1]-1 x = np.arange(xlo, xhi) y = np.arange(ylo, yhi) # Setup some initial parameters initial_guess = (np.max(ccorr), amax[0], amax[1], 3, 3, 0, 0) xx, yy = np.meshgrid(x, y, indexing='ij') # Fit the neighborhood of the maximum with a Gaussian to calculate the offset popt, _ = optimize.curve_fit( fitting.twoD_Gaussian, (xx, yy), ccorr[xlo:xhi, ylo:yhi].ravel(), p0=initial_guess ) # Return the RA and DEC shift, in pixels xoff = 1 - (ccorr.shape[0] % 2) # Need to add 1 for even shaped array yoff = 1 - (ccorr.shape[1] % 2) # Need to add 1 for even shaped array return xoff + popt[1] - ccorr.shape[0]//2, yoff+popt[2] - ccorr.shape[1]//2