Source code for micmac.likelihood.harmonic
# This file is part of MICMAC.
# Copyright (C) 2024 CNRS / SciPol developers
#
# MICMAC is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# MICMAC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with MICMAC. If not, see <https://www.gnu.org/licenses/>.
import time
from collections import namedtuple
from functools import partial
import chex as chx
import healpy as hp
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
import numpyro
from jax import config
from micmac.likelihood.sampling import (
SamplingFunctions,
multivariate_Metropolis_Hasting_step_numpyro_bounded,
)
from micmac.toolbox.tools import (
frequency_alms_x_obj_red_covariance_cell_JAX,
get_c_ells_from_red_covariance_matrix,
get_reduced_matrix_from_c_ell,
get_reduced_matrix_from_c_ell_jax,
)
from micmac.toolbox.utils import generate_power_spectra_CAMB
__all__ = [
'HarmonicMicmacSampler',
]
config.update('jax_enable_x64', True)
[docs]
class HarmonicMicmacSampler(SamplingFunctions):
[docs]
def __init__(
self,
nside,
lmax,
nstokes,
frequency_array,
freq_noise_c_ell,
pos_special_freqs=[0, -1],
n_components=3,
lmin=2,
n_iter=8,
mask=None,
spv_nodes_b=[],
biased_version=False,
boundary_Bf=None,
boundary_r=None,
step_size_r=1e-4,
covariance_Bf=None,
indexes_free_Bf=False,
number_iterations_sampling=100,
number_iterations_done=0,
seed=0,
disable_chex=True,
instrument_name='SO_SAT',
):
"""
Main MICMAC Harmonic sampling object to initialize and launch the Metropolis-Hastings (MH) sampling in harmonic domain.
The MH sampling will store Bf and r parameters.
Parameters
----------
nside: int
nside of the input frequency maps
lmax: int
maximum multipole for the spherical harmonics transforms and harmonic domain objects,
nstokes: int
number of Stokes parameters
frequency_array: array[float]
array of frequencies, in GHz
freq_noise_c_ell: array[float] of dimensions [frequencies, frequencies, lmax+1-lmin] or [frequencies, frequencies, lmax] (in which case it will be cut to lmax+1-lmin)
optional, noise power spectra for each frequency, in uK^2, dimensions
pos_special_freqs: list[int] (optional)
indexes of the special frequencies in the frequency array respectively for synchrotron and dust, default is [0,-1] for first and last frequencies
n_components: int (optional)
number of components for the mixing matrix, default 3
lmin: int (optional)
minimum multipole for the spherical harmonics transforms and harmonic domain objects, default 2
n_iter: int (optional)
number of iterations the spherical harmonics transforms (for map2alm transformations), default 8
mask: None or array[float] of dimensions [n_pix] (optional)
mask to use in the sampling ; if not given, no mask is used, default None
Note: the mask WILL NOT be applied to the input maps, it will be only used for the propagated noise covariance
WARNING: Masked input are not currently supported, expect E-to-B leakage
spv_nodes_b: list[dictionaries] (optional)
tree for the spatial variability, to generate from a yaml file, default []
in principle set up by get_nodes_b
WARNING: The spatial variability is not currently supported, but will be passed to MicmacSampler obj when using create_Harmonic_MicmacSampler_from_MicmacSampler_obj
biased_version: bool (optional)
use the biased version of the likelihood, so no computation of the correction term, default False
boundary_Bf: None or array[float] (optional)
minimum and maximum Bf values accepted for Bf sample, set to [-inf,inf] for each Bf parameter if None, default None
boundary_r: None or array[float] (optional)
minimum and maximum r values accepted for r sample, set to [-inf,inf] if None, default None
step_size_r: float (optional)
step size for the Metropolis-Hastings sampling of r, default 1e-4
covariance_Bf: None or array[float] of dimensions [(n_frequencies-len(pos_special_freqs))*(n_components-1), (n_frequencies-len(pos_special_freqs))*(n_components-1)] (optional)
covariance for the Metropolis-Hastings sampling of Bf ; will be repeated if multiresoltion case, default None
number_iterations_sampling: int (optional)
maximum number of iterations for the sampling, default 100
number_iterations_done: int (optional)
number of iterations already accomplished, in case the chain is resuming from a previous run, usually set by exterior routines, default 0
seed: int or array[jnp.uint32] (optional)
seed for the JAX PRNG random number generator to start the chain or array of a previously computed seed, default 0
disable_chex: bool (optional)
disable chex tests (to improve speed)
instrument_name: str (optional)
name of the instrument as expected by cmbdb or given as 'customized_instrument' if redefined by user, default 'SO_SAT'
see https://github.com/dpole/cmbdb/blob/master/cmbdb/experiments.yaml
"""
# Initialising the parent class
super().__init__(
nside=nside,
lmax=lmax,
nstokes=nstokes,
lmin=lmin,
frequency_array=frequency_array,
pos_special_freqs=pos_special_freqs,
n_components=n_components,
freq_inverse_noise=None,
freq_noise_c_ell=freq_noise_c_ell,
n_iter=n_iter,
mask=mask,
spv_nodes_b=spv_nodes_b,
)
# Run settings
self.biased_version = bool(
biased_version
) # If True, use the biased version of the likelihood, so no computation of the correction term
# CMB parameters
assert (freq_noise_c_ell.shape == (self.n_frequencies, self.n_frequencies, self.lmax + 1 - self.lmin)) or (
freq_noise_c_ell.shape == (self.n_frequencies, self.n_frequencies, self.lmax + 1)
)
self.freq_noise_c_ell = freq_noise_c_ell
# Metropolis-Hastings step-size and covariance parameters
self.covariance_Bf = covariance_Bf
self.step_size_r = step_size_r
if boundary_Bf is None:
boundary_Bf = jnp.zeros((2, (self.n_frequencies - len(self.pos_special_freqs)) * (self.n_components - 1)))
boundary_Bf = boundary_Bf.at[0, :].set(-jnp.inf)
boundary_Bf = boundary_Bf.at[1, :].set(jnp.inf)
if boundary_r is None:
boundary_r = jnp.array([-jnp.inf, jnp.inf])
assert np.array(boundary_Bf).shape == (
2,
(self.n_frequencies - len(self.pos_special_freqs)) * (self.n_components - 1),
)
assert np.array(boundary_r).shape == (2,)
self.boundary_Bf_r = jnp.hstack((boundary_Bf, jnp.expand_dims(boundary_r, axis=0).T))
# Sampling parameters
if indexes_free_Bf is False:
# If given as False, then we sample all Bf
indexes_free_Bf = jnp.arange(self.len_params)
self.indexes_free_Bf = jnp.array(indexes_free_Bf)
assert (
jnp.size(self.indexes_free_Bf) <= self.len_params
) # The number of free parameters should be less than the total number of parameters
assert (
jnp.max(self.indexes_free_Bf) <= self.len_params
) # The indexes should be in the range of the total number of parameters
assert (
jnp.min(self.indexes_free_Bf) >= 0
) # The indexes should be in the range of the total number of parameters
self.number_iterations_sampling = int(
number_iterations_sampling
) # Maximum number of iterations for the sampling
self.number_iterations_done = int(
number_iterations_done
) # Number of iterations already accomplished, in case the chain is resuming from a previous run
self.seed = seed
# Optional parameters
self.disable_chex = disable_chex
self.instrument_name = instrument_name
# Samples preparation
self.all_params_mixing_matrix_samples = jnp.empty(0)
self.all_samples_r = jnp.empty(0)
[docs]
def generate_CMB(self, return_spectra=True):
"""
Returns CMB spectra of scalar modes only and tensor modes only (with r=1)
Both CMB spectra are either returned in the usual form [number_correlations,lmax+1],
or in the red_cov form if return_spectra == False
"""
# Selecting the relevant auto- and cross-correlations from CAMB spectra
if self.nstokes == 2:
# EE, BB
partial_indices_polar = np.array([1, 2])
elif self.nstokes == 1:
# TT
partial_indices_polar = np.array([0])
else:
# TT, EE, BB, EB
partial_indices_polar = np.arange(4)
# Generating the CMB power spectra
all_spectra_r0 = generate_power_spectra_CAMB(self.nside * 2, r=0, typeless_bool=True)
all_spectra_r1 = generate_power_spectra_CAMB(self.nside * 2, r=1, typeless_bool=True)
# Retrieve the scalar mode spectrum
camb_cls_r0 = all_spectra_r0['total'][: self.lmax + 1, partial_indices_polar]
# Retrieve the tensor mode spectrum
tensor_spectra_r1 = all_spectra_r1['tensor'][: self.lmax + 1, partial_indices_polar]
theoretical_r1_tensor = np.zeros((self.n_correlations, self.lmax + 1))
theoretical_r0_total = np.zeros_like(theoretical_r1_tensor)
theoretical_r1_tensor[: self.nstokes, ...] = tensor_spectra_r1.T
theoretical_r0_total[: self.nstokes, ...] = camb_cls_r0.T
if return_spectra:
# Return spectra in the form [number_correlations,lmax+1]
return theoretical_r0_total, theoretical_r1_tensor
# Return spectra in the form of the reduced covariance matrix, [lmax+1-lmin,number_correlations,number_correlations]
theoretical_red_cov_r1_tensor = get_reduced_matrix_from_c_ell(theoretical_r1_tensor)[self.lmin :]
theoretical_red_cov_r0_total = get_reduced_matrix_from_c_ell(theoretical_r0_total)[self.lmin :]
return theoretical_red_cov_r0_total, theoretical_red_cov_r1_tensor
[docs]
def generate_input_freq_maps_from_fgs(
self, freq_maps_fgs, r_true=0, return_only_freq_maps=True, return_only_maps=False
):
"""
Generate input frequency maps (CMB+foregrounds) from the input frequency foregrounds maps,
return either the full frequency maps, the full frequency and CMB maps alone,
or the full frequency and CMB maps with the theoretical reduced covariance matrices for the CMB scalar and tensor modes
Parameters
----------
freq_maps_fgs: array[float] of dimensions [n_frequencies,nstokes,n_pix]
input frequency foregrounds maps
return_only_freq_maps: bool (optional)
return only the full frequency maps, bool
return_only_maps: bool (optional)
return only the full frequency and CMB maps alone, bool
Returns
-------
input_freq_maps: array[float] of dimensions [n_frequencies,nstokes,n_pix]
input frequency maps
input_cmb_maps: array[float] of dimensions [nstokes,n_pix]
input CMB maps
theoretical_red_cov_r0_total: array[float] of dimensions [lmax+1-lmin,nstokes,nstokes]
theoretical reduced covariance matrix for the CMB scalar modes
theoretical_red_cov_r1_tensor: array[float] of dimensions [lmax+1-lmin,nstokes,nstokes]
theoretical reduced covariance matrix for the CMB tensor modes
"""
indices_polar = np.array([1, 2, 4])
# Generate CMB from CAMB
theoretical_red_cov_r0_total, theoretical_red_cov_r1_tensor = self.generate_CMB(return_spectra=False)
# Retrieve fiducial CMB power spectra
true_cmb_specra = get_c_ells_from_red_covariance_matrix(
theoretical_red_cov_r0_total + r_true * theoretical_red_cov_r1_tensor
)
true_cmb_specra_extended = np.zeros((6, self.lmax + 1))
true_cmb_specra_extended[indices_polar, self.lmin :] = true_cmb_specra
# Generate input frequency maps
input_cmb_maps_alt = hp.synfast(true_cmb_specra_extended, nside=self.nside, new=True, lmax=self.lmax)[1:, ...]
input_cmb_maps = np.broadcast_to(input_cmb_maps_alt, (self.n_frequencies, self.nstokes, self.n_pix))
input_freq_maps = input_cmb_maps + freq_maps_fgs
if return_only_freq_maps:
return input_freq_maps
if return_only_maps:
return input_freq_maps, input_cmb_maps
return input_freq_maps, input_cmb_maps, theoretical_red_cov_r0_total, theoretical_red_cov_r1_tensor
[docs]
def update_variable(self, all_samples, new_samples_to_add):
"""
Update the samples with new samples to add by stacking them
Parameters
----------
all_samples: array[float] of dimensions [n_samples,n_pix]
previous samples to update
new_samples_to_add: array[float] of dimensions [n_samples,n_pix]
new samples to add
Returns
-------
all_samples: array[float] of dimensions [n_samples+n_samples,n_pix]
updated samples
"""
if jnp.size(all_samples) == 0:
return new_samples_to_add
elif jnp.size(new_samples_to_add.shape) == 1:
return jnp.hstack([all_samples, new_samples_to_add])
else:
return jnp.vstack([all_samples, new_samples_to_add])
[docs]
def update_samples_MH(self, all_samples):
"""
Update the samples with new samples to add for r and Bf
Parameters
----------
all_samples: dictionary
dictionary of all the samples to update
"""
# Update the samples of r
self.all_samples_r = self.update_variable(self.all_samples_r, all_samples[..., -1])
# Update the samples of Bf
self.all_params_mixing_matrix_samples = self.update_variable(
self.all_params_mixing_matrix_samples, all_samples[..., :-1]
)
[docs]
def get_alm_from_frequency_maps(self, input_freq_maps):
"""
Get the alms from the input frequency maps using JAX
Parameters
----------
input_freq_maps : array[float] of dimensions [n_frequencies,nstokes,n_pix]
input frequency maps
Returns
-------
freq_alms_input_maps : array[float] of dimensions [n_frequencies,nstokes,(lmax+1)*(lmax+2)//2]
alms from the input frequency maps
the (lmax+1)*(lmax+2)//2 dimension is the flattened number of lm coefficients stored according to the Healpy convention
"""
assert input_freq_maps.shape == (self.n_frequencies, self.nstokes, self.n_pix)
## Preparing JAX wrapper for the Healpy map2alm function
def wrapper_map2alm(maps_, lmax=self.lmax, n_iter=self.n_iter, nside=self.nside):
maps_np = jax.tree.map(np.asarray, maps_).reshape((3, 12 * nside**2))
alm_T, alm_E, alm_B = hp.map2alm(maps_np, lmax=lmax, iter=n_iter)
return np.array([alm_T, alm_E, alm_B])
## Preparing JAX pure call back for the Healpy map2alm function
@partial(jax.jit, static_argnums=(1))
def pure_call_map2alm(maps_, lmax):
shape_output = (
3,
(lmax + 1) * (lmax // 2 + 1),
) ## Shape of the output alms : [3 for all Stokes params, (lmax+1)*(lmax+2)//2 for all alms in the Healpy convention]
return jax.pure_callback(wrapper_map2alm, jax.ShapeDtypeStruct(shape_output, np.complex128), maps_.ravel())
JAX_input_freq_maps = jnp.array(input_freq_maps)
def get_freq_alm(num_frequency):
input_map_extended = jnp.vstack(
(jnp.zeros_like(JAX_input_freq_maps[num_frequency, 0]), JAX_input_freq_maps[num_frequency, ...])
) ## Adding empty temperature map
all_alms = jnp.array(
pure_call_map2alm(input_map_extended, lmax=self.lmax)
) ## Getting alms for all stokes parameters
return all_alms[3 - self.nstokes :, ...] ## Removing the empty temperature alms
return jax.vmap(get_freq_alm)(jnp.arange(self.n_frequencies)) ## Getting alms for all frequencies
# def get_masked_alm_from_frequency_maps(self, input_freq_maps):
# """
# For simulation pruposes, get masked alms from input full sky frequency maps using JAX,
# to avoid E->B leakage issues
# Parameters
# ----------
# input_freq_maps : array[float] of dimensions [n_frequencies,nstokes,n_pix]
# input frequency maps
# Returns
# -------
# freq_alms_input_maps : array[float] of dimensions [n_frequencies,nstokes,(lmax+1)*(lmax+2)//2]
# alms from the input frequency maps
# the (lmax+1)*(lmax+2)//2 dimension is the flattened number of lm coefficients stored according to the Healpy convention
# Notes
# -------
# The final alms will see their low multipoles below self.lmin set to zero to avoid mask coupling scales to low multipoles
# """
# assert input_freq_maps.shape == (self.n_frequencies, self.nstokes, self.n_pix)
# # Wrapper for alm2map, to prepare the pure callback of JAX
# def wrapper_alm2map(alm_, lmax=self.lmax, nside=self.nside):
# alm_np = jax.tree.map(np.asarray, alm_)
# return hp.alm2map(alm_np, nside, lmax=lmax)
# ## Preparing JAX pure call back for the Healpy map2alm function
# @partial(jax.jit, static_argnums=(1, 2))
# def pure_call_alm2map(alm_, lmax, nside):
# shape_output = (3, 12 * nside**2)
# return jax.pure_callback(wrapper_alm2map, jax.ShapeDtypeStruct(shape_output, np.float64), alm_)
# def get_freq_alm(num_frequency, JAX_input_freq_alms):
# input_alm_extended = jnp.vstack(
# (jnp.zeros_like(JAX_input_freq_alms[num_frequency, 0]), JAX_input_freq_alms[num_frequency, ...])
# ) ## Adding empty temperature map
# all_maps = jnp.array(
# pure_call_alm2map(input_alm_extended, lmax=self.lmax, nside=self.nside)
# ) ## Getting alms for all stokes parameters
# return all_maps[3 - self.nstokes :, ...] ## Removing the empty temperature alms
# full_sky_alms = self.get_alm_from_frequency_maps(input_freq_maps)
# E_modes_only_alms = jnp.zeros_like(full_sky_alms)
# B_modes_only_alms = jnp.zeros_like(full_sky_alms)
# E_modes_only_alms = E_modes_only_alms.at[:, -2, :].set(full_sky_alms[:, -2, :])
# B_modes_only_alms = B_modes_only_alms.at[:, -1, :].set(full_sky_alms[:, -1, :])
# freq_map_E_modes = jax.vmap(get_freq_alm, in_axes=(0, None))(
# jnp.arange(self.n_frequencies), E_modes_only_alms
# ) ## Getting alms for all frequencies with E modes only
# freq_map_B_modes = jax.vmap(get_freq_alm, in_axes=(0, None))(
# jnp.arange(self.n_frequencies), B_modes_only_alms
# ) ## Getting alms for all frequencies with B modes only
# f_sky = self.mask.sum() / self.mask.size
# masked_alm_E_modes = self.get_alm_from_frequency_maps(freq_map_E_modes * self.mask)
# masked_alm_B_modes = self.get_alm_from_frequency_maps(freq_map_B_modes * self.mask)
# result_alms = jnp.zeros_like(masked_alm_B_modes)
# result_alms = result_alms.at[:, -2, :].set(masked_alm_E_modes[:, -2, :])
# result_alms = result_alms.at[:, -1, :].set(
# masked_alm_B_modes[:, -1, :]
# ) ## Keeping only the E and B modes separately
# return result_alms
# # freq_red_matrix = jnp.einsum(
# # 'fq,l,ij->fqlij',
# # jnp.eye(self.n_frequencies),
# # jnp.ones(self.lmax + 1),
# # jnp.eye(self.nstokes),
# # )
# # freq_red_matrix = freq_red_matrix.at[:, :, : self.lmin, :, :].set(0) # Avoiding mask leakage to low multipoles
# # return frequency_alms_x_obj_red_covariance_cell_JAX(result_alms, freq_red_matrix, 0) #/ jnp.sqrt(f_sky)
# def get_masked_alm_from_maps(self, input_maps, seed=0):
# """
# For simulation pruposes, get masked alms from input full sky frequency maps using JAX,
# to avoid E->B leakage issues
# Parameters
# ----------
# input_maps : array[float] of dimensions [nstokes,n_pix]
# input maps
# Returns
# -------
# freq_alms_input_maps : array[float] of dimensions [n_frequencies,nstokes,(lmax+1)*(lmax+2)//2]
# alms from the input frequency maps
# the (lmax+1)*(lmax+2)//2 dimension is the flattened number of lm coefficients stored according to the Healpy convention
# Notes
# -------
# The final alms will see their low multipoles below self.lmin set to zero to avoid mask coupling scales to low multipoles
# """
# assert input_maps.shape == (self.nstokes, self.n_pix)
# extended_maps = np.vstack((np.zeros_like(input_maps[0]), input_maps))
# c_ells = hp.anafast(extended_maps, lmax=self.lmax, iter=self.n_iter)
# c_ell_ith_modes = [np.zeros_like(c_ells) for i in range(2)]
# alm_list = []
# for i in range(2):
# np.random.seed(seed)
# c_ell_ith_modes[i][i+1] = c_ells[i+1]
# # c_ell_ith_modes[i][4] = c_ells[4]
# # Generate input maps
# input_cmb_maps_alt = hp.synfast(c_ell_ith_modes[i], nside=self.nside, new=True, lmax=self.lmax)
# alm_list.append(hp.map2alm(input_cmb_maps_alt * self.mask, lmax=self.lmax, iter=self.n_iter)[i+1,...])
# return np.array(alm_list)
# def get_masked_alm_from_c_ells(self, c_ells, seed=0):
# """
# For simulation pruposes, get masked alms from input full sky frequency maps using JAX,
# to avoid E->B leakage issues
# Parameters
# ----------
# input_maps : array[float] of dimensions [nstokes,n_pix]
# input maps
# Returns
# -------
# freq_alms_input_maps : array[float] of dimensions [n_frequencies,nstokes,(lmax+1)*(lmax+2)//2]
# alms from the input frequency maps
# the (lmax+1)*(lmax+2)//2 dimension is the flattened number of lm coefficients stored according to the Healpy convention
# Notes
# -------
# The final alms will see their low multipoles below self.lmin set to zero to avoid mask coupling scales to low multipoles
# """
# assert c_ells.shape == (self.n_correlations, self.lmax + 1 - self.lmin)
# alm_list = []
# for i in range(2):
# np.random.seed(seed)
# c_ell_extended = np.zeros((6, self.lmax+1))
# c_ell_extended[i+1, self.lmin:] = c_ells[i]
# # c_ell_extended[4, self.lmin:] = c_ells[3]
# # Generate input maps
# input_cmb_maps_alt = hp.synfast(c_ell_extended, nside=self.nside, new=True, lmax=self.lmax)
# alm_list.append(hp.map2alm(input_cmb_maps_alt * self.mask, lmax=self.lmax, iter=self.n_iter)[i+1,...])
# return np.array(alm_list)
[docs]
def perform_harmonic_minimize(
self,
input_freq_maps,
c_ell_approx,
init_params_mixing_matrix,
theoretical_r0_total,
theoretical_r1_tensor,
initial_guess_r=0,
method_used='ScipyMinimize',
**options_minimizer,
):
"""
Perform a minimization to find the best r and Bf in harmonic domain. The results will be returned as the best parameters found.
Parameters
----------
input_freq_maps : array[float] of dimensions [n_frequencies,nstokes,n_pix]
input frequency maps
c_ell_approx : array[float] of dimensions [number_correlations, lmax+1]
approximate CMB power spectra for the correction term
init_params_mixing_matrix : array[float] of dimensions [n_frequencies-len(pos_special_freqs), n_correlations-1]
initial parameters for the mixing matrix
theoretical_r0_total : array[float] of dimensions [lmax+1-lmin, number_correlations, number_correlations]
theoretical covariance matrix for the CMB scalar modes
theoretical_r1_tensor : array[float] of dimensions [lmax+1-lmin, number_correlations, number_correlations]
theoretical covariance matrix for the CMB tensor modes
initial_guess_r : float (optional)
initial guess for r, default 0
method_used : str (optional)
method used for the minimization, default 'ScipyMinimize'
options_minimizer : dict (optional)
additional options dictionary for the minimizer
Returns
-------
params : array[float] of dimensions [n_frequencies-len(pos_special_freqs)*(n_correlations-1) + 1]
best parameters found
"""
try:
import jaxopt as jopt
except ImportError:
raise ImportError('jaxopt is not installed. Please install it with "pip install jaxopt"')
# Disabling all chex checks to speed up the code
if self.disable_chex:
print('Disabling chex !!!', flush=True)
chx.disable_asserts()
## Getting only the relevant spectra
if self.nstokes == 2:
indices_to_consider = np.array([1, 2, 4])
partial_indices_polar = indices_to_consider[: self.nstokes]
elif self.nstokes == 1:
indices_to_consider = np.array([0])
else:
indices_to_consider = np.arange(6) # All auto- and cross-correlations
## Testing the shapes of the scalar and tensor modes spectra
assert len(theoretical_r0_total.shape) == 2
assert theoretical_r0_total.shape[1] == self.lmax + 1 - self.lmin
assert len(theoretical_r1_tensor.shape) == 2
assert theoretical_r1_tensor.shape[1] == theoretical_r0_total.shape[1]
## If C_approx was given for all correlations, we need to select only the relevant ones for the polarisation
if self.nstokes == 2 and (c_ell_approx.shape[0] != len(indices_to_consider)):
c_ell_approx = c_ell_approx[indices_to_consider, :]
## Testing the initial mixing matrix
if len(init_params_mixing_matrix.shape) == 1:
assert len(init_params_mixing_matrix) == (self.n_frequencies - len(self.pos_special_freqs)) * (
self.n_correlations - 1
)
else:
# assert len(init_params_mixing_matrix.shape) == 2
assert init_params_mixing_matrix.shape[0] == (self.n_frequencies - len(self.pos_special_freqs))
assert init_params_mixing_matrix.shape[1] == (self.n_correlations - 1)
## Preparing the reduced covariance matrix for C_approx as well as the CMB scalar and tensor modes in the format [lmax+1-lmin,number_correlations,number_correlations]
red_cov_approx_matrix = get_reduced_matrix_from_c_ell_jax(c_ell_approx)[self.lmin :, ...]
if self.biased_version:
red_cov_approx_matrix = jnp.zeros_like(red_cov_approx_matrix)
theoretical_red_cov_r0_total = get_reduced_matrix_from_c_ell(theoretical_r0_total)
theoretical_red_cov_r1_tensor = get_reduced_matrix_from_c_ell(theoretical_r1_tensor)
## Getting alms from the input maps
freq_alms_input_maps = self.get_alm_from_frequency_maps(input_freq_maps)
# Preparing the noise weighted alms
## Operator N^-1 in format [frequencies, frequencies, lmax+1-lmin, nstokes, nstokes]
freq_red_inverse_noise = jnp.einsum('fgl,sk->fglsk', self.freq_noise_c_ell, jnp.eye(self.nstokes))
## Applying N^-1 to the alms of the input data
noise_weighted_alm_data = frequency_alms_x_obj_red_covariance_cell_JAX(
freq_alms_input_maps, freq_red_inverse_noise, lmin=self.lmin
)
# Setting up the JAXOpt class:
if method_used in ['BFGS', 'GradientDescent', 'LBFGS', 'NonlinearCG', 'ScipyMinimize']:
class_solver = getattr(jopt, method_used)
else:
raise ValueError('Method used not recognized for minimization')
# Setting up the function to minimize
func_to_minimize = lambda sample_Bf_r: -self.harmonic_marginal_probability(
sample_Bf_r,
noise_weighted_alm_data=noise_weighted_alm_data,
red_cov_approx_matrix=red_cov_approx_matrix,
theoretical_red_cov_r0_total=theoretical_red_cov_r0_total,
theoretical_red_cov_r1_tensor=theoretical_red_cov_r1_tensor,
)
# Setting up the JAX optimizer
optimizer = class_solver(fun=func_to_minimize, **options_minimizer)
# Preparing the initial parameters
init_params_Bf_r = jnp.concatenate(
(init_params_mixing_matrix.ravel(order='F'), jnp.array(initial_guess_r).reshape(1))
)
print('Start of minimization', flush=True)
params, state = optimizer.run(init_params_Bf_r)
print('End of minimization', flush=True)
print('Found parameters', params, flush=True)
print('With state', state, flush=True)
return params
[docs]
def perform_harmonic_MH(
self,
input_freq_maps,
c_ell_approx,
init_params_mixing_matrix,
theoretical_r0_total,
theoretical_r1_tensor,
initial_guess_r=0,
covariance_Bf_r=None,
input_freq_alms=None,
print_bool=True,
):
"""
Perform Metropolis Hastings to find the best r and Bf in harmonic domain.
The chains will be stored as object attributes:
- all_samples_r for r
- all_params_mixing_matrix_samples for Bf
Parameters
----------
input_freq_maps : array[float] of dimensions [n_frequencies,nstokes,n_pix]
input frequency maps
c_ell_approx : array[float] of dimensions [number_correlations, lmax+1]
approximate CMB power spectra for the correction term
init_params_mixing_matrix : array[float] of dimensions [n_frequencies-len(pos_special_freqs), n_correlations-1]
initial parameters for the mixing matrix
theoretical_r0_total : array[float] of dimensions [lmax+1-lmin, number_correlations, number_correlations]
theoretical covariance matrix for the CMB scalar modes
theoretical_r1_tensor : array[float] of dimensions [lmax+1-lmin, number_correlations, number_correlations]
theoretical covariance matrix for the CMB tensor modes
initial_guess_r : float (optional)
initial guess for r, default 0
covariance_Bf_r : None or array[float] of dimensions [(n_frequencies-len(pos_special_freqs))*(n_correlations-1) + 1, (n_frequencies-len(pos_special_freqs))*(n_correlations-1) + 1] (optional)
covariance for the Metropolis-Hastings sampling of Bf and r, default None
input_freq_alms : array[float] of dimensions [n_frequencies,nstokes,(lmax + 1) * (lmax // 2 + 1)] (optional)
if provided, input_freq_alms is used instead of input_freq_maps for the MH steps
print_bool: bool (optional)
option for test prints, default True
"""
# Disabling all chex checks to speed up the code
# chx acts like an assert, but is JAX compatible
if self.disable_chex:
print('Disabling chex !!!', flush=True)
chx.disable_asserts()
## Getting only the relevant spectra
if self.nstokes == 2:
indices_to_consider = np.array([1, 2, 4])
partial_indices_polar = indices_to_consider[: self.nstokes]
elif self.nstokes == 1:
indices_to_consider = np.array([0])
else:
indices_to_consider = np.arange(6) # All auto- and cross-correlations
## Testing the shapes of the scalar and tensor modes spectra
assert len(theoretical_r0_total.shape) == 2
assert (
theoretical_r0_total.shape[1] == self.lmax + 1 - self.lmin
) # or (theoretical_r0_total.shape[1] == self.lmax + 1)
assert len(theoretical_r1_tensor.shape) == 2
assert theoretical_r1_tensor.shape[1] == theoretical_r0_total.shape[1]
## Getting the theoretical reduced covariance matrix for C_approx as well as the CMB scalar and tensor modes in the format [lmax+1-lmin,number_correlations,number_correlations]
theoretical_red_cov_r0_total = get_reduced_matrix_from_c_ell(theoretical_r0_total)
theoretical_red_cov_r1_tensor = get_reduced_matrix_from_c_ell(theoretical_r1_tensor)
## Testing shapes of C_approx
assert len(c_ell_approx.shape) == 2
## If C_approx was given for all correlations, we need to select only the relevant ones for the polarisation
if self.nstokes == 2 and (c_ell_approx.shape[0] != len(indices_to_consider)):
c_ell_approx = c_ell_approx[indices_to_consider, :]
## Cutting the C_ell to the relevant ell range
if c_ell_approx.shape[1] == self.lmax + 1:
c_ell_approx = c_ell_approx[:, self.lmin :]
assert c_ell_approx.shape[1] == self.lmax + 1 - self.lmin
## Testing the initial mixing matrix
if len(init_params_mixing_matrix.shape) == 1:
assert len(init_params_mixing_matrix) == (self.n_frequencies - len(self.pos_special_freqs)) * (
self.n_correlations - 1
)
else:
# assert len(init_params_mixing_matrix.shape) == 2
assert init_params_mixing_matrix.shape[0] == (self.n_frequencies - len(self.pos_special_freqs))
assert init_params_mixing_matrix.shape[1] == (self.n_correlations - 1)
# Preparing for the full Metropolis-Hatings sampling
## Initial guesses preparation
params_mixing_matrix_init_sample = jnp.copy(init_params_mixing_matrix).ravel(order='F')
## CMB covariance preparation in the format [lmax+1-lmin,nstokes,nstokes]
red_cov_approx_matrix = get_reduced_matrix_from_c_ell_jax(c_ell_approx)
if self.biased_version:
red_cov_approx_matrix = jnp.zeros_like(red_cov_approx_matrix)
## Preparing the JAX PRNG key from the seed of the object
if np.size(self.seed) == 1:
PRNGKey = random.PRNGKey(self.seed)
elif np.size(self.seed) == 2:
PRNGKey = jnp.array(self.seed, dtype=jnp.uint32)
else:
raise ValueError('Seed should be either a scalar or a 2D array interpreted as a JAX PRNG Key!')
## Preparing the step-size for Metropolis-within-Gibbs of Bf sampling
dimension_param_Bf = (self.n_frequencies - len(self.pos_special_freqs)) * (self.n_correlations - 1)
if covariance_Bf_r is None:
if self.covariance_Bf is None:
raise ValueError('Please provide a covariance_Bf')
assert (self.covariance_Bf).shape == (dimension_param_Bf, dimension_param_Bf)
## Building the full covariance of both Bf and r, without correlations between Bf and r
covariance_Bf_r = jnp.zeros((dimension_param_Bf + 1, dimension_param_Bf + 1))
covariance_Bf_r = covariance_Bf_r.at[:dimension_param_Bf, :dimension_param_Bf].set(
self.covariance_Bf
) ## Setting the covariance for Bf
covariance_Bf_r = covariance_Bf_r.at[dimension_param_Bf, dimension_param_Bf].set(
self.step_size_r**2
) ## Setting the step-size for r
else:
assert covariance_Bf_r.shape == (dimension_param_Bf + 1, dimension_param_Bf + 1)
if print_bool:
print('Covariance Bf, r:', covariance_Bf_r, flush=True)
## Getting alms from the input maps
if input_freq_alms is None:
input_freq_alms = self.get_alm_from_frequency_maps(input_freq_maps)
## Preparing the noise weighted alms
freq_red_inverse_noise = jnp.einsum(
'fgl,sk->fglsk', self.freq_noise_c_ell, jnp.eye(self.nstokes)
) ## Operator N^-1 in format [frequencies, frequencies, lmax+1-lmin, nstokes, nstokes]
## Applying N^-1 to the alms of the input data
noise_weighted_alm_data = frequency_alms_x_obj_red_covariance_cell_JAX(
input_freq_alms, freq_red_inverse_noise, lmin=self.lmin
)
print(f'Starting {self.number_iterations_sampling} iterations for harmonic run', flush=True)
MHState = namedtuple('MHState', ['u', 'rng_key'])
class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
sample_field = 'u'
def __init__(self, log_proba, covariance_matrix, boundary_Bf_r=self.boundary_Bf_r):
self.log_proba = log_proba
self.covariance_matrix = covariance_matrix
self.boundary_Bf_r = boundary_Bf_r
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
return MHState(init_params, rng_key)
def sample(self, state, model_args, model_kwargs):
"""
One Metropolis-Hastings sampling step
"""
new_sample, rng_key = multivariate_Metropolis_Hasting_step_numpyro_bounded(
state,
covariance_matrix=self.covariance_matrix,
log_proba=self.log_proba,
boundary=self.boundary_Bf_r,
**model_kwargs,
)
return MHState(new_sample, rng_key)
mcmc_obj = numpyro.infer.mcmc.MCMC(
MetropolisHastings(log_proba=self.harmonic_marginal_probability, covariance_matrix=covariance_Bf_r),
num_warmup=0,
num_samples=self.number_iterations_sampling - self.number_iterations_done,
progress_bar=True,
)
# Initializing r and Bf samples
init_params_mixing_matrix_r = jnp.concatenate(
(params_mixing_matrix_init_sample, jnp.array(initial_guess_r).reshape(1))
)
time_start_sampling = time.time()
## Starting the MH sampling !!!
mcmc_obj.run(
PRNGKey,
init_params=init_params_mixing_matrix_r,
noise_weighted_alm_data=noise_weighted_alm_data,
theoretical_red_cov_r1_tensor=theoretical_red_cov_r1_tensor,
theoretical_red_cov_r0_total=theoretical_red_cov_r0_total,
red_cov_approx_matrix=red_cov_approx_matrix,
)
time_full_chain = (time.time() - time_start_sampling) / 60
print(f'End of MH iterations for harmonic run in {time_full_chain} minutes !', flush=True)
posterior_samples = mcmc_obj.get_samples()
if print_bool:
print('Summary of the run', flush=True)
mcmc_obj.print_summary()
# Saving the samples as attributes of the Sampler object
self.update_samples_MH(posterior_samples)
self.last_sample = {
'r_sample': posterior_samples[-1, -1],
'params_mixing_matrix_sample': posterior_samples[-1, :-1],
'input_freq_alms': input_freq_alms,
}
self.number_iterations_done = self.number_iterations_sampling
self.last_PRNGKey = PRNGKey
[docs]
def compute_covariance_from_samples(self):
"""
Compute the covariance matrix from the sample chains of Bf and r
Returns
-------
covariance_Bf_r : array[float]
covariance matrix of the samples of Bf and r
"""
if self.number_iterations_done == 0:
raise ValueError(
'No iterations done yet, please perform some sampling before computing the covariance matrix'
)
print('Computing the covariance matrix from the samples', flush=True)
all_samples_Bf_r = np.zeros(
(self.number_iterations_sampling, (self.n_frequencies - len(self.pos_special_freqs)) * 2 + 1)
)
all_samples_Bf_r[:, :-1] = self.all_params_mixing_matrix_samples.reshape(
(self.number_iterations_sampling, (self.n_frequencies - len(self.pos_special_freqs)) * 2)
)
all_samples_Bf_r[:, -1] = self.all_samples_r
return jnp.cov(all_samples_Bf_r, rowvar=False)