Source code for meer21cm.dataanalysis

"""
This module contains the base class for reading and visualizing the map data cube.

Note that, the defined class, :py:class:`Specification`, is the base class for reading and visualizing the map data cube.
It is typically used as a base class for other classes that inherit from it, and not used directly.
"""


import numpy as np
from astropy.io import fits
from .util import (
    check_unit_equiv,
    freq_to_redshift,
    f_21,
    center_to_edges,
    find_ch_id,
    tagging,
    find_property_with_tags,
    angle_in_range,
    create_wcs,
    tightest_ra_interval,
    which_ra_range_is_tighter,
    real_dtype_from_array,
)
from astropy import constants, units
from .io import (
    cal_freq,
    read_map,
    filter_incomplete_los,
    read_pickle,
)
from astropy.wcs.utils import proj_plane_pixel_area
from itertools import chain
from . import telescope
from .telescope import *
from .skymap import HealpixSkyMap, WcsSkyMap
import meer21cm
import logging
import numbers
import inspect

logger = logging.getLogger(__name__)

default_data_dir = meer21cm.__file__.rsplit("/", 1)[0] + "/data/"


def _validate_precision_flag(value):
    if not isinstance(value, bool):
        raise TypeError("precision must be bool: True (float64) or False (float32)")
    return value


def _validate_batch_number(value):
    if type(value) is not int or value < 1:
        raise TypeError("batch_number must be a positive integer")
    return value


default_nu = {
    "meerkat_L": cal_freq(np.arange(4096) + 1, band="L"),
    "meerkat_UHF": cal_freq(np.arange(4096) + 1, band="UHF"),
    "meerklass_2021_L": cal_freq(np.arange(4096) + 1, band="L"),
    "meerklass_2019_L": cal_freq(np.arange(4096) + 1, band="L"),
    "meerklass_UHF": cal_freq(np.arange(4096) + 1, band="UHF"),
}


[docs] class Specification: """ Base class for reading and visualizing the map data cube. Parameters ---------- nu: np.ndarray, default None The frequencies of the survey in Hz. wproj: :py:class:`astropy.wcs.WCS`, default None The WCS object for the map. num_pix_x: int, default None The number of pixels in the first axis of the map data (WCS only). num_pix_y: int, default None The number of pixels in the second axis of the map data (WCS only). hp_nside: int, default None HEALPix :math:`N_{side}`. Implies sparse ``(n_pix, n_chan)`` layout via :class:`HealpixSkyMap`. Incompatible with predefined ``survey``/``band`` maps. Mutually exclusive with passing ``skymap``. healpix_pixel_id: array-like, default None Optional explicit sparse pixel indices at ``hp_nside`` (RING, ``nest=False``). If omitted, pixels are derived from ``ra_range`` and ``dec_range``. map_has_sampling: np.ndarray, default None A binary window for whether a pixel has been sampled. sigma_beam_ch: np.ndarray, default None The beam size parameter for each frequency channel. beam_unit: :py:class:`astropy.units.Unit`, default :py:class:`astropy.units.deg` The unit of the beam size parameter. map_unit: :py:class:`astropy.units.Unit`, default :py:class:`astropy.units.K` The unit of the map data. map_file: str, default None The file path of the map data. Supports automatic reading of the MeerKLASS L-band data. For UHF data use `pickle_file` for the file path of the pickle file. counts_file: str, default None The file path of the hit counts data. Supports automatic reading of the MeerKLASS L-band data. For UHF data use `pickle_file` for the file path of the pickle file. pickle_file: str, default None The file path of the pickle file. Supports automatic reading of the MeerKLASS UHF data. los_axis: int, default -1 The axis of the map data that corresponds to the line of sight. **Warning**: Tranposing the data to align the los axis is not properly taken care of in the code. If your map los axis is not the last axis, it is recommended to manually transpose the data so that the los axis is the last axis. nu_min: float, default None, The minimum frequency of the survey in Hz. Data below this frequency will be clipped. nu_max: float, default None, The maximum frequency of the survey in Hz. Data above this frequency will be clipped. filter_map_los: bool, default True Whether to filter the map data along the line of sight. See :meth:`meer21cm.io.filter_incomplete_los` gal_file: str, default None, The file path of the galaxy catalogue. weighting: str, default "counts" The weighting scheme for the map data. ra_range: tuple, default (0, 360) The range of the right ascension of the map data in degrees. Data outside this range will be masked. dec_range: tuple, default (-90, 90) The range of the declination of the map data in degrees. Data outside this range will be masked. beam_model: str, default "gaussian" The shape of the beam. data: np.ndarray, default None The map data. weights_map_pixel: np.ndarray, default None The weights per map pixel. counts: np.ndarray, default None The number of hits per pixel for the map data. survey: str, default "" The survey name. band: str, default "" The band of the survey. z_interp_max: float, default 6.0 The maximum redshift to interpolate the redshift as a function of comoving distance. See :meth:`meer21cm.dataanalysis.Specification.get_z_as_func_of_comov_dist`. soft_filter_los: bool, default True If `filter_map_los` is True, whether to use a soft criterion. If False, any line of sight that is not 100% sampled will be removed. If True, the maximum sampling fraction of the map cube is calculated and used as the criterion. See :meth:`meer21cm.io.filter_incomplete_los`. filter_los_threshold: float, default None If given, instead of filtering out incomplete los by checking the maximum sampling fraction along the los, a fixed threshold is used to filter out incomplete los. See :meth:`meer21cm.io.filter_incomplete_los`. data_column: str, default "map" The column name of the map data. counts_column: str, default "hit" The column name of the number of sampling for each pixel. freq_column: str, default "freq" The column name of the frequencies of each channel in the data. wcs_column: str, default "wcs" The column name of the :class:`astropy.wcs.WCS` object for the map. auto_set_radecnu_bounds: bool, default True If True, :meth:`read_from_fits` and :meth:`read_from_pickle` call :meth:`set_radecnu_bounds_from_map` after loading so ``ra_range``, ``dec_range``, ``nu_min``, and ``nu_max`` match the loaded grid and channels. precision: bool, default True Floating precision selector for core numeric arrays. If True, use double precision (`np.float64`); if False, use single precision (`np.float32`). batch_number: int, default 1 Number of sequential batches used by various routines. A value of 1 means no batching. skymap: :class:`~meer21cm.skymap.SkyMap`, default None Injected angular geometry (:class:`~meer21cm.skymap.WcsSkyMap` or :class:`~meer21cm.skymap.HealpixSkyMap`). Mutually exclusive with ``hp_nside``. b_ell_l_max: int, default 8192 The maximum multipole for the beam window function for healpix map smoothing. """ def __init__( self, nu=None, wproj=None, num_pix_x=None, num_pix_y=None, map_has_sampling=None, sigma_beam_ch=None, beam_unit=units.deg, map_unit=units.K, map_file=None, counts_file=None, pickle_file=None, los_axis=-1, nu_min=None, nu_max=None, filter_map_los=True, gal_file=None, weighting="counts", ra_range=(0, 360), dec_range=(-90, 90), beam_model="gaussian", data=None, weights_map_pixel=None, counts=None, survey="", band="", z_interp_max=6.0, soft_filter_los=True, filter_los_threshold=None, data_column="map", counts_column="hit", freq_column="freq", wcs_column="wcs", auto_set_radecnu_bounds=True, precision=True, batch_number=1, skymap=None, hp_nside=None, healpix_pixel_id=None, b_ell_l_max=8192, **kwparams, ): self.survey = survey self.band = band spec_key = survey + "_" + band if spec_key in default_nu.keys(): logger.info( f"found {spec_key} in predefined settings, using default settings" " and override the following parameters:" " nu, nu_min, nu_max, num_pix_x, num_pix_y, wproj", ) nu = default_nu[spec_key] nu_min = default_nu_min[spec_key] nu_max = default_nu_max[spec_key] num_pix_x = default_num_pix_x[spec_key] num_pix_y = default_num_pix_y[spec_key] wproj = default_wproj[spec_key] if spec_key in default_nu.keys() and hp_nside is not None: raise ValueError( "Predefined survey/band grids are WCS-only; omit hp_nside when using " "survey=... and band=..., or use a non-default survey/band key." ) self.dependency_dict = find_property_with_tags(self) funcs = list(chain.from_iterable(list(self.dependency_dict.values()))) for func_i in np.unique(np.array(funcs)): setattr(self, func_i + "_dep_attr", []) for dep_attr, inp_func in self.dependency_dict.items(): for func in inp_func: old_dict = getattr(self, func + "_dep_attr") setattr( self, func + "_dep_attr", old_dict + [ "_" + dep_attr, ], ) self.map_file = map_file self.counts_file = counts_file self.pickle_file = pickle_file self.los_axis = los_axis self._precision = _validate_precision_flag(precision) self._batch_number = _validate_batch_number(batch_number) sel_nu = True if nu is None: nu = np.array([f_21 - 1, f_21]) sel_nu = False if nu_min is None: nu_min = -np.inf if nu_max is None: nu_max = np.inf nu_sel = (nu > nu_min) * (nu < nu_max) if sel_nu: if nu_sel.sum() == 0: raise ValueError("input nu is not in the range of nu_min and nu_max") self.nu = nu[nu_sel] else: self.nu = nu self.nu_min = nu_min self.nu_max = nu_max self.ra_range = ra_range self.dec_range = dec_range if hp_nside is not None and skymap is not None: raise ValueError("pass only one of skymap or hp_nside.") if skymap is not None and healpix_pixel_id is not None: raise ValueError( "healpix_pixel_id is invalid when passing skymap; set pixel_id on " "HealpixSkyMap instead." ) if healpix_pixel_id is not None and hp_nside is None and skymap is None: raise ValueError( "healpix_pixel_id requires hp_nside or pass skymap=HealpixSkyMap(...)." ) if hp_nside is not None: self.skymap = HealpixSkyMap( hp_nside, pixel_id=None if healpix_pixel_id is None else np.asarray(healpix_pixel_id, dtype=np.int64), ra_range=self.ra_range if healpix_pixel_id is None else None, dec_range=self.dec_range if healpix_pixel_id is None else None, ) elif skymap is not None: self.skymap = skymap else: if num_pix_x is None: num_pix_x = 3 if num_pix_y is None: num_pix_y = 3 if wproj is None: wproj = create_wcs(0.0, 0.0, [num_pix_x, num_pix_y], 1.0) self.skymap = WcsSkyMap( wproj=wproj, num_pix_x=num_pix_x, num_pix_y=num_pix_y, ) self.sigma_beam_ch = sigma_beam_ch self.beam_unit = beam_unit if map_has_sampling is None: map_has_sampling = np.ones( self.skymap.map_shape_template + (len(self.nu),), dtype="bool" ) if self.skymap.format == "wcs": map_has_sampling[0] = False map_has_sampling[-1] = False map_has_sampling[:, 0] = False map_has_sampling[:, -1] = False self.map_has_sampling = map_has_sampling self.map_unit = map_unit self.map_unit_type self.__dict__.update(kwparams) self.filter_map_los = filter_map_los self.soft_filter_los = soft_filter_los self.filter_los_threshold = filter_los_threshold self.gal_file = gal_file self.weighting = weighting self._sigma_beam_ch_in_mpc = None if data is None: data = np.zeros(self.map_has_sampling.shape, dtype=self.real_dtype) self.data = data if weights_map_pixel is None: weights_map_pixel = np.ones( self.map_has_sampling.shape, dtype=self.real_dtype ) if self.skymap.format == "wcs": weights_map_pixel[0] = 0.0 weights_map_pixel[-1] = 0.0 weights_map_pixel[:, 0] = 0.0 weights_map_pixel[:, -1] = 0.0 self.weights_map_pixel = weights_map_pixel if counts is None: counts = np.ones(self.map_has_sampling.shape, dtype=self.real_dtype) self.counts = counts self.trim_map_to_range() self.beam_type = None self.beam_model = beam_model self._beam_image = None self._beam_window_ch = None self._z_as_func_of_comov_dist = None self.z_interp_max = z_interp_max self.data_column = data_column self.counts_column = counts_column self.freq_column = freq_column self.wcs_column = wcs_column self.auto_set_radecnu_bounds = auto_set_radecnu_bounds self.b_ell_l_max = b_ell_l_max def _set_wcs_skymap(self, wproj, num_pix_x, num_pix_y): """Reset the WCS skymap backend while preserving WCS-only behavior.""" self.skymap = WcsSkyMap( wproj=wproj, num_pix_x=num_pix_x, num_pix_y=num_pix_y, )
[docs] def set_radecnu_bounds_from_map(self): """ Set ``ra_range``, ``dec_range``, ``nu_min``, and ``nu_max`` from the loaded ``_ra_map``, ``_dec_map``, and ``nu`` (tight RA interval, declination min/max, frequency min/max). Only consider unmaksed pixels. """ # in case it is not properly initialized, use the full grid if self.W_HI.sum() > 0: ra = self.ra_map[self.W_HI.sum(-1) > 0] dec = self.dec_map[self.W_HI.sum(-1) > 0] else: ra = self.ra_map dec = self.dec_map ra_range = tightest_ra_interval(ra) nu_min = self.nu.min() nu_max = self.nu.max() nu_min = np.max([self.nu_min, nu_min]) - 1 nu_max = np.min([self.nu_max, nu_max]) + 1 dec_min = np.max([self.dec_range[0], dec.min()]) - 1e-5 dec_max = np.min([self.dec_range[1], dec.max()]) + 1e-5 ra_flag = which_ra_range_is_tighter(ra_range, self.ra_range) if ra_flag > 0: ra_range = self.ra_range else: ra_0 = np.max([ra_range[0] - 1e-5, 0]) ra_1 = np.min([ra_range[1] + 1e-5, 360]) ra_range = (ra_0, ra_1) self.dec_range = (dec_min, dec_max) self.nu_min = nu_min self.nu_max = nu_max self.ra_range = ra_range
@property def map_unit_type(self): """ The type of the map unit. If the map unit is temperature, return "T". If the map unit is flux density, return "F". If the map unit is not temperature or flux density, raise an error. """ map_unit = self.map_unit if not check_unit_equiv(map_unit, units.K): if not check_unit_equiv(map_unit, units.Jy): raise ( ValueError, "map unit has be to either temperature or flux density.", ) else: map_unit_type = "F" else: map_unit_type = "T" return map_unit_type
[docs] def clean_cache(self, attr): """ Set the attributes to None. This is used to clear the cache of the attributes. """ for att in attr: if att in self.__dict__.keys(): setattr(self, att, None)
@property def precision(self): """Floating precision flag. True for float64, False for float32.""" return self._precision @property def real_dtype(self): """Active real floating dtype controlled by `precision`.""" return np.float64 if self.precision else np.float32 @property def batch_number(self): """Number of sequential batches used by gridding routines.""" return self._batch_number def _iter_last_axis_batches(self, axis_size): """ Split ``axis_size`` into ``batch_number`` index ranges along one axis. Used for chunked channel-wise convolution / smoothing etc. Always returns at least one non-empty batch so ``batch_number=1`` matches the unified code path used for ``batch_number>1``. """ n = int(axis_size) all_indx = np.arange(n, dtype=int) return [ sel for sel in np.array_split(all_indx, self.batch_number) if sel.size > 0 ] def _iter_field_los_chunks(self, arr): """ Yield ``(los_sel, arr[..., los_sel])`` along the last axis. Works for box fields ``(nx, ny, n_z)``, WCS map cubes ``(nx, ny, n_ch)``, and HEALPix cubes ``(n_pix, n_ch)``. """ for sel in self._iter_last_axis_batches(arr.shape[-1]): yield sel, arr[..., sel] @property def wproj(self): """The WCS projection object for the map geometry.""" if self.skymap.format != "wcs": raise KeyError("wproj is only defined for WCS sky maps.") return self.skymap.wproj @property def num_pix_x(self): """The number of pixels along the first map axis.""" if self.skymap.format != "wcs": raise KeyError("num_pix_x is only defined for WCS sky maps.") return self.skymap.num_pix_x @property def num_pix_y(self): """The number of pixels along the second map axis.""" if self.skymap.format != "wcs": raise KeyError("num_pix_y is only defined for WCS sky maps.") return self.skymap.num_pix_y @property def hp_nside(self): """HEALPix :math:`N_{side}` for HEALPix-backed specifications.""" if self.skymap.format != "healpix": raise KeyError("hp_nside is only defined for HEALPix sky maps.") return self.skymap.hp_nside @property def pixel_id(self): """Sparse HEALPix pixel indices (RING, ``nest=False``).""" if self.skymap.format != "healpix": raise KeyError("pixel_id is only defined for HEALPix sky maps.") return self.skymap.pixel_id @property def beam_type(self): """ The beam type that can be either be isotropic or anisotropic. """ return self._beam_type @beam_type.setter def beam_type(self, value): self._beam_type = value if "beam_dep_attr" in dir(self): self.clean_cache(self.beam_dep_attr) @property def beam_model(self): """ The name of the beam function. """ return self._beam_model @beam_model.setter def beam_model(self, value): beam_func = value + "_beam" if beam_func not in telescope.__dict__.keys(): raise ValueError(f"{value} is not a beam model") self._beam_model = value self.beam_type = getattr(telescope, value + "_beam").tags[0] self._beam_window_ch = None if "beam_dep_attr" in dir(self): self.clean_cache(self.beam_dep_attr) @property def beam_unit(self): """ The unit of input beam size parameter sigma """ return self._beam_unit @beam_unit.setter def beam_unit(self, value): self._beam_unit = value if "beam_dep_attr" in dir(self): self.clean_cache(self.beam_dep_attr) @property def sigma_beam_ch(self): """ The input beam size parameter sigma for each channel. If one number is provided, it will be used for all channels. """ return self._sigma_beam_ch @sigma_beam_ch.setter def sigma_beam_ch(self, value): if isinstance(value, numbers.Number): value = np.ones(self.nu.size, dtype=self.real_dtype) * float(value) elif value is not None: value = np.asarray(value, dtype=self.real_dtype) self._sigma_beam_ch = value self._beam_window_ch = None if "beam_dep_attr" in dir(self): self.clean_cache(self.beam_dep_attr) @property def sigma_beam_in_mpc(self): """ The channel averaged beam size in Mpc """ if self.sigma_beam_ch_in_mpc is None: return None return self.sigma_beam_ch_in_mpc.mean() @property def nu(self): """ The input frequencies of the survey """ return self._nu @nu.setter def nu(self, value): self._nu = np.asarray(value, dtype=self.real_dtype) if "nu_dep_attr" in dir(self): self.clean_cache(self.nu_dep_attr) # nu dependent, but it calculates on the fly # so no need for tags @property def z_ch(self): """ The redshift of each frequency channel """ return freq_to_redshift(self.nu) @property def z(self): """ The effective centre redshift of the frequency range """ return self.z_ch.mean() @property def dvdf_ch(self): """ velocity resolution per unit frequency in each channel, in km/s/Hz """ return (constants.c / self.nu).to("km/s").value @property def vel_resol_ch(self): """ velocity resolution of each channel in km/s """ return self.dvdf_ch * self.freq_resol @property def dvdf(self): """ velocity resolution per unit frequency on average, in km/s/Hz """ return self.dvdf_ch.mean() @property def vel_resol(self): """ velocity resolution on average in km/s """ return self.vel_resol_ch.mean() @property def freq_resol(self): """ frequency resolution in Hz """ return np.diff(self.nu).mean() @property def pixel_area(self): """ angular area of the map pixel in deg^2 """ return self.skymap.pixel_area @property def pix_resol(self): """ angular resolution of the map pixel in deg """ return self.skymap.pix_resol @property def data(self): """ The map data """ return self._data @data.setter def data(self, value): self._data = np.asarray(value, dtype=self.real_dtype) @property def counts(self): """ The number of hits per pixel for the map data """ return self._counts @counts.setter def counts(self, value): self._counts = np.asarray(value, dtype=self.real_dtype) @property def map_has_sampling(self): """ A binary window for whether a pixel has been sampled """ return self._map_has_sampling @map_has_sampling.setter def map_has_sampling(self, value): self._map_has_sampling = np.asarray(value, dtype=bool) W_HI = map_has_sampling @property def ra_map(self): """ The right ascension of each pixel in the map. """ return self.skymap.ra_map @property def dec_map(self): """ The declination of each pixel in the map. """ return self.skymap.dec_map @property def weights_map_pixel(self): """ The weights per map pixel. """ return self._weights_map_pixel @weights_map_pixel.setter def weights_map_pixel(self, value): self._weights_map_pixel = np.asarray(value, dtype=self.real_dtype) w_HI = weights_map_pixel @property def ra_gal(self): """ The right ascension of each galaxy in the catalogue for cross-correlation. """ return self._ra_gal @property def dec_gal(self): """ The declination of each galaxy in the catalogue for cross-correlation. """ return self._dec_gal @property def z_gal(self): """ The redshifts of each galaxy in the catalogue for cross-correlation. """ return self._z_gal @property def freq_gal(self): """ The 21cm line frequency for each galaxy in Hz. """ return f_21 / (1 + self.z_gal) @property def ch_id_gal(self): """ The channel id (0-indexed) of each galaxy in the catalogue for cross-correlation. Galaxies out of the frequency range will be given len(self.nu) as indices. """ return find_ch_id(self.freq_gal, self.nu)
[docs] def read_gal_cat( self, ra_col="RA", dec_col="DEC", z_col="Z", trim=True, ): """ Read in a galaxy catalogue for cross-correlation and save the data into the class attributes. The data is read from the `gal_file`, which has to be a FITS file. Parameters ---------- ra_col: str, default "RA" The column name of the right ascension in the galaxy catalogue. dec_col: str, default "DEC" The column name of the declination in the galaxy catalogue. z_col: str, default "Z" The column name of the redshift in the galaxy catalogue. trim: bool, default True Whether to trim the galaxy catalogue to the ra,dec,z range of the map. See :meth:`meer21cm.dataanalysis.Specification.trim_gal_to_range`. """ if self.gal_file is None: print("no gal_file specified") return None hdu = fits.open(self.gal_file) ra_g = hdu[1].data[ra_col] # Right ascension (J2000) [deg] dec_g = hdu[1].data[dec_col] # Declination (J2000) [deg] z_g = hdu[1].data[z_col] # Spectroscopic redshift, -1 for none attempted self._ra_gal = ra_g self._dec_gal = dec_g self._z_gal = z_g if trim: self.trim_gal_to_range()
[docs] def read_from_pickle(self): """ Read in a pickle file for cross-correlation and save the data into the class attributes. See :meth:`meer21cm.io.read_pickle` for more details. """ if self.pickle_file is None: print("no pickle_file specified") return None ( self.data, self.counts, self.map_has_sampling, _ra_map, _dec_map, self.nu, wproj, ) = read_pickle( self.pickle_file, nu_min=self.nu_min, nu_max=self.nu_max, los_axis=self.los_axis, data_column=self.data_column, counts_column=self.counts_column, freq_column=self.freq_column, wcs_column=self.wcs_column, ) self._set_wcs_skymap( wproj=wproj, num_pix_x=_ra_map.shape[0], num_pix_y=_ra_map.shape[1], ) if self.filter_map_los: print("filtering map los") (self.data, self.map_has_sampling, _, self.counts,) = filter_incomplete_los( self.data, self.map_has_sampling, self.counts, self.counts, soft_mask=self.soft_filter_los, threshold_instead_of_filter=self.filter_los_threshold, ) if self.weighting.lower()[:5] == "count": self.weights_map_pixel = self.counts elif self.weighting.lower()[:7] == "uniform": self.weights_map_pixel = (self.counts > 0).astype("float") if self.auto_set_radecnu_bounds: self.set_radecnu_bounds_from_map() self.trim_map_to_range()
[docs] def read_from_fits(self): """ Read in a FITS file for the map data and hit counts. The FITS file need to follow the format of the MeerKLASS L-band data. See :meth:`meer21cm.io.read_map` for more details. After reading the data, the map data and hit counts are filtered along the frequency direction (see :meth:`meer21cm.io.filter_incomplete_los`), and trimmed to the specified range (see :meth:`meer21cm.dataanalysis.Specification.trim_map_to_range`). The weights per pixel are set to the hit counts if `self.weighting` is "counts", or set to 1 if `self.weighting` is "uniform". """ if self.map_file is None: print("no map_file specified") return None ( self.data, self.counts, self.map_has_sampling, _ra_map, _dec_map, self.nu, wproj, ) = read_map( self.map_file, counts_file=self.counts_file, nu_min=self.nu_min, nu_max=self.nu_max, los_axis=self.los_axis, band=self.band, ) self._set_wcs_skymap( wproj=wproj, num_pix_x=_ra_map.shape[0], num_pix_y=_ra_map.shape[1], ) if self.filter_map_los: (self.data, self.map_has_sampling, _, self.counts,) = filter_incomplete_los( self.data, self.map_has_sampling, self.counts, self.counts, soft_mask=self.soft_filter_los, threshold_instead_of_filter=self.filter_los_threshold, ) if self.weighting.lower()[:5] == "count": self.weights_map_pixel = self.counts elif self.weighting.lower()[:7] == "uniform": self.weights_map_pixel = (self.counts > 0).astype("float") if self.auto_set_radecnu_bounds: self.set_radecnu_bounds_from_map() self.trim_map_to_range()
[docs] def trim_map_to_range(self): """ Trim the map to the specified range. The map data and counts outside the range will be set to zero. The map_has_sampling and weights_map_pixel will be set to False outside the range. """ logger.debug( "flagging map and weights outside " f"ra_range: {self.ra_range}, dec_range: {self.dec_range}" ) trim = np.asarray( self.skymap.trim_selector(self.ra_range, self.dec_range), dtype=float ) # if trim.shape != self.skymap.map_shape_template: # raise ValueError( # "trim_selector shape mismatch with map_shape_template: " # f"{trim.shape} vs {self.skymap.map_shape_template}." # ) map_sel = trim.reshape(trim.shape + (1,) * (self.data.ndim - trim.ndim)) self.data = self.data * map_sel self.counts = self.counts * map_sel self.map_has_sampling = self.map_has_sampling * map_sel self.weights_map_pixel = self.weights_map_pixel * map_sel
[docs] def trim_gal_to_range(self): """ Trim the galaxy catalogue to the specified range. The galaxy catalogue outside the ra-dec-z range will be removed. Note that, a small buffer corresponding to half of the frequency channel bandwidth is added to the redshift range. """ ra_range = np.array(self.ra_range) dec_range = np.array(self.dec_range) freq_edges = center_to_edges(self.nu) z_edges = freq_to_redshift(freq_edges) logger.debug( f"flagging galaxy catalogue outside ra_range: {ra_range}, dec_range: {dec_range} and " f"z_range: [{z_edges.min()}, {z_edges.max()}]" ) gal_sel = ( angle_in_range(self.ra_gal, ra_range[0], ra_range[1]) * (self.dec_gal > dec_range[0]) * (self.dec_gal < dec_range[1]) ) z_sel = (self.z_gal > z_edges.min()) * (self.z_gal < z_edges.max()) gal_sel *= z_sel self._ra_gal = self.ra_gal[gal_sel] self._dec_gal = self.dec_gal[gal_sel] self._z_gal = self.z_gal[gal_sel] return gal_sel
@property @tagging("beam", "nu") def beam_image(self): """ Returns the beam image projected onto the sky map for the input beam model. """ if self._beam_image is None: self.get_beam_image() return self._beam_image
[docs] def get_beam_image( self, wproj=None, num_pix_x=None, num_pix_y=None, cache=True, ch_sel=None, ): """ Calculate the beam image projected onto the sky map for the input beam model. Parameters ---------- wproj: :py:class:`astropy.wcs.WCS`, default None The WCS object for the map. Default uses `self.wproj`. num_pix_x: int, default None The number of pixels in the first axis of the map data. Default uses `self.num_pix_x`. num_pix_y: int, default None The number of pixels in the second axis of the map data. Default uses `self.num_pix_y`. cache: bool, default True Whether to cache the beam image. Default is True. If True, the beam image will be cached and returned directly if it is already computed. If False, the beam image will be computed and returned. The cache is saved into the class attribute `beam_image`. ch_sel: array-like, default None Optional channel selection. If provided, returns beam image only for selected channels. Caching is only applied when `ch_sel` is None. """ if self.sigma_beam_ch is None: logger.info( f"sigma_beam_ch is None, returning None for {inspect.currentframe().f_code.co_name}" ) return None logger.info( f"invoking {inspect.currentframe().f_code.co_name} to get the beam image" ) logger.info(f"beam_type: {self.beam_type}, sigma_beam_ch: {self.sigma_beam_ch}") if self.skymap.format != "wcs": raise NotImplementedError( "get_beam_image is implemented for WCS maps only; on HEALPix use " "get_beam_window_ch for harmonic beam windows." ) if wproj is None: wproj = self.wproj if num_pix_x is None: num_pix_x = self.num_pix_x if num_pix_y is None: num_pix_y = self.num_pix_y n_ch = len(self.nu) if ch_sel is None: ch_sel = np.arange(n_ch, dtype=int) use_full_channels = True else: ch_sel = np.asarray(ch_sel, dtype=int) full_idx = np.arange(n_ch, dtype=int) use_full_channels = ch_sel.shape == full_idx.shape and np.allclose( ch_sel, full_idx ) if use_full_channels: ch_sel = full_idx if ( use_full_channels and cache and self._beam_image is not None and self._beam_image.shape == (num_pix_x, num_pix_y, len(self.nu)) ): return self._beam_image pix_resol = np.sqrt(proj_plane_pixel_area(wproj)) beam_image = np.zeros( (num_pix_x, num_pix_y, len(ch_sel)), dtype=self.real_dtype ) beam_model = getattr(telescope, self.beam_model + "_beam") if self.beam_type == "isotropic": for i_out, i_ch in enumerate(ch_sel): beam_image[:, :, i_out] = telescope.isotropic_beam_profile( num_pix_x, num_pix_y, wproj, beam_model(self.sigma_beam_ch[i_ch]), ) else: beam_image = beam_model( self.nu[ch_sel], wproj, num_pix_x, num_pix_y, band=self.band, ) sigma_beam_from_image = ( np.sqrt(beam_image.sum(axis=(0, 1)) / 2 / np.pi) * pix_resol ) if use_full_channels: self.sigma_beam_ch = sigma_beam_from_image if cache and use_full_channels: self._beam_image = beam_image return beam_image
@property @tagging("beam", "nu") def beam_window_ch(self): """ Per-channel spherical-harmonic beam window :math:`B_\\ell` for HEALPix ``hp.smoothing`` workflows. Populated lazily via :meth:`get_beam_window_ch`. """ if self._beam_window_ch is None: self.get_beam_window_ch() return self._beam_window_ch
[docs] def get_beam_window_ch( self, ch_sel=None, cache=True, lmax=None, hp_nside=None, ): """ Build per-channel harmonic beam windows for HEALPix smoothing. Parameters ---------- ch_sel : array-like of int, optional Channel indices. If omitted, uses all ``len(self.nu)`` channels. cache : bool, default True Cache the full-channel result in ``_beam_window_ch`` when ``ch_sel`` spans all channels. lmax : int, optional Maximum multipole (inclusive). Defaults to ``min(3 * hp_nside - 1, 8192)``. hp_nside : int, optional HEALPix :math:`N_{\\rm side}` used to set the default ``lmax``. Defaults to ``self.hp_nside``. Returns ------- ndarray of shape ``(len(ch_sel), lmax + 1)`` Beam window rows :math:`B_\\ell` per selected channel. """ if self.sigma_beam_ch is None: logger.info( f"sigma_beam_ch is None, returning None for {inspect.currentframe().f_code.co_name}" ) return None if self.skymap.format != "healpix": raise ValueError( "get_beam_window_ch is only defined for healpix skymaps; " "use get_beam_image on WCS." ) if self.beam_type == "anisotropic": raise NotImplementedError( "HEALPix harmonic beam smoothing does not support anisotropic " "(kat) beams." ) if hp_nside is None: hp_nside = int(self.hp_nside) else: hp_nside = int(hp_nside) if lmax is None: lmax = int(min(3 * hp_nside - 1, self.b_ell_l_max)) else: lmax = int(lmax) n_ch = len(self.nu) if ch_sel is None: ch_sel = np.arange(n_ch, dtype=int) use_full_channels = True else: ch_sel = np.asarray(ch_sel, dtype=int) full_idx = np.arange(n_ch, dtype=int) use_full_channels = ch_sel.shape == full_idx.shape and np.allclose( ch_sel, full_idx ) if use_full_channels: ch_sel = full_idx if ( use_full_channels and cache and self._beam_window_ch is not None and self._beam_window_ch.shape == (len(self.nu), lmax + 1) ): return self._beam_window_ch beam_model = getattr(telescope, self.beam_model + "_beam") out = np.zeros((ch_sel.size, lmax + 1), dtype=np.float64) for i_out, i_ch in enumerate(ch_sel): sigma = self.sigma_beam_ch[i_ch] sigma_rad = (sigma * self.beam_unit).to(units.rad).value if self.beam_model == "gaussian": out[i_out] = telescope.gaussian_beam_window(sigma_rad, lmax) else: beam_func = beam_model(sigma) out[i_out] = telescope.isotropic_beam_window( beam_func, float(sigma_rad), lmax ) if cache and use_full_channels: self._beam_window_ch = out return out
def _convolve_data_healpix_harmonic(self, data, weights): """ HEALPix: weighted harmonic-space smoothing using :func:`weighted_smoothing_healpix` and :meth:`get_beam_window_ch`. """ data = np.asarray(data) weights = np.asarray(weights) if data.shape != weights.shape: raise ValueError( f"data shape {data.shape} does not match weights shape {weights.shape}." ) pix_id = np.asarray(self.pixel_id, dtype=np.int64) nside = int(self.hp_nside) if data.ndim != 2: raise ValueError( "HEALPix map cubes must have shape (n_pix, n_ch); " f"got {data.shape}. Use self.data ordering for the LOS axis." ) if data.shape[0] != pix_id.size: raise ValueError( f"data axis 0 ({data.shape[0]}) must equal len(self.pixel_id) ({pix_id.size})." ) fdtype = np.result_type(real_dtype_from_array(data), self.real_dtype) conv_data = np.zeros_like(data, dtype=fdtype) conv_weights = np.zeros_like(weights, dtype=fdtype) for ch_sel, sl_d in self._iter_field_los_chunks(data): beam_w = self.get_beam_window_ch(ch_sel=ch_sel, cache=False, hp_nside=nside) sl_d = np.asarray(sl_d, dtype=fdtype) sl_w = np.asarray(weights[..., ch_sel], dtype=fdtype) cd, cw = telescope.weighted_smoothing_healpix( sl_d, sl_w, beam_w, nside, pix_id, ) conv_data[:, ch_sel] = cd conv_weights[:, ch_sel] = cw return conv_data, conv_weights
[docs] def convolve_data(self, kernel=None, data=None, weights=None, assign_to_self=True): """ Convolve map data. **WCS** maps use raster-space :func:`~meer21cm.telescope.weighted_convolution` with a beam image kernel (e.g. from :meth:`beam_image`). **HEALPix** maps ignore raster ``kernel`` and use harmonic :func:`~meer21cm.telescope.weighted_smoothing_healpix` with per-channel beam windows from :meth:`get_beam_window_ch` — pass ``kernel=None``. Parameters ---------- kernel : np.ndarray or None Raster beam cube ``(..., nx, ny, n_ch)`` aligned with LOS on ``wcs``. Ignored when ``kernel is None`` on HEALPix (then harmonic smoothing is applied). On WCS maps, passing ``kernel=None`` raises ``ValueError`` (provide e.g. ``self.beam_image``). data : np.ndarray, default None Map to convolve; default ``self.data``. weights : np.ndarray, default None Per-pixel weights (e.g. ``self.w_HI``); required semantics match ``data``. assign_to_self : bool, default True Assign results to ``self.data`` and ``self.w_HI``. Returns ------- data : np.ndarray weights : np.ndarray """ if data is None: data = self.data if weights is None: weights = self.w_HI if self.skymap.format == "healpix": if kernel is not None: raise ValueError( "HEALPix backend does not use a raster ``kernel``. " "Pass ``kernel=None`` to convolve with the harmonic beam windows " "from ``sigma_beam_ch`` / ``beam_model`` (see ``get_beam_window_ch``)." ) logger.info( f"invoking {inspect.currentframe().f_code.co_name} " "(healpix harmonic weighted smoothing)" ) conv_d, conv_w = self._convolve_data_healpix_harmonic(data, weights) if assign_to_self: self.data = conv_d self.w_HI = conv_w return conv_d, conv_w if kernel is None: raise ValueError( "WCS ``convolve_data`` requires ``kernel`` (e.g. ``self.beam_image``)." ) logger.info( f"invoking {inspect.currentframe().f_code.co_name} " f"with raster kernel shaped {np.shape(kernel)}" ) data, weights = telescope.weighted_convolution( data, kernel, weights, ) if assign_to_self: self.data = data self.w_HI = weights return data, weights
@property def maximum_sampling_channel(self): """ Returns the index of the frequency channel with the maximum sampling on the sky map. """ nd = self.map_has_sampling.ndim la = self.los_axis if la < 0: la += nd axes = tuple(i for i in range(nd) if i != la) return np.argmax(self.map_has_sampling.sum(axis=axes))
[docs] def get_weights_none_to_one(self, attr_name): """ Get the weights, and if it is None, convert it to 1.0 of size of kmode. Only used for power spectrum calculation. Defined here for inheritance. """ weights = getattr(self, attr_name) if weights is None: if hasattr(self, "box_ndim"): weights = np.ones(self.box_ndim, dtype=self.real_dtype) else: shape = np.array(self.kmode.shape) shape[-1] = 2 * shape[-1] - 2 weights = np.ones(shape, dtype=self.real_dtype) return weights
[docs] def get_jackknife_patches( self, ra_patch_num, dec_patch_num, nu_patch_num, ra_range=None, dec_range=None, nu_range=None, ): """ Split the map into roughly equal patches. Each patch can then be masked, which can be used for jackknife resampling for covariance estimation. Note that the masks=True is where the pixels **should be masked**. So for example, if you want to exclude a patch, the correct survey window is then ``self.W_HI * (1-mask_arr[i])`` and the weights ``self.w_HI * (1-mask_arr[i])``. If you want to examine the patch splits, you can visualise the mask array by using :func:`meer21cm.plot.visualise_patch_split`. Parameters ---------- ra_patch_num: int The number of patche grids in the right ascension direction. dec_patch_num: int The number of patche grids in the declination direction. nu_patch_num: int The number of patche grids in the frequency direction. ra_range: tuple, default None The range of the right ascension of the map data in degrees. Default uses ``self.ra_range``. dec_range: tuple, default None The range of the declination of the map data in degrees. Default uses ``self.dec_range``. nu_range: tuple, default None The range of the frequency of the map data in Hz. Default uses ``[self.nu.min() - self.freq_resol/2, self.nu.max() + self.freq_resol/2]``. """ if ra_range is None: ra_range = self.ra_range if dec_range is None: dec_range = self.dec_range assert ( dec_range[0] < dec_range[1] ), "dec_range[0] must be less than dec_range[1]" assert dec_range[0] >= -90, "dec must be between -90 and 90" assert dec_range[1] <= 90, "dec must be between -90 and 90" if nu_range is None: nu_range = [ self.nu.min() - self.freq_resol / 2, self.nu.max() + self.freq_resol / 2, ] assert nu_range[0] < nu_range[1], "nu_range[0] must be less than nu_range[1]" assert not ( ra_range[0] == 0 and ra_range[1] == 360 ), "ra_range is whole sky 0-360, check if you have passed a value to it" ra_delta_map = (self.ra_map - ra_range[0]) % 360 ra_delta_bins = np.linspace( 0, (ra_range[1] - ra_range[0]) % 360, ra_patch_num + 1 ) dec_bins = np.linspace(dec_range[0], dec_range[1], dec_patch_num + 1) nu_bins = np.linspace(nu_range[0], nu_range[1], nu_patch_num + 1) ra_indx = np.digitize(ra_delta_map, ra_delta_bins) ra_indx[ra_indx == 0] = len(ra_delta_bins) dec_indx = np.digitize(self.dec_map, dec_bins) dec_indx[dec_indx == 0] = len(dec_bins) nu_indx = np.digitize(self.nu, nu_bins) nu_indx[nu_indx == 0] = len(nu_bins) ra_indx -= 1 dec_indx -= 1 nu_indx -= 1 mask_arr = np.zeros( (ra_patch_num, dec_patch_num, nu_patch_num) + self.W_HI.shape, dtype=bool, ) for i in range(len(ra_delta_bins) - 1): for j in range(len(dec_bins) - 1): for k in range(len(nu_bins) - 1): W_ijk = ((ra_indx == i) * (dec_indx == j))[:, :, None] * ( nu_indx == k )[None, None, :] mask_arr[i, j, k] = W_ijk return mask_arr
[docs] def create_white_noise_map(self, sigma_N, counts=None, seed=None, inf_to_zero=True): """ Create a white noise map with the given standard deviation. The sigma in each pixel is then scaled by the counts 1/sqrt(counts). If you want to generate multiple random catalogues, you need to set a different seed manually for each catalogue. If you want to use different noise level per pixel, you can either pass a 3D array of sigma_N, or a single number and a 3D array of counts. You can usually pass ``self.counts`` as the counts array, but do check the counts are set up correctly by ``plot_map(self.counts, self.wproj)``. Finally, note that the noise map is not masked by the survey selection function. You can mask the noise map manually by ``noise_map *= self.W_HI``. Parameters ---------- sigma_N: float or array. The standard deviation of the white noise. counts: array, default None. The counts in each pixel. If None, the counts will be one across the cube. seed: int, default None. If none, the seed is pulled from OS as described by the numpy documentation. inf_to_zero: bool, default True. If True, the inf values in the noise map will be set to zero. Returns ------- noise_map: array. The white noise map. """ if counts is None: counts = np.ones(self.data.shape, dtype=self.data.dtype) else: counts = np.asarray(counts, dtype=real_dtype_from_array(self.data)) rng = np.random.default_rng(seed=seed) noise_map = rng.normal( scale=sigma_N / np.sqrt(counts), size=self.data.shape ).astype(real_dtype_from_array(self.data), copy=False) if inf_to_zero: noise_map[np.isinf(noise_map)] = 0.0 return noise_map
[docs] def check_is_map_noiselike_using_pca(self, A_mat, data=None, sigma_N=1.0): """ Use the source mixing matrix from eigendecomposition of the covariance matrix, project out the map data with more and more modes, and check if the variance of the residual map behaves like white noise. You can use :func:`meer21cm.util.pca_clean` to retrieve the source mixing matrix: .. code-block:: python N_fg = 15 # check 15 modes removed res_map, A_mat = pca_clean(ps.data, N_fg, weights=ps.W_HI, return_A=True) res, noise = ps.check_is_map_noiselike_using_pca(A_mat) plt.plot(res / noise) If the residual map is noise-like, the plot should decrease and eventually reach a plateau. If you know the expected std of the map (per hit), you can pass it to ``sigma_N`` to scale the noise variance, and the plateau should be close to 1. Note that, the input data should be the mean-centered data. You can use :func:`meer21cm.util.mean_center_signal` to mean-center the data if needed. Parameters ---------- A_mat: array. The source mixing matrix. data: array, default None. The data to be projected out. If None, the class attribute ``self.data`` will be used. sigma_N: float. """ res_var = [] noise_var = [] if data is None: data = self.data for i in range(A_mat.shape[1]): R_mat = np.eye(self.nu.size) - np.dot( A_mat[:, : i + 1], A_mat[:, : i + 1].T ) var_attenuation = np.trace(R_mat.T @ R_mat) / self.nu.size data_res = np.einsum("ij, abj -> abi", R_mat, data) res_var.append((data_res * np.sqrt(self.counts))[self.W_HI > 0].var()) noise_var.append(var_attenuation) res_var = np.array(res_var) noise_var = np.array(noise_var) * sigma_N**2 return res_var, noise_var
[docs] def generate_full_healpix_map(self, data=None, fill_value=np.nan): """ Generate a full healpix map from the data. Pixels not included in the data will be filled with the `fill_value` (default is `np.nan`). Parameters ---------- data: array, default None. The data to be included in the full healpix map. If None, the class attribute ``self.data`` will be used. fill_value: float, default np.nan. The value to fill the pixels not included in the data. Returns ------- full_healpix_map: array. The full healpix map. """ if self.skymap.format != "healpix": raise ValueError("Skymap format must be healpix, got " + self.skymap.format) if data is None: data = self.data num_ch = data.shape[-1] full_healpix_map = np.zeros((hp.nside2npix(self.hp_nside), num_ch)) + fill_value full_healpix_map[self.pixel_id] = data return full_healpix_map