Source code for pypeit.core.wavecal.kdtree_generator

"""This script is used to generate the KD Tree that is needed for
the kdtree pattern matching wavelength calibration algorithm. At
present, this method is only used for calibrating ThAr lamps.

You should not run this script unless you know what you're doing,
since you could mess up the ThAr patterns that are used in the
wavelength calibration routine. This script should not be called
from within PypeIt - it should be run as a standalone script, and
it's only purpose is to generate a KD Tree with the desired patterns.
"""

# TODO: Is the statement below true?  It seems like this *can be* used, even if
# it's rare.
# NOTE: No longer used.  Use KD tree in scikit-learn:
#   https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html
# See benchmarks here:
#   https://jakevdp.github.io/blog/2013/04/29/benchmarking-nearest-neighbor-searches-in-python/

# TODO: We should not be using pickle.  Consider using the KDTree generator in
# scikit-learn and a skops.io serialization that allows users to validate the
# contents of the file.  See: https://scikit-learn.org/stable/model_persistence.html
import pickle

from astropy.table import vstack
import numpy as np
from scipy.spatial import cKDTree

from pypeit import dataPaths
from pypeit import log
from pypeit.core.wavecal import waveio

[docs] def trigon(linelist, numsrch, maxlin): """ Generate a series of trigon patterns, given an input list of detections or lines from a linelist linelist : ndarray list of wavelength calibration lines (must be sorted by ascending wavelength) numsrch : int Number of consecutive detected lines used to generate a pattern. For example, if numsrch is 4, there are four lines (called 1 2 3 4). The following patterns will be generated (assuming line #1 is the left anchor): 1 2 3 (in this case line #3 is the right anchor) 1 2 4 (in this case line #4 is the right anchor) 1 3 4 (in this case line #4 is the right anchor) maxlin : float Value (in pixels in the case of detections or Angstroms in the case of a linelist) over which the wavelength solution can be considered linear. """ nptn = 3 # Number of lines used to create a pattern sz_l = linelist.shape[0] # Count the number of patterns that will be created cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue for x in range(l + 1, ll): cnt += 1 index = np.zeros((cnt, nptn), dtype=np.uint64) pattern = np.zeros((cnt, nptn - 2),dtype=float) # Generate the patterns cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue # Create a pattern with these two endpoints for x in range(l + 1, ll): index[cnt, 0] = l index[cnt, 1] = x index[cnt, 2] = ll pattern[cnt, 0] = (linelist[x] - linelist[l]) / (linelist[ll] - linelist[l]) cnt += 1 return pattern, index
[docs] def tetragon(linelist, numsrch, maxlin): """ Generate a series of tetragon patterns, given an input list of detections or lines from a linelist linelist : ndarray list of wavelength calibration lines (must be sorted by ascending wavelength) numsrch : int Number of consecutive detected lines used to generate a pattern. For example, if numsrch is 5, there are four lines (called 1 2 3 4 5). The following patterns will be generated (assuming line #1 is the left anchor): 1 2 3 4 (in this case line #4 is the right anchor) 1 2 3 5 (in this case line #5 is the right anchor) 1 2 4 5 (in this case line #5 is the right anchor) 1 3 4 5 (in this case line #5 is the right anchor) maxlin : float Value (in pixels in the case of detections or Angstroms in the case of a linelist) over which the wavelength solution can be considered linear. """ nptn = 4 # Number of lines used to create a pattern sz_l = linelist.shape[0] # Count the number of patterns that will be created cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue for x in range(l + 1, ll - 2): for xx in range(x + 1, ll): cnt += 1 index = np.zeros((cnt, nptn), dtype=np.uint64) pattern = np.zeros((cnt, nptn - 2),dtype=float) # Generate the patterns cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue # Create a pattern with these two endpoints for x in range(l + 1, ll - 2): for xx in range(x + 1, ll): index[cnt, 0] = l index[cnt, 1] = x index[cnt, 2] = xx index[cnt, 3] = ll pattern[cnt, 0] = (linelist[x] - linelist[l]) / (linelist[ll] - linelist[l]) pattern[cnt, 1] = (linelist[xx] - linelist[l]) / (linelist[ll] - linelist[l]) cnt += 1 return pattern, index
[docs] def pentagon(linelist, numsrch, maxlin): """ see trigon and tetragon for an example docstring """ nptn = 5 # Number of lines used to create a pattern sz_l = linelist.shape[0] # Count the number of patterns that will be created cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue for x in range(l + 1, ll - 3): for xx in range(x + 1, ll - 2): for xxx in range(xx + 1, ll - 1): cnt += 1 index = np.zeros((cnt, nptn), dtype=np.uint64) pattern = np.zeros((cnt, nptn - 2),dtype=float) # Generate the patterns cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue # Create a pattern with these two endpoints for x in range(l + 1, ll - 3): for xx in range(x + 1, ll - 2): for xxx in range(xx + 1, ll - 1): index[cnt, 0] = l index[cnt, 1] = x index[cnt, 2] = xx index[cnt, 3] = xxx index[cnt, 4] = ll pattern[cnt, 0] = (linelist[x] - linelist[l]) / (linelist[ll] - linelist[l]) pattern[cnt, 1] = (linelist[xx] - linelist[l]) / (linelist[ll] - linelist[l]) pattern[cnt, 2] = (linelist[xxx] - linelist[l]) / (linelist[ll] - linelist[l]) cnt += 1 return pattern, index
[docs] def hexagon(linelist, numsrch, maxlin): """ see trigon and tetragon for an example docstring """ # Number of lines used to create a pattern nptn = 6 sz_l = linelist.shape[0] # Count the number of patterns that will be created cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue for x in range(l + 1, ll - 4): for xx in range(x + 1, ll - 3): for xxx in range(xx + 1, ll - 2): for xxxx in range(xxx + 1, ll - 1): cnt += 1 index = np.zeros((cnt, nptn),dtype=np.uint64) pattern = np.zeros((cnt, nptn - 2),dtype=float) # Generate the patterns cnt = 0 for l in range(0, sz_l - nptn + 1): nup = (l + nptn - 1) + numsrch if nup > sz_l: nup = sz_l for ll in range(l + nptn - 1, nup): if (linelist[ll] - linelist[l]) > maxlin: continue # Create a pattern with these two endpoints for x in range(l + 1, ll - 4): for xx in range(x + 1, ll - 3): for xxx in range(xx + 1, ll - 2): for xxxx in range(xxx + 1, ll - 1): index[cnt, 0] = l index[cnt, 1] = x index[cnt, 2] = xx index[cnt, 3] = xxx index[cnt, 4] = xxxx index[cnt, 5] = ll pattern[cnt, 0] = (linelist[x] - linelist[l]) / (linelist[ll] - linelist[l]) pattern[cnt, 1] = (linelist[xx] - linelist[l]) / (linelist[ll] - linelist[l]) pattern[cnt, 2] = (linelist[xxx] - linelist[l]) / (linelist[ll] - linelist[l]) pattern[cnt, 3] = (linelist[xxxx] - linelist[l]) / (linelist[ll] - linelist[l]) cnt += 1 return pattern, index
[docs] def load_tree(polygon=4, numsearch=20): """ Load a KDTree of ThAr patterns that is stored on disk Parameters ---------- polygon : int Number of sides to the polygon used in pattern matching: - polygon=3 --> trigon (two anchor lines and one floating line) - polygon=4 --> tetragon (two anchor lines and two floating lines) - polygon=5 --> pentagon (two anchor lines and three floating lines) numsearch : int Number of consecutive detected lines used to generate a pattern. For example, if numsearch is 4, then for a trigon, the following patterns will be generated (assuming line #1 is the left anchor): - 1 2 3 (in this case line #3 is the right anchor) - 1 2 4 (in this case line #4 is the right anchor) - 1 3 4 (in this case line #4 is the right anchor) Returns ------- file_load : KDTree instance The KDTree containing the patterns index : ndarray For each pattern in the KDTree, this array stores the corresponding index in the linelist """ filename = dataPaths.linelist.get_file_path( f'ThAr_patterns_poly{polygon}_search{numsearch}.kdtree') fileindx = dataPaths.linelist.get_file_path( f'ThAr_patterns_poly{polygon}_search{numsearch}.index.npy') try: with open(filename, "rb", encoding="utf-8") as f_obj: # TODO: We should not be using pickle file_load = pickle.load(f_obj) index = np.load(fileindx) except FileNotFoundError: log.info( 'The requested KDTree was not found on disk\nplease be patient while the ThAr KDTree ' 'is built and saved to disk.' ) file_load, index = main( polygon, numsearch=numsearch, verbose=True, ret_treeindx=True, outname=filename ) return file_load, index
[docs] def main(polygon, numsearch=8, maxlinear=100.0, use_unknowns=True, leafsize=30, verbose=False, ret_treeindx=False, outname=None, ): """Driving method for generating the KD Tree Parameters ---------- polygon : int Number of sides to the polygon used in pattern matching numsearch : int Number of adjacent lines to use when deriving patterns maxlinear : float Over how many Angstroms is the solution deemed to be linear use_unknowns : bool Include unknown lines in the wavelength calibration (these may arise from lines other than Th I/II and Ar I/II) leafsize : int The leaf size of the tree """ # Load the ThAr linelist line_lists_all = waveio.load_line_lists(['ThAr']) line_lists = line_lists_all[np.where(line_lists_all['ion'] != 'UNKNWN')] unknwns = line_lists_all[np.where(line_lists_all['ion'] == 'UNKNWN')] if use_unknowns: tot_list = vstack([line_lists, unknwns]) else: tot_list = line_lists wvdata = np.array(tot_list['wave'].data) # Removes mask if any wvdata.sort() # NIST_lines = (line_lists_all['NIST'] > 0) & (np.char.find(line_lists_all['Source'].data, 'MURPHY') >= 0) # wvdata = line_lists_all['wave'].data[NIST_lines] # wvdata.sort() if polygon == 3: if verbose: print("Generating patterns for a trigon") pattern, index = trigon(wvdata, numsearch, maxlinear) elif polygon == 4: if verbose: print("Generating patterns for a tetragon") pattern, index = tetragon(wvdata, numsearch, maxlinear) elif polygon == 5: if verbose: print("Generating patterns for a pentagon") pattern, index = pentagon(wvdata, numsearch, maxlinear) elif polygon == 6: if verbose: print("Generating patterns for a hexagon") pattern, index = hexagon(wvdata, numsearch, maxlinear) else: if verbose: print("Patterns can only be generated with 3 <= polygon <= 6") return None if outname is None: outname = dataPaths.linelist.get_file_path(f'ThAr_patterns_poly{polygon}_search{numsearch}.kdtree') outindx = outname.replace('.kdtree', '.index') print("Generating Tree") tree = cKDTree(pattern, leafsize=leafsize) print("Saving Tree") pickle.dump(tree, open(outname, 'wb')) print("Written KD Tree file:\n{0:s}".format(outname)) np.save(outindx, index) print("Written index file:\n{0:s}".format(outindx)) #_ = pickle.load(open(outname, 'rb')) #print("loaded successfully") if ret_treeindx: return tree, index
# Test if __name__ == '__main__': """Set the number of sides to the polygon. Some examples: ========================================================= A trigon (polygon=3) contains (1) a starting point (s), (2) an end point (e), and (3) something in between (b) | | | | | | | | | s b e Then, the value (b-s)/(e-s) is in the same coordinate system for both detlines and linelist. ========================================================= A tetragon (polygon=4) contains (1) a left line (l), (2) a right line (r), and (3) two lines in between (a, b) | | | | | | | | | | | | l a b r Then, the values (a-ll)/(r-ll) and (b-ll)/(r-ll) are in the same coordinate system for both detlines and linelist. """ polygon = 4 numsearch = 10 main(polygon, numsearch=numsearch, verbose=True)