# This file is part of MICMAC.
# Copyright (C) 2024 CNRS / SciPol developers
#
# MICMAC is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# MICMAC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with MICMAC. If not, see <https://www.gnu.org/licenses/>.
from functools import partial
import chex as chx
import healpy as hp
import jax
import jax.lax as jlax
import jax.numpy as jnp
import numpy as np
__all__ = [
'get_reduced_matrix_from_c_ell',
'get_reduced_matrix_from_c_ell_jax',
'get_c_ells_from_red_covariance_matrix',
'get_c_ells_from_red_covariance_matrix_JAX',
'get_c_ells_from_red_covariance_matrix',
'get_sqrt_reduced_matrix_from_matrix_jax',
'get_cell_from_map_jax',
'get_bool_array_in_boundary',
'alm_dot_product_JAX',
'JAX_almxfl',
'maps_x_red_covariance_cell_JAX',
'alms_x_red_covariance_cell_JAX',
'frequency_alms_x_obj_red_covariance_cell_JAX',
]
[docs]
def get_reduced_matrix_from_c_ell_jax(c_ells_input):
"""
Returns the input spectra in the format [lmax+1-lmin, nstokes, nstokes]
Expect c_ells_input to be sorted as TT, EE, BB, TE, TB, EB if 6 spectra are given
or EE, BB, EB if 3 spectra are given
or TT if 1 spectrum is given
Generate covariance matrix from c_ells assuming it's block diagonal,
in the "reduced" (prefix red) format, i.e. : [ell, nstokes, nstokes]
The input spectra doesn't have to start from ell=0,
and the output matrix spectra will start from the same lmin as the input spectra
Parameters
----------
c_ells_input: array of shape (n_correlations, lmax)
Input c_ells
Returns
-------
reduced_matrix: array of shape (lmax+1-lmin, nstokes, nstokes)
Reduced format of the covariance matrix
"""
c_ells_array = jnp.copy(c_ells_input)
n_correlations = c_ells_array.shape[0]
lmax_p1 = c_ells_array.shape[1]
# Getting number of Stokes parameters from the number of correlations within the input spectrum
if n_correlations == 1:
nstokes = 1
elif n_correlations == 3:
nstokes = 2
elif n_correlations == 4 or n_correlations == 6:
nstokes = 3
if n_correlations != 6:
# c_ells_array = jnp.vstack(
# (c_ells_array, jnp.repeat(jnp.zeros(lmax_p1), 6 - n_correlations))
# )
c_ells_array = jnp.vstack(
(c_ells_array, jnp.broadcast_to(jnp.zeros(lmax_p1), (6 - n_correlations, lmax_p1)).ravel(order='F'))
)
n_correlations = 6
else:
raise Exception(
'C_ells must be given as TT for temperature only ; EE, BB, EB for polarization only ; TT, EE, BB, TE, (TB, EB) for both temperature and polarization'
)
# Constructing the reduced matrix
reduced_matrix = jnp.zeros((lmax_p1, nstokes, nstokes))
## First diagonal elements
def fmap(i, j):
return jnp.einsum('l,sk->lsk', c_ells_array[i, :], jnp.eye(nstokes))[:, j]
reduced_matrix = reduced_matrix.at[:, :, :].set(
jax.vmap(fmap, in_axes=(0), out_axes=(1))(jnp.arange(nstokes), jnp.arange(nstokes))
)
## Then off-diagonal elements
if n_correlations > 1:
reduced_matrix = reduced_matrix.at[:, 0, 1].set(c_ells_array[nstokes, :])
reduced_matrix = reduced_matrix.at[:, 1, 0].set(c_ells_array[nstokes, :])
if n_correlations == 6:
reduced_matrix = reduced_matrix.at[:, 0, 2].set(c_ells_array[5, :])
reduced_matrix = reduced_matrix.at[:, 2, 0].set(c_ells_array[5, :])
reduced_matrix = reduced_matrix.at[:, 1, 2].set(c_ells_array[4, :])
reduced_matrix = reduced_matrix.at[:, 2, 1].set(c_ells_array[4, :])
return reduced_matrix
[docs]
def get_c_ells_from_red_covariance_matrix_JAX(red_cov_mat, nstokes=0):
"""
Retrieve the c_ell in the format [number_correlations, lmax+1-lmin],
from the reduced covariance matrix format [lmax+1-lmin, nstokes, nstokes],
assuming it's block diagonal
Depending of nstokes, the number of correlations corresponds to:
TT
EE, BB, EB
TT, EE, BB, TE, EB, TB
ATTENTION : Currently not optimised for JAX (the for loops must be replaced by JAX loops)
Parameters
----------
red_cov_mat: array[float] of dimensions [lmax+1-lmin, nstokes, nstokes]
reduced spectra of the covariance matrix
Returns
-------
c_ells: array of dimensions [n_correlations, lmax+1-lmin]
power specturm of the input reduced covariance matrix
"""
lmax = red_cov_mat.shape[0]
nstokes = jnp.where(nstokes == 0, red_cov_mat.shape[1], nstokes)
n_correl = jnp.int32(jnp.ceil(nstokes**2 / 2) + jnp.floor(nstokes / 2))
c_ells = jnp.zeros((n_correl, lmax))
for i in range(nstokes):
c_ells = c_ells.at[i, :].set(red_cov_mat[:, i, i])
if nstokes > 1:
c_ells = c_ells.at[nstokes, :].set(red_cov_mat[:, 0, 1])
if nstokes == 3:
c_ells = c_ells.at[nstokes + 2, :].set(red_cov_mat[:, 0, 2])
c_ells = c_ells.at[nstokes + 1, :].set(red_cov_mat[:, 1, 2])
return c_ells
[docs]
def get_sqrt_reduced_matrix_from_matrix_jax(red_matrix):
"""
Return matrix square root of covariance matrix in the format [lmax+1-lmin, nstokes, nstokes],
assuming it's block diagonal
The input matrix doesn't have to start from ell=0,
and the output matrix will start from the same lmin as the input matrix
The initial matrix HAVE to be positive semi-definite
Parameters
----------
red_matrix: array of dimensions [lmax+1-lmin, nstokes, nstokes]
reduced spectra of the covariance matrix
Returns
-------
reduced_sqrtm: array of dimensions [lmax+1-lmin, nstokes, nstokes]
matrix square root of the covariance matrix
"""
red_matrix = jnp.array(red_matrix, dtype=jnp.float64)
lmax = red_matrix.shape[0]
nstokes = red_matrix.shape[1]
reduced_sqrtm = jnp.zeros_like(red_matrix)
# Building the square root matrix from the eigenvalues of the initial one
eigvals, eigvect = jnp.linalg.eigh(red_matrix)
inv_eigvect = jnp.linalg.pinv(eigvect)
reduced_sqrtm = jnp.einsum('ljk,km,lm,lmn->ljn', eigvect, jnp.eye(nstokes), jnp.sqrt(jnp.abs(eigvals)), inv_eigvect)
return reduced_sqrtm
[docs]
def get_cell_from_map_jax(pixel_maps, lmax, n_iter=8):
"""
Return c_ell from pixel_maps with an associated lmax and iteration number of harmonic operations
Parameters
----------
pixel_maps: array of dimensions [nstokes, n_pix]
input maps
lmax: int
maximum ell for the spectrum
n_iter: int
number of iterations for harmonic operations
Returns
-------
c_ells: array of dimensions[n_correlations,lmin:lmax+1]
power specturm of the input maps
"""
# Wrapper for anafast, to prepare the pure callback of JAX
def wrapper_anafast(maps_, lmax=lmax, n_iter=n_iter):
maps_np = jax.tree.map(np.asarray, maps_)
return hp.anafast(maps_np, lmax=lmax, iter=n_iter)
# Pure call back of anafast, to be used with JAX for JIT compilation
@partial(jax.jit, static_argnums=1)
def pure_call_anafast(maps_, lmax):
"""Pure call back of anafast, to be used with JAX for JIT compilation"""
shape_output = (6, lmax + 1)
return jax.pure_callback(wrapper_anafast, jax.ShapeDtypeStruct(shape_output, np.float64), maps_)
# Getting nstokes from the input maps
if jnp.size(pixel_maps.shape) == 1:
nstokes = 1
else:
nstokes = pixel_maps.shape[0]
# Extending the pixel maps if they are given with only polarization Stokes parameters (nstokes=2)
if nstokes == 2:
pixel_maps_for_Wishart = jnp.vstack((jnp.zeros_like(pixel_maps[0]), pixel_maps))
else:
pixel_maps_for_Wishart = jnp.copy(pixel_maps)
c_ells_output = pure_call_anafast(pixel_maps_for_Wishart, lmax=lmax)
if nstokes == 2:
polar_indexes = jnp.array([1, 2, 4])
return c_ells_output[polar_indexes] # Return only polarization spectra if nstokes=2
return c_ells_output
[docs]
def get_bool_array_in_boundary(input_array, boundary):
"""
Return a boolean array of the same shape as the input array, with True values where the input array is within the boundary
Parameters
----------
input_array: array
array to test
boundary: array of dimension [2,dim(input_array)]
represents the boundary
Returns
-------
bool_array: array[bool]
boolean array of the same shape as the input array,
with True values where the input array is within the boundary
"""
return (input_array >= boundary[0]) & (input_array <= boundary[1])
[docs]
@partial(jax.jit, static_argnames=('lmax'))
def alm_dot_product_JAX(alm_1, alm_2, lmax):
"""
Return dot product of two alms
Parameters
----------
alm_1: array
input alms of shape (...,(lmax + 1) * (lmax // 2 + 1))
alm_2: array
input alms of shape (...,(lmax + 1) * (lmax // 2 + 1))
lmax: int
maximum ell for the power spectrum
Returns
-------
dot_product: float
dot product of the two alms
"""
real_part = alm_1.real * alm_2.real
imag_part = alm_1.imag * alm_2.imag
mask_true_m_contribution = jnp.where(jnp.arange(alm_1.shape[-1]) < lmax + 1, 1, 2)
# See https://healpy.readthedocs.io/en/latest/generated/healpy.sphtfunc.Alm.getidx.html#healpy.sphtfunc.Alm.getidx
# In HEALPix C++ and healpy, coefficients are stored ordered by m
# So the first [lmax+1] elements of the alm array are the m=0 coefficients,
# the next [lmax] are the m=1 coefficients, the following [lmax-1] are the m=2 coefficients,
# the next [lmax-2] are the m=3 coefficients, etc.
# and so on until the last element of the array which is the m=lmax coefficient.
return jnp.sum((real_part + imag_part) * mask_true_m_contribution)
[docs]
@partial(jax.jit, static_argnames=('lmax'))
def JAX_almxfl(alm, c_ell_x_, lmax):
"""
Return alms convolved with the covariance matrix c_ell_x_ given as input in the format [lmax+1-lmin, nstokes, nstokes],
assuming it's block diagonal, without the need of a pure callback to Healpy
Parameters
----------
alm: array
input alms of shape ((lmax + 1) * (lmax // 2 + 1))
c_ell_x_: array of shape [lmax+1]
input power spectrum
Returns
-------
alms_output: array
updated output alms of shape ((lmax + 1) * (lmax // 2 + 1))
"""
# Identifying the m indices of a set of alms according to Healpy convention
all_m_idx = jax.vmap(lambda m_idx: m_idx * (2 * lmax + 1 - m_idx) // 2)(jnp.arange(lmax + 1))
def func_scan(carry, ell):
"""
For a given ell, returns the alms convolved with the covariance matrix c_ell_x_ for all m
"""
_alm_carry = carry
mask_m = jnp.where(jnp.arange(lmax + 1) <= ell, c_ell_x_[ell], 1)
_alm_carry = _alm_carry.at[all_m_idx + ell].set(_alm_carry[all_m_idx + ell] * mask_m)
return _alm_carry, ell
alms_output, _ = jax.lax.scan(func_scan, jnp.copy(alm), jnp.arange(lmax + 1))
return alms_output
[docs]
def maps_x_red_covariance_cell_JAX(maps_input, red_matrix_sqrt, nside, lmin, n_iter=8):
"""
Return maps convolved with the harmonic covariance matrix given as input
in the format [lmax+1-lmin, nstokes, nstokes], assuming it's block diagonal
The input matrix have to start from ell=lmin, otherwise the lmax associated with the harmonic
operations will be wrong
Parameters
----------
maps_input: array[float] of shape [nstokes, n_pix]
input maps
red_matrix_sqrt: array[float] of shape [lmax+1-lmin, nstokes, nstokes]
input reduced spectra
nside: int
nside of the input maps
lmin: int
minimum ell for the spectrum
n_iter: int
number of iterations for harmonic operations
Returns
-------
maps_output: array[float] of shape [nstokes, n_pix]
input maps convolved with input spectra
"""
# Getting scalar parameters from the input covariance
all_params = 3
lmax = red_matrix_sqrt.shape[0] - 1 + lmin
nstokes = red_matrix_sqrt.shape[1]
# Building the full covariance matrix from the covariance matrix
red_decomp = jnp.zeros((lmax + 1, 3, 3)) # 3 is the maximum number of stokes parameters
if nstokes != 1:
red_decomp = red_decomp.at[lmin:, 3 - nstokes :, 3 - nstokes :].set(red_matrix_sqrt)
else:
red_decomp = red_decomp.at[lmin:].set(red_matrix_sqrt)
# Extending the pixel maps if they are given with only polarization Stokes parameters (nstokes=2)
if maps_input.shape[0] == 2:
maps_TQU = jnp.vstack((jnp.zeros_like(maps_input[0]), jnp.copy(maps_input)))
else:
maps_TQU = jnp.copy(maps_input)
# Wrapper for map2alm, to prepare the pure callback of JAX
def wrapper_map2alm(maps_, lmax=lmax, n_iter=n_iter, nside=nside):
maps_np = jax.tree.map(np.asarray, maps_).reshape((3, 12 * nside**2))
alm_T, alm_E, alm_B = hp.map2alm(maps_np, lmax=lmax, iter=n_iter)
return np.array([alm_T, alm_E, alm_B])
# Wrapper for alm2map, to prepare the pure callback of JAX
def wrapper_alm2map(alm_, lmax=lmax, nside=nside):
alm_np = jax.tree.map(np.asarray, alm_)
return hp.alm2map(alm_np, nside, lmax=lmax)
# Pure call back of map2alm, to be used with JAX for JIT compilation
@partial(jax.jit, static_argnums=(1, 2))
def pure_call_map2alm(maps_, lmax, nside):
shape_output = (3, (lmax + 1) * (lmax // 2 + 1))
return jax.pure_callback(
wrapper_map2alm,
jax.ShapeDtypeStruct(shape_output, np.complex128),
maps_.ravel(),
)
@partial(jax.jit, static_argnums=(1, 2))
def pure_call_alm2map(alm_, lmax, nside):
shape_output = (3, 12 * nside**2)
return jax.pure_callback(wrapper_alm2map, jax.ShapeDtypeStruct(shape_output, np.float64), alm_)
alms_input = pure_call_map2alm(maps_TQU, lmax=lmax, nside=nside)
# Multiplying the nstokes's jth alms with the covariance matrix for each stokes parameter contribution
def scan_func(carry, nstokes_j):
val_alms_j, nstokes_i = carry
result_almxfl = JAX_almxfl(alms_input[nstokes_j], red_decomp[:, nstokes_i, nstokes_j], lmax)
new_carry = (val_alms_j + result_almxfl, nstokes_i)
return new_carry, val_alms_j + result_almxfl
# Multiplying the nstokes's ith alms with the covariance matrix
def fmap(nstokes_i):
return jlax.scan(
scan_func,
(jnp.zeros_like(alms_input[nstokes_i]), nstokes_i),
jnp.arange(all_params),
)[
0
][0]
# Multiplying the alms with the covariance matrix
alms_output = jax.vmap(fmap, in_axes=0)(jnp.arange(all_params))
# Retrieving the maps from the alms convolved with the input covariance matrix
maps_output = pure_call_alm2map(alms_output, nside=nside, lmax=lmax)
if nstokes != 1:
return maps_output[3 - nstokes :, ...] # If only polarization maps are given, return only polarization maps
return maps_output
[docs]
def alms_x_red_covariance_cell_JAX(alm_Stokes_input, red_matrix, lmin):
"""
Return alms convolved with the input harmonic covariance matrix in the format [lmax+1-lmin, nstokes, nstokes]
given as input, assuming it's block diagonal
The input matrix have to start from ell=lmin, otherwise the lmax associated with the harmonic
operations will be wrong
Parameters
----------
alms_Stokes_input: arary of shape [nstokes, (lmax + 1) * (lmax // 2 + 1)]
input alms
red_matrix: array of shape [lmax+1-lmin, nstokes, nstokes]
input reduced covariance matrix
lmin: int
minimum ell for the spectrum
Returns
-------
maps_output: array of shape [nstokes, n_pix]
output maps
"""
# Getting scalar parameters from the input covariance
lmax = red_matrix.shape[0] - 1 + lmin
nstokes = red_matrix.shape[1]
# Building the full covariance matrix from the covariance matrix
red_decomp = jnp.zeros((lmax + 1, nstokes, nstokes))
# Preparing the alms and the covariance matrix for the convolution
alm_input = jnp.copy(alm_Stokes_input)
red_decomp = red_decomp.at[lmin:].set(red_matrix)
# Multiplying the alms with the covariance matrix for each stokes parameter contribution
def scan_func(carry, nstokes_j):
"""
For a given nstokes_j, returns the alms convolved with the covariance matrix to be summed up for all nstokes_i
"""
val_alms_j, nstokes_i = carry
result_almxfl = JAX_almxfl(alm_input[nstokes_j], red_decomp[:, nstokes_i, nstokes_j], lmax)
new_carry = (val_alms_j + result_almxfl, nstokes_i)
return new_carry, val_alms_j + result_almxfl
# Multiplying the ie alms with the covariance matrix
def fmap(nstokes_i):
return jlax.scan(
scan_func,
(jnp.zeros_like(alm_input[0]), nstokes_i),
jnp.arange(nstokes),
)[
0
][0]
# Multiplying the alms with the covariance matrix
alms_output = jax.vmap(fmap, in_axes=0)(jnp.arange(nstokes))
return alms_output
[docs]
def frequency_alms_x_obj_red_covariance_cell_JAX(freq_alm_Stokes_input, freq_red_matrix, lmin, n_iter=8):
"""
Return frequency alms convolved with the covariance matrix given as input, assuming it's block diagonal
The freq_red_matrix can have its first dimension reprensenting anything,
in particular either the number of frequencies or the number of components
Parameters
----------
freq_alm_Stokes_input: array of shape [frequency, nstokes, (lmax + 1) * (lmax // 2 + 1))]
input alms per frequency
red_matrix: array of shape [first_dim, frequency, lmax+1-lmin, nstokes, nstokes]
input reduced covariance matrix
lmin: int
minimum ell for the power spectrum
n_iter: int
number of iterations for harmonic operations
Returns
-------
maps_output: array[float] of shape [nstokes, n_pix]
output maps
"""
# Getting scalar parameters from the input covariance
lmax = freq_red_matrix.shape[2] - 1 + lmin
first_dim_red_matrix = freq_red_matrix.shape[
0
] # Can be any dimension, in the use of the function either n_frequencies or n_components
n_frequencies = freq_red_matrix.shape[1]
nstokes = freq_red_matrix.shape[3]
# Few tests to check the input
chx.assert_axis_dimension(freq_red_matrix, 1, n_frequencies)
chx.assert_axis_dimension(freq_red_matrix, 2, lmax + 1 - lmin)
chx.assert_axis_dimension(freq_red_matrix, 3, nstokes)
chx.assert_axis_dimension(freq_red_matrix, 4, nstokes)
chx.assert_shape(freq_alm_Stokes_input, (n_frequencies, nstokes, (lmax + 1) * (lmax // 2 + 1)))
freq_alm_input = jnp.copy(freq_alm_Stokes_input)
def scan_func(carry, frequency_j):
"""
For a given frequency_j, returns the alms convolved with the frequency covariance matrix to be summed up for all nstokes_i
"""
val_alms_j, idx_i = carry
result_almxfl = alms_x_red_covariance_cell_JAX(
freq_alm_input[frequency_j], freq_red_matrix[idx_i, frequency_j, ...], lmin=lmin
)
new_carry = (val_alms_j + result_almxfl, idx_i)
return new_carry, val_alms_j + result_almxfl
# Multiplying the ie alms with the covariance matrix
def fmap(idx_i):
"""
For a given idx_i, returns the alms convolved with the frequency covariance matrix to be summed up for all corresponding frequencies
"""
return jlax.scan(
scan_func,
(jnp.zeros_like(freq_alm_input[0]), idx_i),
jnp.arange(n_frequencies),
)[
0
][0]
# Multiplying the frequency alms with the first dimension-frequency covariance matrix
freq_alms_output = jax.vmap(fmap, in_axes=0)(jnp.arange(first_dim_red_matrix))
return freq_alms_output
## Numpy version
import healpy as hp
import numpy as np
[docs]
def get_reduced_matrix_from_c_ell(c_ells_input):
"""
Returns the input spectra in the format [lmax+1-lmin, nstokes, nstokes]
Expect c_ells_input to be sorted as TT, EE, BB, TE, TB, EB if 6 spectra are given
or EE, BB, EB if 3 spectra are given
or TT if 1 spectrum is given
Generate covariance matrix from c_ells assuming it's block diagonal,
in the "reduced" (prefix red) format, i.e. : [ell, nstokes, nstokes]
The input spectra doesn't have to start from ell=0,
and the output matrix spectra will start from the same lmin as the input spectra
Parameters
----------
c_ells_input: array of shape (n_correlations, lmax)
input power spectra
Returns
-------
reduced_matrix: array of shape (lmax+1-lmin, nstokes, nstokes)
reduced covariance matrix
"""
c_ells_array = np.copy(c_ells_input)
n_correlations = c_ells_array.shape[0]
assert n_correlations == 1 or n_correlations == 3 or n_correlations == 6
lmax_p1 = c_ells_array.shape[1]
if n_correlations == 1:
nstokes = 1
elif n_correlations == 3:
nstokes = 2
elif n_correlations == 4 or n_correlations == 6:
nstokes = 3
if n_correlations != 6:
for i in range(6 - n_correlations):
c_ells_array = np.vstack((c_ells_array, np.zeros(lmax_p1)))
n_correlations = 6
else:
raise Exception(
'C_ells must be given as TT for temperature only ; EE, BB, EB for polarization only ; TT, EE, BB, TE, (TB, EB) for both temperature and polarization'
)
reduced_matrix = np.zeros((lmax_p1, nstokes, nstokes))
for i in range(nstokes):
reduced_matrix[:, i, i] = c_ells_array[i, :]
# for j in range(n_correlations-nstokes):
if n_correlations > 1:
reduced_matrix[:, 0, 1] = c_ells_array[nstokes, :]
reduced_matrix[:, 1, 0] = c_ells_array[nstokes, :]
if n_correlations == 6:
reduced_matrix[:, 0, 2] = c_ells_array[5, :]
reduced_matrix[:, 2, 0] = c_ells_array[5, :]
reduced_matrix[:, 1, 2] = c_ells_array[4, :]
reduced_matrix[:, 2, 1] = c_ells_array[4, :]
return reduced_matrix
[docs]
def get_c_ells_from_red_covariance_matrix(red_cov_mat):
"""
Retrieve the c_ell in the format [number_correlations, lmax+1-lmin],
from the reduced covariance matrix format [lmax+1-lmin, nstokes, nstokes],
assuming it's block diagonal
Depending of nstokes, the number of correlations corresponds to:
TT
EE, BB, EB
TT, EE, BB, TE, EB, TB
"""
lmax = red_cov_mat.shape[0]
nstokes = red_cov_mat.shape[1]
n_correl = int(np.ceil(nstokes**2 / 2) + np.floor(nstokes / 2))
c_ells = np.zeros((n_correl, lmax))
for i in range(nstokes):
c_ells[i, :] = red_cov_mat[:, i, i]
if nstokes > 1:
c_ells[nstokes, :] = red_cov_mat[:, 0, 1]
if nstokes == 3:
# c_ells[nstokes+1,:] = red_cov_mat[:,0,2]
# c_ells[nstokes+2,:] = red_cov_mat[:,1,2]
c_ells[nstokes + 2, :] = red_cov_mat[:, 0, 2]
c_ells[nstokes + 1, :] = red_cov_mat[:, 1, 2]
return c_ells
def get_sqrt_reduced_matrix_from_matrix(red_matrix, tolerance=10 ** (-15)):
"""
Return matrix square root of covariance matrix in the format [lmax+1-lmin, nstokes, nstokes],
assuming it's block diagonal
The input matrix doesn't have to start from ell=0,
and the output matrix will start from the same lmin as the input matrix
The initial matrix HAVE to be positive semi-definite
Parameters
----------
red_matrix: array of shape (lmax, nstokes, nstokes)
reduced covariance matrix
Returns
-------
reduced_sqrtm: array of shape (lmax, nstokes, nstokes)
reduced matrix square root of the covariance matrix
"""
lmax = red_matrix.shape[0]
nstokes = red_matrix.shape[1]
reduced_sqrtm = np.zeros_like(red_matrix)
for ell in range(red_matrix.shape[0]):
eigvals, eigvect = np.linalg.eigh(red_matrix[ell, :, :])
try:
inv_eigvect = np.linalg.pinv(eigvect)
except:
raise Exception(
'Error for ell=', ell, 'eigvals', eigvals, 'eigvect', eigvect, 'red_matrix', red_matrix[ell, :, :]
)
if not (np.all(eigvals > 0)) and (np.abs(eigvals[eigvals < 0]) > tolerance):
raise Exception(
'Covariance matrix not consistent with a negative eigval for ell=',
ell,
'eigvals',
eigvals,
'eigvect',
eigvect,
'red_matrix',
red_matrix[ell, :, :],
)
reduced_sqrtm[ell] = np.einsum(
'jk,km,m,mn->jn', eigvect, np.eye(nstokes), np.sqrt(np.abs(eigvals)), inv_eigvect
)
return reduced_sqrtm
def get_cell_from_map(pixel_maps, lmax, n_iter=8):
"""
Return c_ell from pixel_maps with an associated lmax and iteration number of harmonic operations
Parameters
----------
pixel_maps: array of shape (nstokes, n_pix)
input maps
lmax: int
maximum ell for the spectrum
n_iter: int
number of iterations for harmonic operations
Returns
-------
c_ells: array of shape (nstokes, lmax+1)
power spectra from the input maps
"""
if len(pixel_maps.shape) == 1:
nstokes = 1
else:
nstokes = pixel_maps.shape[0]
if nstokes == 2:
pixel_maps_for_Wishart = np.vstack((np.zeros_like(pixel_maps[0]), pixel_maps))
else:
pixel_maps_for_Wishart = pixel_maps
c_ells_Wishart = hp.anafast(pixel_maps_for_Wishart, lmax=lmax, iter=n_iter)
if nstokes == 2:
polar_indexes = np.array([1, 2, 4])
c_ells_Wishart = c_ells_Wishart[polar_indexes]
return c_ells_Wishart
def maps_x_reduced_matrix_generalized_sqrt_sqrt(maps_TQU_input, red_matrix_sqrt, lmin, n_iter=8):
"""
NOT USED -- TO BE REMOVED
Return maps convolved with the harmonic covariance matrix given as input
in the format [lmax+1-lmin, nstokes, nstokes], assuming it's block diagonal
The input matrix have to start from ell=lmin, otherwise the lmax associated with the harmonic
operations will be wrong
Parameters
----------
maps_input: input maps of shape (nstokes, n_pix)
red_matrix_sqrt: input reduced spectra of shape (lmax+1-lmin, nstokes, nstokes)
nside: nside of the input maps, int
lmin: minimum ell for the spectrum, int
n_iter: number of iterations for harmonic operations, int
Returns
-------
maps_output: input maps convolved with input spectra, dimensions (nstokes, n_pix)
"""
lmax = red_matrix_sqrt.shape[0] - 1 + lmin
nstokes = red_matrix_sqrt.shape[1]
all_params = int(np.where(nstokes > 1, 3, 1))
if len(maps_TQU_input.shape) == 1:
nside = int(np.sqrt(len(maps_TQU_input) / 12))
else:
nside = int(np.sqrt(len(maps_TQU_input[0]) / 12))
red_sqrt_decomp = np.zeros((lmax + 1, all_params, all_params))
if nstokes != 1:
red_sqrt_decomp[lmin:, 3 - nstokes :, 3 - nstokes :] = red_matrix_sqrt
else:
red_sqrt_decomp[lmin:, ...] = red_matrix_sqrt
if maps_TQU_input.shape[0] == 2:
maps_TQU = np.vstack((np.zeros_like(maps_TQU_input[0]), np.copy(maps_TQU_input)))
else:
maps_TQU = np.copy(maps_TQU_input)
alms_input = hp.map2alm(maps_TQU, lmax=lmax, iter=n_iter)
alms_output = np.zeros_like(alms_input)
for i in range(all_params):
alms_j = np.zeros_like(alms_input[i])
for j in range(all_params):
alms_j += hp.almxfl(alms_input[j], red_sqrt_decomp[:, i, j], inplace=False)
alms_output[i] = np.copy(alms_j)
maps_output = hp.alm2map(alms_output, nside, lmax=lmax)
if nstokes != 1:
return maps_output[3 - nstokes :, ...]
return maps_output