Source code for pypeit.scripts.rectify_2dspec

"""
Create an FITS file with rectified 2D spectra for all slits/orders.

.. include common links, assuming the primary doc root is up one directory
.. include:: ../include/links.rst
"""

import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt

from pypeit.scripts import scriptbase
from pypeit import spec2dobj
from pypeit import specobjs
from pypeit import log
from pypeit.core import coadd
from pypeit.core.wavecal import wvutils
from pypeit.core.moment import moment1d
from astropy.io import fits

from IPython import embed


[docs] class Rectify2DSpec(scriptbase.ScriptBase):
[docs] @classmethod def get_parser(cls, width=None): parser = super().get_parser(description='Create an FITS file with rectified ' '2D sky-subtracted spectra for all slits/orders.', width=width) parser.add_argument('files', type = str, nargs='*', help = 'PypeIt spec2d file(s)') parser.add_argument('--no_rot', default=False, action='store_true', help='Do not rotate the rectified image to have wavelength ' 'on the x-axis.') parser.add_argument('--embed', default=False, action='store_true', help='Embed in IPython shell in each detector loop, i.e., before saving to disk.') parser.add_argument('--try_old', default=False, action='store_true', help='Attempt to load old datamodel versions. A crash may ensue..') return parser
[docs] @staticmethod def main(args): chk_version = not args.try_old for spec2file in args.files: log.info(f'Processing file: {spec2file}') # Get list of detectors hdr = fits.getheader(spec2file) detnames = hdr['HIERARCH ALLSPEC2D_DETS'].split(',') # Empty primary HDU hdu_list = [fits.PrimaryHDU()] for detname in detnames: log.info(f'DETECTOR: {detname}') spec2d = spec2dobj.Spec2DObj.from_file(spec2file, detname, chk_version=chk_version) slitmask = spec2d.slits.slit_img(flexure=spec2d.sci_spat_flexure) slit_ids = spec2d.slits.spat_id # this is just to print useful info in the terminal slitord_ids = spec2d.slits.slitord_id # get the wave grid # Do we have the spec1d file? spec1dfile = spec2file.replace('spec2d', 'spec1d') if Path(spec1dfile).is_file(): sobjs = specobjs.SpecObjs.from_fitsfile(spec1dfile, chk_version=False) on_det = sobjs.DET == detname waves, gpms = [], [] for sobj in sobjs[on_det]: # we use the BOX extraction here for simplicity if sobj.has_box_ext() and np.any(sobj.BOX_MASK): waves.append(sobj.BOX_WAVE) gpms.append(sobj.BOX_MASK) else: # extract the wave and gpm from the spec2d object directly (taken from coadd2d) waves, gpms = [], [] slits_left, slits_right, _ = spec2d.slits.select_edges() row = np.arange(slits_left.shape[0]) box_radius = 3 # pixels # Loop on the slits for kk, spat_id in enumerate(slit_ids): mask = slitmask == spat_id # Create apertures at 5%, 50%, and 95% of the slit width to cover full range of wavelengths # on this slit trace_spat = slits_left[:, kk][:, np.newaxis] + np.outer((slits_right[:, kk] - slits_left[:, kk]), [0.05, 0.5, 0.95]) box_denom = moment1d(spec2d.waveimg * mask > 0.0, trace_spat, 2 * box_radius, row=row)[0] wave_box = moment1d(spec2d.waveimg * mask, trace_spat, 2 * box_radius, row=row)[0] / (box_denom + (box_denom == 0.0)) gpm_box = box_denom > 0. waves += [wave for (wave, gpm) in zip(wave_box.T, gpm_box.T) if np.any(gpm)] gpms += [(wave > 0.) & gpm for (wave, gpm) in zip(wave_box.T, gpm_box.T) if np.any(gpm)] if len(waves) == 0: log.warning(f'There is a problem with the wavelengths on det {detname}. ' 'The RECTIFIED 2D spectral image will not be created.') continue wmax = np.ceil(spec2d.waveimg[spec2d.waveimg>0].max()) wmin = np.floor(spec2d.waveimg[spec2d.waveimg>0].min()) wave_grid, wave_grid_mid, dsamp = wvutils.get_wave_grid(waves=waves, gpms=gpms, wave_grid_min=wmin, wave_grid_max=wmax) # Loop over slits and rectify each one imgrect_list = [] for slitidx, slit_id in enumerate(slit_ids): this_mask = (slitmask == slit_id) # check if this slit was masked. If so, skip it. slitord_id = slitord_ids[slitidx] if not np.any(this_mask): log.warn(f'Slit/order {slitord_id} on {detname} is fully masked. Skipping it.') continue log.info(f'Rectifying slit/order {slitord_id}') slit_cen = spec2d.slits.center[:,slitidx] mask = spec2d.bpmmask.mask == 0 imgrect_dict = coadd.compute_coadd2d([slit_cen], [spec2d.sciimg], [spec2d.ivarmodel],[spec2d.skymodel], [mask],[this_mask], [spec2d.waveimg], wave_grid) imgrect_list.append(imgrect_dict) # Get dimensions for each slit pad = 10 # pixels to pad on each side nspat_vec = np.array([cdict['nspat'] for cdict in imgrect_list]) nspec_rect = len(wave_grid_mid) nspat_rect = int(np.sum(nspat_vec) + (spec2d.slits.nslits + 1) * pad) # Initialize output array # this is the rectified image on a common wavelength grid for all slits image_rect = np.zeros((nspec_rect, nspat_rect)) # the outputs below are not really used, but are useful for double-checking # this is the wave image created by placing the rectified wavelengths for each slit in # their proper location in the rectified image waveimg_rect = np.zeros((nspec_rect, nspat_rect)) # this is the common wave image for the whole rectified image, created by repeating # the wavelength grid for each spatial pixel wavegrid_image = np.repeat(wave_grid_mid[:, np.newaxis], nspat_rect, axis=1) # Loop over the rectified slits and place them in the output arrays spat_left = pad for islit, imgrect_dict in enumerate(imgrect_list): spat_righ = spat_left + nspat_vec[islit] # Get wavelength array for this slit wave_slit = imgrect_dict['wave_mid'] nspec_slit = len(wave_slit) # get the spectral pixel where the slit should start/stop _wave_grid_mid = np.round(wave_grid_mid, 4) _wave_slit = np.round(imgrect_dict['wave_mid'], 4) wstart = np.where(_wave_grid_mid == _wave_slit[0])[0][0] wend = np.where(_wave_grid_mid == _wave_slit[-1])[0][0] + 1 # Place the rectified data for this slit in the output arrays for ispat_slit in range(nspat_vec[islit]): # Get data for this spatial pixel flux = imgrect_dict['imgminsky'][:nspec_slit, ispat_slit] # Assign to output arrays image_rect[wstart:wend, spat_left + ispat_slit] = flux waveimg_rect[wstart:wend, spat_left + ispat_slit] = wave_slit spat_left = spat_righ + pad if args.embed: embed(header=f">>> Rectify2DSpec: useful variables are image_rect, waveimg_rect, wavegrid_image") # Rotate if desired if not args.no_rot: image_rect = np.rot90(image_rect) # Create HDU with WCS header hdu = fits.ImageHDU(image_rect, name=detname) hdu.header['CTYPE1'] = 'LAMBDA ' hdu.header['CUNIT1'] = 'Angstrom' hdu.header['CDELT1'] = dsamp hdu.header['CRPIX1'] = 0 hdu.header['CRVAL1'] = wave_grid_mid[0] hdu.header['CTYPE2'] = 'LINEAR ' hdu.header['CUNIT2'] = 'pixel' hdu.header['CDELT2'] = 1. hdu.header['CRPIX2'] = 0 hdu.header['CRVAL2'] = 0. else: # Create HDU with WCS header hdu = fits.ImageHDU(image_rect, name=detname) hdu.header['CTYPE1'] = 'LINEAR ' hdu.header['CUNIT1'] = 'pixel' hdu.header['CDELT1'] = 1. hdu.header['CRPIX1'] = 0 hdu.header['CRVAL1'] = 0. hdu.header['CTYPE2'] = 'LAMBDA ' hdu.header['CUNIT2'] = 'Angstrom' hdu.header['CDELT2'] = dsamp hdu.header['CRPIX2'] = 0 hdu.header['CRVAL2'] = wave_grid_mid[0] # Add other keywords hdu.header['NSPEC'] = nspec_rect hdu.header['NSPAT'] = nspat_rect hdu.header['WAVEMIN'] = wave_grid_mid[0] hdu.header['WAVEMAX'] = wave_grid_mid[-1] hdu.header['DETNAME'] = detname hdu.header['PIPELINE'] = hdr['PIPELINE'] hdu.header['PYPELINE'] = hdr['PYPELINE'] hdu.header['PYP_SPEC'] = hdr['PYP_SPEC'] hdu_list.append(hdu) # Write to file out_file = spec2file.replace('spec2d', 'rectified_spec2d') hdulist = fits.HDUList(hdu_list) hdulist.writeto(out_file, overwrite=True) log.info(f'Rectified images saved to {out_file}')