Source code for micmac.likelihood.pixel
# This file is part of MICMAC.
# Copyright (C) 2024 CNRS / SciPol developers
#
# MICMAC is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# MICMAC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with MICMAC. If not, see <https://www.gnu.org/licenses/>.
import time
import chex as chx
import healpy as hp
import jax
import jax.lax as jlax
import jax.numpy as jnp
import jax.random as random
import jax.scipy as jsp
import numpy as np
from jax import config
from jax_tqdm import scan_tqdm
from micmac.likelihood.sampling import (
SamplingFunctions,
separate_single_MH_step_index_accelerated,
separate_single_MH_step_index_v2b,
separate_single_MH_step_index_v4_pixel,
separate_single_MH_step_index_v4b_pixel,
single_Metropolis_Hasting_step,
)
from micmac.noise.noisecovar import (
get_BtinvN,
get_inv_BtinvNB,
get_inv_BtinvNB_c_ell,
get_Wd,
)
from micmac.toolbox.statistics import get_1d_recursive_empirical_covariance
from micmac.toolbox.tools import (
get_c_ells_from_red_covariance_matrix,
get_cell_from_map_jax,
get_reduced_matrix_from_c_ell,
get_reduced_matrix_from_c_ell_jax,
get_sqrt_reduced_matrix_from_matrix_jax,
maps_x_red_covariance_cell_JAX,
)
from micmac.toolbox.utils import generate_power_spectra_CAMB
__all__ = ['MicmacSampler']
config.update('jax_enable_x64', True)
[docs]
class MicmacSampler(SamplingFunctions):
[docs]
def __init__(
self,
nside,
lmax,
nstokes,
frequency_array,
freq_inverse_noise,
pos_special_freqs=[0, -1],
freq_noise_c_ell=None,
n_components=3,
lmin=2,
n_iter=8,
limit_iter_cg=200,
limit_iter_cg_eta=200,
tolerance_CG=1e-8,
atol_CG=1e-8,
mask=None,
save_CMB_chain_maps=False,
save_eta_chain_maps=False,
save_all_Bf_params=True,
save_s_c_spectra=False,
sample_r_Metropolis=True,
sample_C_inv_Wishart=False,
perturbation_eta_covariance=True,
simultaneous_accept_rate=False,
non_centered_moves=False,
save_intermediary_centered_moves=False,
limit_r_value=False,
min_r_value=0,
biased_version=False,
classical_Gibbs=False,
use_binning=False,
bin_ell_distribution=None,
acceptance_posdef=False,
step_size_r=1e-4,
covariance_Bf=None,
use_scam_step_size=False,
burn_in_scam=50,
s_param_scam=(2.4) ** 2,
epsilon_param_scam_r=1e-10,
epsilon_param_scam_Bf=1e-11,
scam_iteration_updates=50,
indexes_free_Bf=False,
number_iterations_sampling=100,
number_iterations_done=0,
seed=0,
disable_chex=True,
instrument_name='SO_SAT',
# fwhm=None,
spv_nodes_b=[],
):
"""
Main MICMAC pixel sampling object to initialize and launch the Gibbs sampling in pixel domain.
The Gibbs sampling will always store Bf and r (or C) parameters
Parameters
----------
nside: int
nside of the input frequency maps
lmax: int
maximum multipole for the spherical harmonics transforms and harmonic domain objects,
nstokes: int
number of Stokes parameters
frequency_array: array[float]
array of frequencies, in GHz
freq_inverse_noise: array[float]
array of inverse noise for each frequency, in uK^-2
pos_special_freqs: list[int] (optional)
indexes of the special frequencies in the frequency array respectively for synchrotron and dust, default is [0,-1] for first and last frequencies
freq_noise_c_ell: array[float] of dimensions [frequencies, frequencies, lmax+1-lmin] or [frequencies, frequencies, lmax] (in which case it will be cut to lmax+1-lmin)
optional, noise power spectra for each frequency, in uK^2, dimensions
n_components: int (optional)
number of components for the mixing matrix, default 3
lmin: int (optional)
minimum multipole for the spherical harmonics transforms and harmonic domain objects, default 2
n_iter: int (optional)
number of iterations the spherical harmonics transforms (for map2alm transformations), default 8
limit_iter_cg: int (optional)
maximum number of iterations for the conjugate gradient for the CMB map sampling, default 200
limit_iter_cg_eta: int (optional)
maximum number of iterations for the conjugate gradient for eta maps sampling, default 200
tolerance_CG: float (optional)
tolerance for the conjugate gradient, default 1e-8
atol_CG: float (optional)
absolute tolerance for the conjugate gradient, default 1e-8
mask: None or array[float] of dimensions [n_pix] (optional)
mask to use in the sampling ; if not given, no mask is used, default None
Note: the mask WILL NOT be applied to the input maps, it will be only used for the propagated noise covariance
save_CMB_chain_maps: bool (optional)
save the CMB chain maps, default False
save_eta_chain_maps: bool (optional)
save the eta chain maps, default False
sample_r_Metropolis: bool (optional)
sample r with a Metropolis-within-Gibbs step from the BB power spectrum of the
reconstructed sample of the CMB map during the Gibbs iteration, default True
Either sample_r_Metropolis or sample_C_inv_Wishart should True but not both
sample_C_inv_Wishart: bool (optional)
sample C_inv with Wishart distribution instead of simply r being sampled, default False
Either sample_r_Metropolis or sample_C_inv_Wishart should True but not both
limit_r_value: bool (optional)
limit r value being sampled with the minmum r value given by min_r_value, default False
min_r_value: float (optional)
minimum r value accepted for r sample if limit_r_value is True, default 0
perturbation_eta_covariance: bool (optional)
approach to compute difference between CMB noise component for eta log proba instead of repeating the CG for each Bf sampling, default True
simultaneous_accept_rate: bool (optional)
use the simultaneous accept rate for the patches of the Bf sampling, default False
biased_version: bool (optional)
use the biased version of the likelihood, so no computation of the correction term, default False
classical_Gibbs: bool (optional)
sampling only for s_c and the CMB covariance, and neither Bf or eta, default False
use_binning: bool (optional)
use binning for the sampling of inverse Wishart CMB covariance, if False bin_ell_distribution will not be used, default False
bin_ell_distribution: array[int] (optional)
binning distribution for the sampling of inverse Wishart CMB covariance, default None
acceptance_posdef: accept only positive definite matrices C sampling, bool
step_size_r: float (optional)
step size for the Metropolis-Hastings sampling of r, default 1e-4
covariance_Bf: None or array[float] of dimensions [(n_frequencies-len(pos_special_freqs))*(n_components-1), (n_frequencies-len(pos_special_freqs))*(n_components-1)] (optional)
covariance for the Metropolis-Hastings sampling of Bf ; will be repeated if multiresoltion case, default None
use_scam_step_size: bool (optional)
use the SCAM step size for the Metropolis-Hastings sampling of Bf and r (Haario et al. 2005), default False
burn_in_scam: int (optional)
number of burn-in iterations before using adaptive step-size (SCAM), not used if use_scam_steps_size is False, default 50
s_param_scam: float (optional)
s parameter for the SCAM step size (see Haario et al. 2001, Haario et al. 2005), default (2.4)**2
epsilon_param_scam_r: float (optional)
epsilon parameter for the SCAM step size for r (see Haario et al. 2001, Haario et al. 2005), default 1e-10
epsilon_param_scam_Bf: float (optional)
epsilon parameter for the SCAM step size for Bf (see Haario et al. 2001, Haario et al. 2005), default 1e-11
scam_iteration_updates: int (optional)
number of iterations for which the SCAM step size will be updated (the variance is updated successively for every scam_iteration_updates iterations), default 100
indexes_free_Bf: bool or array[int] (optional)
indexes of the free Bf parameters to actually sample and leave the rest of the indices fixed, array of integers, default False to sample all Bf
number_iterations_sampling: int (optional)
maximum number of iterations for the sampling, default 100
number_iterations_done: int (optional)
number of iterations already accomplished, in case the chain is resuming from a previous run, usually set by exterior routines, default 0
seed: int or array[jnp.uint32] (optional)
seed for the JAX PRNG random number generator to start the chain or array of a previously computed seed, default 0
disable_chex: bool (optional)
disable chex tests (to improve speed)
instrument_name: str (optional)
name of the instrument as expected by cmbdb or given as 'customized_instrument' if redefined by user, default 'SO_SAT'
see https://github.com/dpole/cmbdb/blob/master/cmbdb/experiments.yaml
fwhm: float (optional)
FWHM of the beam in arcmin, default None (no beam) ; not implemented yet
spv_nodes_b: list[dictionaries] (optional)
tree for the spatial variability, to generate from a yaml file, default []
in principle set up by get_nodes_b
"""
## Give the parameters to the parent class
super().__init__(
nside=nside,
lmax=lmax,
nstokes=nstokes,
lmin=lmin,
frequency_array=frequency_array,
freq_inverse_noise=freq_inverse_noise,
spv_nodes_b=spv_nodes_b,
pos_special_freqs=pos_special_freqs,
n_components=n_components,
n_iter=n_iter,
limit_iter_cg=limit_iter_cg,
limit_iter_cg_eta=limit_iter_cg_eta,
tolerance_CG=tolerance_CG,
atol_CG=atol_CG,
mask=mask,
bin_ell_distribution=bin_ell_distribution,
)
# Run settings
self.classical_Gibbs = bool(
classical_Gibbs
) # To run the classical Gibbs sampling instead of the full MICMAC sampling
if self.classical_Gibbs is False:
# Then we expect to have multiple components
assert self.n_components > 1
try:
assert len(pos_special_freqs) == self.n_components - 1
except:
raise Exception('The number of special frequencies should be equal to the number of components - 1')
self.biased_version = bool(biased_version) # To have a run without the correction term
self.perturbation_eta_covariance = bool(
perturbation_eta_covariance
) # To use the perturbation approach for the eta contribution in log-proba of Bf
self.simultaneous_accept_rate = bool(
simultaneous_accept_rate
) # To use the simultaneous accept rate for the patches of the Bf sampling
assert ((sample_r_Metropolis and sample_C_inv_Wishart) == False) and (
(sample_r_Metropolis or not (sample_C_inv_Wishart)) or (not (sample_r_Metropolis) or sample_C_inv_Wishart)
)
self.sample_r_Metropolis = bool(sample_r_Metropolis) # To sample r with Metropolis-Hastings
self.sample_C_inv_Wishart = bool(sample_C_inv_Wishart)
self.use_binning = bool(use_binning) # To use binning for the sampling of inverse Wishart CMB covariance
self.acceptance_posdef = bool(acceptance_posdef) # To accept only positive definite matrices for C sampling
self.non_centered_moves = bool(non_centered_moves) # To use non-centered moves for C sampling
self.save_intermediary_centered_moves = bool(
save_intermediary_centered_moves
) # To save intermediary r values in case of non-centered moves in the sampling
self.limit_r_value = bool(limit_r_value) # To limit the r value to be positive
self.min_r_value = float(min_r_value) # Minimum value for r
# Harmonic noise parameter
self.freq_noise_c_ell = freq_noise_c_ell # Noise power spectra for each frequency, in uK^2, dimensions [frequencies, frequencies, lmax+1-lmin] or [frequencies, frequencies, lmax] (in which case it will be cut to lmax+1-lmin)
# Metropolis-Hastings parameters
self.covariance_Bf = covariance_Bf # Covariance for the Metropolis-Hastings step sampling of Bf
self.step_size_r = step_size_r # Step size for the Metropolis-Hastings step sampling of r
self.use_scam_step_size = bool(
use_scam_step_size
) # Use the SCAM (Single Component Adaptive Metropolis) step size for the Metropolis-Hastings step sampling of Bf and r (Haario et al. 2005)
self.burn_in_scam = int(burn_in_scam) # Number of burn-in iterations before using adaptive step-size (SCAM)
self.s_param_scam = float(s_param_scam) # s parameter for the SCAM step size
self.epsilon_param_scam_r = float(epsilon_param_scam_r) # epsilon parameter for the SCAM step size for r
self.epsilon_param_scam_Bf = float(epsilon_param_scam_Bf) # epsilon parameter for the SCAM step size for Bf
self.scam_iteration_updates = int(scam_iteration_updates)
# if number_iterations_done > 0:
# self.burn_in_scam = 0
if self.use_scam_step_size:
print(
'Using SCAM step size for the Metropolis-Hastings step sampling of Bf and r after',
self.burn_in_scam,
'iterations, with parameters s',
self.s_param_scam,
'epsilon_r',
self.epsilon_param_scam_r,
'epsilon_Bf',
self.epsilon_param_scam_Bf,
'and updates every',
self.scam_iteration_updates,
'iterations',
flush=True,
)
# Sampling parameters
if indexes_free_Bf is False:
# If given as False, then we sample all Bf
indexes_free_Bf = jnp.arange(self.len_params)
self.indexes_free_Bf = jnp.array(indexes_free_Bf)
assert (
jnp.size(self.indexes_free_Bf) <= self.len_params
) # The number of free parameters should be less than the total number of parameters
assert (
jnp.max(self.indexes_free_Bf) <= self.len_params
) # The indexes should be in the range of the total number of parameters
assert (
jnp.min(self.indexes_free_Bf) >= 0
) # The indexes should be in the range of the total number of parameters
self.number_iterations_sampling = int(
number_iterations_sampling
) # Maximum number of iterations for the sampling
self.number_iterations_done = int(
number_iterations_done
) # Number of iterations already accomplished, in case the chain is resuming from a previous run
self.seed = seed # Seed for the JAX PRNG random number generator to start the chain
# Saving parameters
self.save_CMB_chain_maps = bool(save_CMB_chain_maps) # Save the CMB chain maps
self.save_eta_chain_maps = bool(save_eta_chain_maps) # Save the eta chain maps
self.save_all_Bf_params = bool(save_all_Bf_params) # Save all the Bf chains
self.save_s_c_spectra = bool(save_s_c_spectra) # Save the s_c spectra
# Instrument parameters
self.instrument_name = instrument_name # Name of the instrument
# if fwhm is not None:
# self.fwhm = float(fwhm) # FWHM of the beam in arcmin
# else:
# self.fwhm = None # No beam
# Check related parameters
self.disable_chex = disable_chex # Disable chex tests (to improve speed)
# Samples preparation
self.all_samples_eta = jnp.empty(0)
self.all_params_mixing_matrix_samples = jnp.empty(0)
self.all_samples_wiener_filter_maps = jnp.empty(0)
self.all_samples_fluctuation_maps = jnp.empty(0)
self.all_samples_r = jnp.empty(0)
self.all_samples_CMB_c_ell = jnp.empty(0)
self.all_samples_s_c_spectra = jnp.empty(0)
@property
def all_samples_s_c(self):
"""
Returns all the CMB sampled maps from the initial WF and fluctuation maps
"""
return self.all_samples_wiener_filter_maps + self.all_samples_fluctuation_maps
[docs]
def generate_CMB(self, return_spectra=True):
"""
Returns CMB spectra of scalar modes only and tensor modes only (with r=1)
Both CMB spectra are either returned in the usual form [number_correlations,lmax+1],
or in the red_cov form if return_spectra == False
"""
# Selecting the relevant auto- and cross-correlations from CAMB spectra
if self.nstokes == 2:
# EE, BB
partial_indices_polar = np.array([1, 2])
elif self.nstokes == 1:
# TT
partial_indices_polar = np.array([0])
else:
# TT, EE, BB, EB
partial_indices_polar = np.arange(4)
# Generating the CMB power spectra
all_spectra_r0 = generate_power_spectra_CAMB(self.nside * 2, r=0, typeless_bool=True)
all_spectra_r1 = generate_power_spectra_CAMB(self.nside * 2, r=1, typeless_bool=True)
# Retrieve the scalar mode spectrum
camb_cls_r0 = all_spectra_r0['total'][: self.lmax + 1, partial_indices_polar]
# Retrieve the tensor mode spectrum
tensor_spectra_r1 = all_spectra_r1['tensor'][: self.lmax + 1, partial_indices_polar]
theoretical_r1_tensor = np.zeros((self.n_correlations, self.lmax + 1))
theoretical_r0_total = np.zeros_like(theoretical_r1_tensor)
theoretical_r1_tensor[: self.nstokes, ...] = tensor_spectra_r1.T
theoretical_r0_total[: self.nstokes, ...] = camb_cls_r0.T
if return_spectra:
# Return spectra in the form [number_correlations,lmax+1]
return theoretical_r0_total, theoretical_r1_tensor
# Return spectra in the form of the reduced covariance matrix, [lmax+1-lmin,number_correlations,number_correlations]
theoretical_red_cov_r1_tensor = get_reduced_matrix_from_c_ell(theoretical_r1_tensor)[self.lmin :]
theoretical_red_cov_r0_total = get_reduced_matrix_from_c_ell(theoretical_r0_total)[self.lmin :]
return theoretical_red_cov_r0_total, theoretical_red_cov_r1_tensor
[docs]
def generate_input_freq_maps_from_fgs(
self, freq_maps_fgs, r_true=0, return_only_freq_maps=True, return_only_maps=False
):
"""
Generate input frequency maps (CMB+foregrounds) from the input frequency foregrounds maps,
return either the full frequency maps, the full frequency and CMB maps alone,
or the full frequency and CMB maps with the theoretical reduced covariance matrices for the CMB scalar and tensor modes
Parameters
----------
freq_maps_fgs: array[float] of dimensions [n_frequencies,nstokes,n_pix]
input frequency foregrounds maps
r_true: float, optional
input tensor-to-scalar ratio r to generate the CMB, default to 0
return_only_freq_maps: bool (optional)
return only the full frequency maps, bool
return_only_maps: bool (optional)
return only the full frequency and CMB maps alone, bool
Returns
-------
input_freq_maps: array[float] of dimensions [n_frequencies,nstokes,n_pix]
input frequency maps
input_cmb_maps: array[float] of dimensions [nstokes,n_pix]
input CMB maps
theoretical_red_cov_r0_total: array[float] of dimensions [lmax+1-lmin,nstokes,nstokes]
theoretical reduced covariance matrix for the CMB scalar modes
theoretical_red_cov_r1_tensor: array[float] of dimensions [lmax+1-lmin,nstokes,nstokes]
theoretical reduced covariance matrix for the CMB tensor modes
"""
indices_polar = np.array([1, 2, 4])
# Generate CMB from CAMB
theoretical_red_cov_r0_total, theoretical_red_cov_r1_tensor = self.generate_CMB(return_spectra=False)
# Retrieve fiducial CMB power spectra
true_cmb_specra = get_c_ells_from_red_covariance_matrix(
theoretical_red_cov_r0_total + r_true * theoretical_red_cov_r1_tensor
)
true_cmb_specra_extended = np.zeros((6, self.lmax + 1))
true_cmb_specra_extended[indices_polar, self.lmin :] = true_cmb_specra
# Generate input frequency maps
input_cmb_maps_alt = hp.synfast(true_cmb_specra_extended, nside=self.nside, new=True, lmax=self.lmax)[1:, ...]
input_cmb_maps = np.broadcast_to(input_cmb_maps_alt, (self.n_frequencies, self.nstokes, self.n_pix))
input_freq_maps = input_cmb_maps + freq_maps_fgs
if return_only_freq_maps:
return input_freq_maps
if return_only_maps:
return input_freq_maps, input_cmb_maps
return input_freq_maps, input_cmb_maps, theoretical_red_cov_r0_total, theoretical_red_cov_r1_tensor
[docs]
def update_variable(self, all_samples, new_samples_to_add):
"""
Update the samples with new samples to add by stacking them
Parameters
----------
all_samples: array[float] of dimensions [n_samples,n_pix]
previous samples to update
new_samples_to_add: array[float] of dimensions [n_samples,n_pix]
new samples to add
Returns
-------
all_samples: array[float] of dimensions [n_samples+n_samples,n_pix]
updated samples
"""
if jnp.size(all_samples) == 0:
return new_samples_to_add
elif jnp.size(new_samples_to_add.shape) == 1:
return jnp.hstack([all_samples, new_samples_to_add])
else:
return jnp.vstack([all_samples, new_samples_to_add])
[docs]
def update_samples(self, all_samples):
"""
Update the samples with new samples to add
Parameters
----------
all_samples: dictionary
dictionary of all the samples to update
"""
# Update the eta samples if they were saved and/or if they were sampled
if self.save_eta_chain_maps and not (self.classical_Gibbs):
self.all_samples_eta = self.update_variable(self.all_samples_eta, all_samples['eta_maps'])
# Update the CMB chain maps if they were saved
if self.save_CMB_chain_maps:
self.all_samples_wiener_filter_maps = self.update_variable(
self.all_samples_wiener_filter_maps, all_samples['wiener_filter_term']
)
self.all_samples_fluctuation_maps = self.update_variable(
self.all_samples_fluctuation_maps, all_samples['fluctuation_maps']
)
# Update the s_c spectra if they were saved
if self.save_s_c_spectra:
self.all_samples_s_c_spectra = self.update_variable(
self.all_samples_s_c_spectra, all_samples['s_c_spectra']
)
# Update the CMB covariance if they were sampled
if self.sample_C_inv_Wishart:
if all_samples['red_cov_matrix_sample'].shape[1] == self.lmax + 1 - self.lmin:
all_samples_CMB_c_ell = jnp.array(
[
get_c_ells_from_red_covariance_matrix(all_samples['red_cov_matrix_sample'][iteration])
for iteration in range(self.number_iterations_sampling) # - self.number_iterations_done)
]
)
else:
all_samples_CMB_c_ell = all_samples['red_cov_matrix_sample']
self.all_samples_CMB_c_ell = self.update_variable(self.all_samples_CMB_c_ell, all_samples_CMB_c_ell)
# Update the r samples if they were sampled
if self.sample_r_Metropolis:
if len(all_samples['r_sample'].shape) != len(self.all_samples_r.shape):
all_samples['r_sample'] = all_samples['r_sample'].squeeze()
self.all_samples_r = self.update_variable(self.all_samples_r, all_samples['r_sample'])
# Update the mixing matrix Bf parameters if they were sampled
if self.save_all_Bf_params:
self.all_params_mixing_matrix_samples = self.update_variable(
self.all_params_mixing_matrix_samples, all_samples['params_mixing_matrix_sample']
)
[docs]
def update_one_sample(self, one_sample):
"""
Update the samples with one sample to add
"""
if self.save_eta_chain_maps and not (self.classical_Gibbs):
self.all_samples_eta = self.update_variable(
self.all_samples_eta, jnp.expand_dims(one_sample['eta_maps'], axis=0)
)
if self.save_CMB_chain_maps:
self.all_samples_wiener_filter_maps = self.update_variable(
self.all_samples_wiener_filter_maps, jnp.expand_dims(one_sample['wiener_filter_term'], axis=0)
)
self.all_samples_fluctuation_maps = self.update_variable(
self.all_samples_fluctuation_maps, jnp.expand_dims(one_sample['fluctuation_maps'], axis=0)
)
if self.sample_C_inv_Wishart:
if one_sample['red_cov_matrix_sample'].shape[0] == self.lmax + 1 - self.lmin:
one_sample_CMB_c_ell = get_c_ells_from_red_covariance_matrix(one_sample['red_cov_matrix_sample'])
else:
one_sample_CMB_c_ell = one_sample['red_cov_matrix_sample']
self.all_samples_CMB_c_ell = self.update_variable(
self.all_samples_CMB_c_ell, jnp.expand_dims(one_sample_CMB_c_ell, axis=0)
)
if self.sample_r_Metropolis:
if self.non_centered_moves:
if self.save_intermediary_centered_moves:
self.all_samples_r = self.update_variable(
self.all_samples_r,
jnp.expand_dims(jnp.stack((one_sample['r_sample'], one_sample['r_sample'])), axis=0),
)
else:
self.all_samples_r = self.update_variable(self.all_samples_r, one_sample['r_sample'])
else:
self.all_samples_r = self.update_variable(self.all_samples_r, one_sample['r_sample'])
if self.save_all_Bf_params:
self.all_params_mixing_matrix_samples = self.update_variable(
self.all_params_mixing_matrix_samples,
jnp.expand_dims(one_sample['params_mixing_matrix_sample'], axis=0),
)
# def update_scam_step_size(self, carry, new_carry, iteration):
# """
# Update the SCAM step size for the Metropolis-Hastings step sampling of Bf and r
# Parameters
# ----------
# carry: dictionary
# dictionary carry from all_sampling_steps function
# new_carry: dictionary
# updated dictionary from all_sampling_steps function
# iteration: int
# current iteration number
# """
# total_number_iterations = iteration + self.number_iterations_done + 1
# # Update the SCAM step size for the Metropolis-Hastings step sampling of r
# new_carry['empirical_variance_r'] = get_1d_recursive_empirical_covariance(
# total_number_iterations,
# new_carry['r_sample'],
# carry['mean_r'],
# carry['empirical_variance_r'],
# s_param=self.s_param_scam,
# epsilon_param=self.epsilon_param_scam_r,
# ).squeeze()
# new_carry['mean_r'] = (total_number_iterations * carry['mean_r'] + new_carry['r_sample']) / (
# total_number_iterations + 1
# )
# # Update the SCAM step size for the Metropolis-Hastings step sampling of Bf
# new_carry['empirical_variance_Bf'] = get_1d_recursive_empirical_covariance(
# total_number_iterations,
# new_carry['params_mixing_matrix_sample'],
# carry['mean_Bf'],
# carry['empirical_variance_Bf'],
# s_param=self.s_param_scam,
# epsilon_param=self.epsilon_param_scam_Bf,
# )
# new_carry['mean_Bf'] = (total_number_iterations * carry['mean_Bf'] + new_carry['params_mixing_matrix_sample']) / (
# total_number_iterations + 1
# )
[docs]
def perform_Gibbs_sampling(
self,
input_freq_maps,
c_ell_approx,
CMB_c_ell,
init_params_mixing_matrix,
initial_guess_r=1e-8,
initial_wiener_filter_term=None,
initial_fluctuation_maps=None,
theoretical_r0_total=None,
theoretical_r1_tensor=None,
**dictionnary_additional_parameters,
):
r"""
Perform sampling steps with:
1. The sampling of \eta by computing \eta = x + C_approx^(1/2) N_c^{-1/2} y ; where x is band-limited
2. A CG for the Wiener filter (WF) and fluctuation variables s_c: (s_c - s_{c,WF})^t (C^{-1} + N_c^{-1}) (s_c - s_{c,WF})
3. The c_ell sampling, either by parametrizing it by r or by sampling an inverse Wishart distribution
4. Mixing matrix Bf sampling with: -(d - B_c s_c)^t N^{-1} B_f (B_f^t N^{-1} B_f)^{-1} B_f^t N^{-1} (d - B_c s_c) + \eta^t (Id + C_{approx}^{1/2} N_c^{-1} C_{approx}^{1/2}) \eta
The results of the chain will be stored in the class attributes, depending if the save options are put to True or False:
- self.all_samples_eta (if self.save_eta_chain_maps is True)
- self.all_samples_wiener_filter_maps (if self.save_CMB_chain_maps is True)
- self.all_samples_fluctuation_maps (if self.save_CMB_chain_maps is True)
- self.all_samples_r (if self.sample_r_Metropolis is True)
- self.all_samples_CMB_c_ell (if self.sample_C_inv_Wishart is True)
- self.all_params_mixing_matrix_samples (always)
This same function can be used to continue a chain from a previous run, by giving the number of iterations already done in the MicmacSampler object,
giving the chains to the attributes of the object, and giving the last iteration results as initial guesses.
Parameters
----------
input_freq_maps: array[float] of dimensions [frequencies, nstokes, n_pix]
input frequency maps
c_ell_approx: array[float] of dimensions [number_correlations, lmax+1]
approximate CMB power spectra for the latent parameter \eta defining the ad-hoc correction term
CMB_c_ell: array[float] of dimensions [number_correlations, lmax+1]
CMB power spectra, where number_correlations is the number of auto- and cross-correlations relevant considering the number of Stokes parameters
init_params_mixing_matrix: array[float] of dimensions [len_params]
initial parameters for the mixing matrix elements Bf; expected to be given flattened as [Bf_s1, Bf_s2, ..., Bf_sn, Bf_d1, ..., Bf_dn]
initial_guess_r: float (optional)
initial guess for r, default 1e-8
initial_wiener_filter_term: array[float] of dimensions [nstokes, n_pix] or empty (optional)
initial guess for the Wiener filter term, default empty array
initial_fluctuation_maps: array[float] of dimensions [nstokes, n_pix] or empty (optional)
initial guess for the fluctuation maps, default empty array
theoretical_r0_total: array[float] of dimensions [number_correlations, lmax+1-lmin] (optional)
theoretical reduced covariance matrix for the CMB scalar modes, default empty array
theoretical_r1_tensor: array[float] of dimensions [number_correlations, lmax+1-lmin] (optional)
theoretical reduced covariance matrix for the CMB tensor modes, default empty array
dictionnary_additional_parameters: dictionary
additional parameters to give to the function, currently only the ones related to the SCAM step size
Notes
-----
The formalism relies on the ability to have an inverse for C_approx (even though it is never computed effectively in the code), and may lead to numerical instabilities if the C_approx matrix is not well-conditioned.
"""
time_test = time.time()
# Disabling all chex checks to speed up the code
if self.disable_chex:
print('Disabling chex !!!', flush=True)
chx.disable_asserts()
## Getting only the relevant spectra
if self.nstokes == 2:
indices_to_consider = np.array([1, 2, 4])
partial_indices_polar = indices_to_consider[: self.nstokes]
elif self.nstokes == 1:
indices_to_consider = np.array([0])
else:
indices_to_consider = np.arange(6) # All auto- and cross-correlations
## Testing the inverse frequency noise
assert (
self.freq_inverse_noise is not None
), 'The inverse noise for the frequencies should be provided as an attribute of the MicmacSampler object'
assert self.freq_inverse_noise.shape == (
self.n_frequencies,
self.n_frequencies,
self.n_pix,
), 'The inverse noise for the frequencies should have dimensions [n_frequencies,n_frequencies,n_pix]'
## Testing the initial WF term, or initialize it properly
if initial_wiener_filter_term is None:
wiener_filter_term = jnp.zeros((self.nstokes, self.n_pix))
else:
assert len(initial_wiener_filter_term.shape) == 2
assert initial_wiener_filter_term.shape == (self.nstokes, self.n_pix)
wiener_filter_term = initial_wiener_filter_term
## Testing the initial fluctuation term, or initialize it properly
if initial_fluctuation_maps is None:
fluctuation_maps = jnp.zeros((self.nstokes, self.n_pix))
else:
assert len(initial_fluctuation_maps.shape) == 2
assert initial_fluctuation_maps.shape == (self.nstokes, self.n_pix)
fluctuation_maps = initial_fluctuation_maps
## Testing the initial spectra given in case the sampling is done with r
if self.sample_r_Metropolis:
assert len(theoretical_r0_total.shape) == 2
assert theoretical_r0_total.shape[1] == self.lmax + 1 - self.lmin
assert len(theoretical_r1_tensor.shape) == 2
assert theoretical_r1_tensor.shape[1] == theoretical_r0_total.shape[1]
# Transforming into the reduced (red) format [lmax+1-lmin,nstokes,nstokes]
theoretical_red_cov_r0_total = get_reduced_matrix_from_c_ell(theoretical_r0_total)
theoretical_red_cov_r1_tensor = get_reduced_matrix_from_c_ell(theoretical_r1_tensor)
assert theoretical_red_cov_r0_total.shape[1] == self.nstokes
## Testing the initial CMB spectra and C_approx spectra given
if self.nstokes == 2 and (CMB_c_ell.shape[0] != len(indices_to_consider)):
CMB_c_ell = CMB_c_ell[indices_to_consider, :]
if self.nstokes == 2 and (c_ell_approx.shape[0] != len(indices_to_consider)):
c_ell_approx = c_ell_approx[indices_to_consider, :]
assert len(CMB_c_ell.shape) == 2
assert CMB_c_ell.shape[1] == self.lmax + 1
assert len(c_ell_approx.shape) == 2
assert c_ell_approx.shape[1] == self.lmax + 1
red_cov_approx_matrix = jnp.array(get_reduced_matrix_from_c_ell(c_ell_approx)[self.lmin :, ...])
assert (
jnp.linalg.det(red_cov_approx_matrix) != 0
).any(), 'The approximate covariance matrix should be invertible ; if you and to put it to zero, please put it instead to a very small value'
## Testing the initial mixing matrix
if self.n_components != 1:
assert init_params_mixing_matrix.shape == (self.len_params,)
## Testing the input frequency maps
assert len(input_freq_maps.shape) == 3
assert input_freq_maps.shape == (self.n_frequencies, self.nstokes, self.n_pix)
## Testing the mask
assert np.abs(self.mask).sum() != 0
## Testing the initial guess for r
assert np.size(initial_guess_r) == 1
# assert initial_guess_r >= 0 # Not allowing for first guess negative r values
# Preparing for the full Gibbs sampling
len_pos_special_freqs = len(self.pos_special_freqs)
# if self.fwhm is not None:
# ell_range = jnp.arange(self.lmin, self.lmax + 1)
# spin = 2
# self.beam_harmonic = jnp.exp(-0.5 * (ell_range * (ell_range + 1) - spin**2)* self.fwhm**2 / (8 * np.log(2)))
# else:
# self.beam_harmonic = jnp.ones(self.lmax + 1 - self.lmin)
if self.use_binning:
print('Using binning for the sampling of CMB covariance !!!', flush=True)
print('Binning distribution:', self.bin_ell_distribution, flush=True)
## Initial guesses preparation
## eta
initial_eta = jnp.zeros((self.nstokes, self.n_pix))
## CMB covariance preparation in the format [lmax,nstokes,nstokes]
red_cov_matrix = get_reduced_matrix_from_c_ell(CMB_c_ell)[self.lmin :, ...]
## parameters of the mixing matrix
params_mixing_matrix_init_sample = jnp.array(init_params_mixing_matrix, copy=True)
# Preparing the sampling functions
## Function to sample eta
func_logproba_eta = self.get_conditional_proba_correction_likelihood_JAX_v2d
## Function to compute the Wiener filter term
sampling_func_WF = self.solve_generalized_wiener_filter_term_v2d
## Function to sample the fluctuation maps
sampling_func_Fluct = self.get_fluctuating_term_maps_v2d
## Function to sample the CMB covariance from inverse Wishart
func_get_inverse_wishart_sampling_from_c_ells = self.get_inverse_wishart_sampling_from_c_ells
if self.use_binning:
func_get_inverse_wishart_sampling_from_c_ells = self.get_binned_inverse_wishart_sampling_from_c_ells_v3
## Function to sample the CMB covariance parametrize from r
r_sampling_MH = single_Metropolis_Hasting_step
# r_sampling_MH = bounded_single_Metropolis_Hasting_step
if self.sample_r_Metropolis:
log_proba_r = self.get_conditional_proba_C_from_r_wBB
if self.use_binning:
print('Using BB binning for the sampling of r !!!', flush=True)
log_proba_r = self.get_binned_conditional_proba_C_from_r_wBB
## Function to sample the mixing matrix free parameters in the most general way
jitted_Bf_func_sampling = jax.jit(
self.get_conditional_proba_mixing_matrix_v2b_JAX, static_argnames=['biased_bool']
)
sampling_func = separate_single_MH_step_index_accelerated
if self.biased_version or self.perturbation_eta_covariance:
print('Using biased version or perturbation version of mixing matrix sampling !!!', flush=True)
## Function to sample the mixing matrix free parameters through the difference of the log-proba, to have only one CG done
jitted_Bf_func_sampling = jax.jit(
self.get_conditional_proba_mixing_matrix_v3_JAX, static_argnames=['biased_bool']
)
sampling_func = separate_single_MH_step_index_v2b
if self.simultaneous_accept_rate:
## More efficient version of the mixing matrix sampling
## MH step function to sample the mixing matrix free parameters with patches simultaneous computed accept rate
print('Using simultaneous accept rate version of mixing matrix sampling !!!', flush=True)
print(
'---- ATTENTION: This assumes all patches are distributed in the same way for all parameters !',
flush=True,
)
jitted_Bf_func_sampling = jax.jit(
self.get_conditional_proba_mixing_matrix_v3_pixel_JAX,
static_argnames=['biased_bool'],
)
sampling_func = separate_single_MH_step_index_v4_pixel
if (self.size_patches != self.size_patches[0]).any():
sampling_func = separate_single_MH_step_index_v4b_pixel
# raise NotImplemented("All patches should have the same size for the simultaneous accept rate version of mixing matrix sampling for now !!!")
## Redefining the free Bf indexes to sample to the one
# condition_unobserved_patches = self.get_cond_unobserved_patches() ## Get boolean array to identify which free indexes are not relevant
# print("Previous free indexes for Bf", self.indexes_free_Bf, flush=True)
# self.indexes_free_Bf = jnp.array(self.indexes_free_Bf).at[condition_unobserved_patches].get()
# print("New free indexes for Bf", self.indexes_free_Bf, flush=True)
print('Previous free indexes for Bf', self.indexes_free_Bf, self.indexes_free_Bf.size, flush=True)
self.indexes_free_Bf = self.indexes_free_Bf.at[
self.get_cond_unobserved_patches_from_indices_optimized(self.indexes_free_Bf)
].get()
## Get boolean array to identify which free indexes are not relevant
print('New free indexes for Bf', self.indexes_free_Bf, self.indexes_free_Bf.size, flush=True)
indexes_patches_Bf = jnp.array(self.indexes_b.ravel(order='F'), dtype=jnp.int64)
def which_interval(carry, index_Bf):
"""
Selecting the patches to be used for the Bf sampling by checking if the index_Bf is in the interval of the patches
"""
return (
carry
| ((index_Bf >= indexes_patches_Bf) & (index_Bf < indexes_patches_Bf + self.size_patches)),
index_Bf,
)
condition, _ = jlax.scan(
which_interval, jnp.zeros_like(self.size_patches, dtype=bool), self.indexes_free_Bf
)
first_indices_patches_free_Bf = indexes_patches_Bf[condition]
max_len_patches_Bf = int(np.max(self.size_patches[condition]))
size_patches = self.size_patches[condition]
## Preparing minmum value of r sampling
## Preparing the random JAX PRNG key
if np.size(self.seed) == 1:
PRNGKey = random.PRNGKey(self.seed)
elif np.size(self.seed) == 2:
PRNGKey = jnp.array(self.seed, dtype=jnp.uint32)
else:
raise ValueError('Seed should be either a scalar or a 2D array interpreted as a JAX PRNG Key!')
## Computing the number of iterations to perform
actual_number_of_iterations = self.number_iterations_sampling # - self.number_iterations_done
if not (self.classical_Gibbs):
## Preparing the step-size for Metropolis-within-Gibbs of Bf sampling
## try/except step only because jsp.linalg.sqrtm is not implemented in GPU
try:
initial_step_size_Bf = jnp.array(jnp.diag(jsp.linalg.sqrtm(self.covariance_Bf)), dtype=jnp.float64)
except:
initial_step_size_Bf = jnp.array(jnp.diag(jnp.sqrt(self.covariance_Bf)), dtype=jnp.float64)
assert len(initial_step_size_Bf.shape) == 1
print('Step-size Bf', initial_step_size_Bf, flush=True)
if self.covariance_Bf.shape[0] != self.len_params:
print('Covariance matrix for Bf is not of the right shape !', flush=True)
# initial_step_size_Bf = jnp.repeat(initial_step_size_Bf, self.len_params//self.covariance_Bf.shape[0], axis=0)
if self.covariance_Bf.shape[0] != 2 * (self.n_frequencies - len_pos_special_freqs):
raise ValueError(
f'Covariance matrix for Bf is not of the right shape with shape {self.covariance_Bf.shape[0]}, it cannot be properly expanded with the considered multipatch distribution!'
)
if (
self.size_patches is not None and (self.size_patches == self.size_patches[0]).all()
): # If all patches have the same size
initial_step_size_Bf = jnp.broadcast_to( # Broadcasting the step-size for each patch size
initial_step_size_Bf,
(self.len_params // self.covariance_Bf.shape[0], self.covariance_Bf.shape[0]),
).ravel(order='F')
else: # If patches have different sizes
previous_initial_Bf = jnp.copy(initial_step_size_Bf)
initial_step_size_Bf = jnp.zeros(self.len_params)
number_free_Bf = (self.n_frequencies - len_pos_special_freqs) * (self.n_components - 1)
extended_array = np.zeros((number_free_Bf + 1), dtype=np.int64)
extended_array[0] = 0
extended_array[1:] = self.sum_size_patches_indexed_freq_comp.ravel(order='F') + self.size_patches
for i in range(
self.size_patches.size
): # Loop over the patches to update the step-size for each patch size
initial_step_size_Bf = initial_step_size_Bf.at[extended_array[i] : extended_array[i + 1]].set(
previous_initial_Bf[i]
)
initial_step_size_Bf = initial_step_size_Bf.at[extended_array[-1]].set(previous_initial_Bf[-1])
print('New step-size Bf', initial_step_size_Bf, flush=True)
## Few prints to re-check the toml parameters chosen
if self.classical_Gibbs:
print('Not sampling for eta and Bf, only for s_c and the CMB covariance !', flush=True)
if self.sample_r_Metropolis:
print('Sample for r instead of C !', flush=True)
if self.limit_r_value:
print(f'Limiting the r value to be superior to {self.min_r_value} !', flush=True)
if self.non_centered_moves:
print('Using non-centered moves for C sampling !', flush=True)
if self.save_intermediary_centered_moves:
print('Saving intermediary centered moves for C sampling !', flush=True)
else:
print('Sample for C with inverse Wishart !', flush=True)
# Few steps to improve the speed of the code
## Preparing the square root matrix of C_approx
red_cov_approx_matrix_sqrt = get_sqrt_reduced_matrix_from_matrix_jax(red_cov_approx_matrix)
## Preparing the preconditioner in the case of a full sky and white noise
use_precond = False
if self.mask.sum() == self.n_pix and self.freq_noise_c_ell is not None:
assert len(self.freq_noise_c_ell.shape) == 3
assert self.freq_noise_c_ell.shape[0] == self.n_frequencies
assert self.freq_noise_c_ell.shape[1] == self.n_frequencies
assert (self.freq_noise_c_ell.shape[2] == self.lmax + 1) or (
self.freq_noise_c_ell.shape[2] == self.lmax + 1 - self.lmin
)
if self.freq_noise_c_ell.shape[2] == self.lmax + 1:
self.freq_noise_c_ell = self.freq_noise_c_ell[..., self.lmin :]
self.freq_noise_c_ell = jnp.array(self.freq_noise_c_ell)
print('Full sky case !', flush=True)
use_precond = True
## Finally starting the Gibbs sampling !!!
print(
f'Starting {self.number_iterations_sampling} iterations in addition to {self.number_iterations_done} iterations done',
flush=True,
)
@scan_tqdm(
actual_number_of_iterations,
)
def all_sampling_steps(carry, iteration):
"""
1-step Gibbs sampling function, performing the following:
- Sampling of eta, for the correction term ; perform as well a CG if the perturbation approach is chosen
- Sampling of s_c, for the constrained CMB map realization ; sampling both Wiener filter and fluctuation maps
- Sampling of C or r parametrizing C, for the CMB covariance matrix
- Sampling of the free Bf parameters, for the mixing matrix
Parameters
----------
carry: dictionary
dictionary containing the following variables at 1 iteration depending on the option chosen: WF maps, fluctuation maps, CMB covariance, r samples, Bf samples, PRNGKey
iteration: int
current iteration number
Returns
-------
new_carry: dictionary
dictionary containing the following variables at the next iteration: WF maps, fluctuation maps, CMB covariance, r sample, Bf sample, PRNGKey
all_samples: dictionary
dictionary containing the variables to save as chains, so depending on the options chosen: eta maps, WF maps, fluctuation maps, CMB covariance, r sample, Bf sample
"""
# Extracting the JAX PRNG key from the carry
PRNGKey = carry['PRNGKey']
# Preparing the new carry and all_samples to save the chains
new_carry = dict()
all_samples = dict()
# Preparing a new PRNGKey for eta sampling
PRNGKey, subPRNGKey = random.split(PRNGKey)
# Extracting the mixing matrix parameters and initializing the new one
# self.update_params(carry['params_mixing_matrix_sample'])
# mixing_matrix_sampled = self.get_B(jax_use=True)
mixing_matrix_sampled = self.get_B_from_params(carry['params_mixing_matrix_sample'], jax_use=True)
# Few checks for the mixing matrix
chx.assert_shape(mixing_matrix_sampled, (self.n_frequencies, self.n_components, self.n_pix))
# Application of new mixing matrix to the noise covariance and extracted CMB map from data
invBtinvNB = get_inv_BtinvNB(self.freq_inverse_noise, mixing_matrix_sampled, jax_use=True)
BtinvN_sqrt = get_BtinvN(jnp.sqrt(self.freq_inverse_noise), mixing_matrix_sampled, jax_use=True)
s_cML = get_Wd(self.freq_inverse_noise, mixing_matrix_sampled, input_freq_maps, jax_use=True)[0]
# Sampling step 1: sampling of Gaussian variable eta
## Initialize the preconditioner for the eta contribution
precond_func_eta = None
## Sampling of eta if not using the classical Gibbs sampling and neither the biased version
if not (self.classical_Gibbs) and not (self.biased_version):
# Preparing random variables
map_random_x = None
map_random_y = None
# Sampling eta maps
new_carry['eta_maps'] = self.get_sampling_eta_v2(
red_cov_approx_matrix_sqrt,
invBtinvNB,
BtinvN_sqrt,
subPRNGKey,
map_random_x=map_random_x,
map_random_y=map_random_y,
suppress_low_modes=True,
)
# Checking shape of the resulting maps
chx.assert_shape(new_carry['eta_maps'], (self.nstokes, self.n_pix))
# Preparing the preconditioner for the CG
if use_precond:
## Assuming a harmonic noise with the pixel average of the mixing matrix
noise_c_ell = get_inv_BtinvNB_c_ell(self.freq_noise_c_ell, mixing_matrix_sampled.mean(axis=2))[0, 0]
## Getting N_c^{-1} for the harmonic noise covariance
red_inv_noise_c_ell = jnp.linalg.pinv(
get_reduced_matrix_from_c_ell_jax(
jnp.stack([noise_c_ell, noise_c_ell, jnp.zeros_like(noise_c_ell)])
)
)
red_preconditioner_eta = jnp.linalg.pinv(
jnp.eye(self.nstokes)
+ jnp.einsum(
'lij,ljk,lkm->lim',
red_cov_approx_matrix_sqrt,
red_inv_noise_c_ell,
red_cov_approx_matrix_sqrt,
)
)
precond_func_eta = lambda x: maps_x_red_covariance_cell_JAX(
x.reshape((self.nstokes, self.n_pix)),
red_preconditioner_eta,
nside=self.nside,
lmin=self.lmin,
n_iter=self.n_iter,
).ravel()
if self.perturbation_eta_covariance:
# Computing the inverse associated log proba term fixed correction covariance for the Bf sampling, in case of the perturbative approach
_, inverse_term = func_logproba_eta(
invBtinvNB[0, 0],
new_carry['eta_maps'],
red_cov_approx_matrix_sqrt,
first_guess=carry['inverse_term'],
return_inverse=True,
precond_func=precond_func_eta,
)
else:
inverse_term = carry['inverse_term']
if self.save_eta_chain_maps:
all_samples['eta_maps'] = new_carry['eta_maps']
# Sampling step 2: sampling of Gaussian variable s_c, contrained CMB map realization
## Geting the square root matrix of the sampled CMB covariance
red_cov_matrix_sqrt = get_sqrt_reduced_matrix_from_matrix_jax(carry['red_cov_matrix_sample'])
# Preparing the preconditioner to use for the sampling of the CMB maps
precond_func_s_c = None
if use_precond:
## Assuming a harmonic noise with the pixel average of the mixing matrix
noise_c_ell = get_inv_BtinvNB_c_ell(self.freq_noise_c_ell, mixing_matrix_sampled.mean(axis=2))[0, 0]
## Getting N_c^{-1} for the harmonic noise covariance
red_inv_noise_c_ell = jnp.linalg.pinv(
get_reduced_matrix_from_c_ell_jax(
jnp.stack([noise_c_ell, noise_c_ell, jnp.zeros_like(noise_c_ell)])
)
) # [self.lmin:]
red_preconditioner_s_c = jnp.linalg.pinv(
jnp.eye(self.nstokes)
+ jnp.einsum('lij,ljk,lkm->lim', red_cov_matrix_sqrt, red_inv_noise_c_ell, red_cov_matrix_sqrt)
)
precond_func_s_c = lambda x: maps_x_red_covariance_cell_JAX(
x.reshape((self.nstokes, self.n_pix)),
red_preconditioner_s_c,
nside=self.nside,
lmin=self.lmin,
n_iter=self.n_iter,
).ravel()
## Computing an initial guess closer to the actual start of the CG for the Wiener filter
initial_guess_WF = maps_x_red_covariance_cell_JAX(
carry['wiener_filter_term'],
jnp.linalg.pinv(red_cov_matrix_sqrt),
nside=self.nside,
lmin=self.lmin,
n_iter=self.n_iter,
)
## Sampling the Wiener filter term
new_carry['wiener_filter_term'] = sampling_func_WF(
s_cML, red_cov_matrix_sqrt, invBtinvNB, initial_guess=initial_guess_WF, precond_func=precond_func_s_c
)
## Preparing the random variables for the fluctuation term
PRNGKey, new_subPRNGKey = random.split(PRNGKey)
map_random_realization_xi = None
map_random_realization_chi = None
## Getting the fluctuation maps terms, for the variance of the variable s_c
initial_guess_Fluct = maps_x_red_covariance_cell_JAX(
carry['fluctuation_maps'],
jnp.linalg.pinv(red_cov_matrix_sqrt),
nside=self.nside,
lmin=self.lmin,
n_iter=self.n_iter,
)
## Sampling the fluctuation maps
new_carry['fluctuation_maps'] = sampling_func_Fluct(
red_cov_matrix_sqrt,
invBtinvNB,
BtinvN_sqrt,
new_subPRNGKey,
map_random_realization_xi=map_random_realization_xi,
map_random_realization_chi=map_random_realization_chi,
initial_guess=initial_guess_Fluct,
precond_func=precond_func_s_c,
)
## Constructing the sampled CMB map
s_c_sample = new_carry['fluctuation_maps'] + new_carry['wiener_filter_term']
if self.save_CMB_chain_maps:
## Saving the sampled Wiener filter term and fluctuation maps if chosen to
all_samples['wiener_filter_term'] = new_carry['wiener_filter_term']
all_samples['fluctuation_maps'] = new_carry['fluctuation_maps']
## Checking the shape of the resulting maps
chx.assert_shape(new_carry['wiener_filter_term'], (self.nstokes, self.n_pix))
chx.assert_shape(new_carry['fluctuation_maps'], (self.nstokes, self.n_pix))
chx.assert_shape(s_c_sample, (self.nstokes, self.n_pix))
# Sampling step 3: sampling of CMB covariance C
## Preparing the c_ell which will be used for the sampling
c_ells_Wishart_ = get_cell_from_map_jax(s_c_sample, lmax=self.lmax, n_iter=self.n_iter)[:, self.lmin :]
## Saving the corresponding spectrum
if self.save_s_c_spectra:
all_samples['s_c_spectra'] = c_ells_Wishart_
# ### Getting them in the format [lmax,nstokes,nstokes] multiplied by 2 ell+1, to take into account the m
# red_c_ells_Wishart_modified = get_reduced_matrix_from_c_ell_jax(c_ells_Wishart_*(2*jnp.arange(self.lmax+1) + 1))
### Getting them in the format [lmax,nstokes,nstokes] without the facor 2 ell+1 to take into account the m
red_c_ells_Wishart_modified = get_reduced_matrix_from_c_ell_jax(c_ells_Wishart_)
## Preparing the new PRNGkey
PRNGKey, new_subPRNGKey_2 = random.split(PRNGKey)
## Performing the sampling
if self.sample_C_inv_Wishart:
# Sampling C with inverse Wishart
new_carry['red_cov_matrix_sample'] = func_get_inverse_wishart_sampling_from_c_ells(
c_ells_Wishart_,
PRNGKey=new_subPRNGKey_2,
old_sample=carry['red_cov_matrix_sample'],
acceptance_posdef=self.acceptance_posdef,
)
all_samples['red_cov_matrix_sample'] = new_carry['red_cov_matrix_sample']
elif self.sample_r_Metropolis:
# Sampling r which will parametrize C(r) = C_scalar + r*C_tensor
step_size_r = self.step_size_r
if self.use_scam_step_size:
# step_size_r = jnp.where(iteration > self.burn_in_scam, jnp.sqrt(self.s_param_scam*(carry['empirical_variance_r'] + self.epsilon_param_scam_r)), self.step_size_r)
# step_size_r = jnp.where(
# iteration > self.burn_in_scam, jnp.sqrt(carry['empirical_variance_r']), self.step_size_r
# )
step_size_r = jnp.sqrt(carry['empirical_variance_r'])
all_samples['empirical_variance_r'] = step_size_r**2
all_samples['mean_r'] = carry['mean_r']
new_carry['r_sample'] = r_sampling_MH(
random_PRNGKey=new_subPRNGKey_2,
old_sample=carry['r_sample'],
step_size=step_size_r,
log_proba=log_proba_r,
red_sigma_ell=red_c_ells_Wishart_modified,
theoretical_red_cov_r1_tensor=theoretical_red_cov_r1_tensor,
theoretical_red_cov_r0_total=theoretical_red_cov_r0_total,
)
# min_value=self.min_r_to_sample)
if self.limit_r_value:
new_carry['r_sample'] = jnp.where(
new_carry['r_sample'] < self.min_r_value, carry['r_sample'], new_carry['r_sample']
)
## Reconstructing the new spectra from r
new_carry['red_cov_matrix_sample'] = (
theoretical_red_cov_r0_total + new_carry['r_sample'] * theoretical_red_cov_r1_tensor
)
## Binning if needed
if self.use_binning:
new_carry['red_cov_matrix_sample'] = self.bin_and_reproject_red_c_ell(
new_carry['red_cov_matrix_sample']
)
## Saving the r sample
all_samples['r_sample'] = new_carry['r_sample']
else:
raise Exception('C not sampled in any way !!! It must be either inv Wishart or through r sampling !')
if self.non_centered_moves:
PRNGKey, new_subPRNGKey_2b = random.split(PRNGKey)
if self.sample_r_Metropolis:
new_r_sample = r_sampling_MH(
random_PRNGKey=new_subPRNGKey_2b,
old_sample=new_carry['r_sample'],
step_size=self.step_size_r,
log_proba=self.get_log_proba_non_centered_move_C_from_r,
old_r_sample=new_carry['r_sample'],
invBtinvNB=invBtinvNB,
s_cML=s_cML,
s_c_sample=s_c_sample,
theoretical_red_cov_r1_tensor=theoretical_red_cov_r1_tensor,
theoretical_red_cov_r0_total=theoretical_red_cov_r0_total,
)
# min_value=self.min_r_to_sample)
if self.limit_r_value:
new_r_sample = jnp.where(new_r_sample < self.min_r_value, new_carry['r_sample'], new_r_sample)
new_carry['red_cov_matrix_sample'] = (
theoretical_red_cov_r0_total + new_r_sample * theoretical_red_cov_r1_tensor
)
if self.save_intermediary_centered_moves:
all_samples['r_sample'] = jnp.stack((new_carry['r_sample'], new_r_sample))
else:
all_samples['r_sample'] = new_r_sample
new_carry['r_sample'] = new_r_sample
## Checking the shape of the resulting covariance matrix, and correcting it if needed
if new_carry['red_cov_matrix_sample'].shape[0] == self.lmax + 1:
new_carry['red_cov_matrix_sample'] = new_carry['red_cov_matrix_sample'][self.lmin :]
## Small check on the shape of the resulting covariance matrix
chx.assert_shape(
new_carry['red_cov_matrix_sample'], (self.lmax + 1 - self.lmin, self.nstokes, self.nstokes)
)
# Sampling step 4: sampling of mixing matrix Bf
## Preparation of sampling step 4
## First preparing the term: d - B_c s_c
full_data_without_CMB = input_freq_maps - jnp.broadcast_to(
s_c_sample, (self.n_frequencies, self.nstokes, self.n_pix)
)
chx.assert_shape(full_data_without_CMB, (self.n_frequencies, self.nstokes, self.n_pix))
## Preparing the new PRNGKey
PRNGKey, new_subPRNGKey_3 = random.split(PRNGKey)
## Performing the sampling
if not (self.classical_Gibbs):
# Preparing the step-size
step_size_Bf = initial_step_size_Bf
if self.use_scam_step_size:
# step_size_Bf = jnp.where(iteration > self.burn_in_scam, jnp.sqrt(self.s_param_scam *(carry['empirical_variance_Bf'] + self.epsilon_param_scam_Bf)), initial_step_size_Bf)
# step_size_Bf = jnp.where(
# iteration > self.burn_in_scam, jnp.sqrt(carry['empirical_variance_Bf']), initial_step_size_Bf
# )
step_size_Bf = jnp.sqrt(carry['empirical_variance_Bf'])
# all_samples['empirical_variance_Bf'] = step_size_Bf
all_samples['empirical_variance_Bf'] = carry['empirical_variance_Bf']
all_samples['mean_Bf'] = carry['mean_Bf']
# Sampling Bf
if self.perturbation_eta_covariance or self.biased_version:
## Preparing the parameters to provide for the sampling of Bf
dict_parameters_sampling_Bf = {
'indexes_Bf': self.indexes_free_Bf,
'full_data_without_CMB': full_data_without_CMB,
'red_cov_approx_matrix_sqrt': red_cov_approx_matrix_sqrt,
'old_params_mixing_matrix': carry['params_mixing_matrix_sample'],
'biased_bool': self.biased_version,
}
if self.perturbation_eta_covariance:
## Precomputing the term C_approx^{1/2} A^{-1} eta = C_approx^{1/2} ( Id + C_approx^{1/2} N_{c,old}^{-1} C_approx^{1/2} )^{-1} eta
inverse_term_x_Capprox_root = maps_x_red_covariance_cell_JAX(
inverse_term.reshape(self.nstokes, self.n_pix),
red_cov_approx_matrix_sqrt,
nside=self.nside,
lmin=self.lmin,
n_iter=self.n_iter,
).ravel()
dict_parameters_sampling_Bf['previous_inverse_x_Capprox_root'] = inverse_term_x_Capprox_root
dict_parameters_sampling_Bf['first_guess'] = inverse_term
if not (self.biased_version):
## If not biased, provide the eta maps
dict_parameters_sampling_Bf['component_eta_maps'] = new_carry['eta_maps']
if self.simultaneous_accept_rate:
## Provide as well the indexes of the patches in case of the uncorrelated patches version
dict_parameters_sampling_Bf['size_patches'] = size_patches
dict_parameters_sampling_Bf['max_len_patches_Bf'] = max_len_patches_Bf
dict_parameters_sampling_Bf['indexes_patches_Bf'] = first_indices_patches_free_Bf
dict_parameters_sampling_Bf['len_indexes_Bf'] = self.len_params
# TODO: Accelerate by removing indexes of indexes_patches_Bf if the corresponding patches are not in indexes_free_Bf, nor in the mask
## Sampling Bf !
new_subPRNGKey_3, new_carry['params_mixing_matrix_sample'] = sampling_func(
random_PRNGKey=new_subPRNGKey_3,
old_sample=carry['params_mixing_matrix_sample'],
step_size=step_size_Bf,
log_proba=jitted_Bf_func_sampling,
**dict_parameters_sampling_Bf,
)
else:
## Sampling Bf with older version -> might be slower
new_subPRNGKey_3, new_carry['params_mixing_matrix_sample'], inverse_term = sampling_func(
random_PRNGKey=new_subPRNGKey_3,
old_sample=carry['params_mixing_matrix_sample'],
step_size=step_size_Bf,
indexes_Bf=self.indexes_free_Bf,
log_proba=jitted_Bf_func_sampling,
full_data_without_CMB=full_data_without_CMB,
component_eta_maps=new_carry['eta_maps'],
red_cov_approx_matrix_sqrt=red_cov_approx_matrix_sqrt,
first_guess=carry['inverse_term'],
biased_bool=self.biased_version,
precond_func=precond_func_eta,
)
if self.perturbation_eta_covariance:
## Passing the inverse term to the next iteration
new_carry['inverse_term'] = inverse_term
# Checking the shape of the resulting mixing matrix
chx.assert_shape(new_carry['params_mixing_matrix_sample'], (self.len_params,))
else:
## Classical Gibbs sampling, no need to sample Bf but still needs to provide them to the next iteration in case it is used for the CMB noise component
new_carry['params_mixing_matrix_sample'] = carry['params_mixing_matrix_sample']
# all_samples['params_mixing_matrix_sample'] = new_carry['params_mixing_matrix_sample']
## Saving the Bf obtained
all_samples['params_mixing_matrix_sample'] = new_carry['params_mixing_matrix_sample']
# Updating the step-size in case of SCAM for the Metropolis-Hastings step
if self.use_scam_step_size:
## Using the SCAM step-size for the Metropolis-Hasting step
# new_carry = self.update_scam_step_size(carry, new_carry, iteration)
total_number_iterations = (
iteration + self.number_iterations_done + 1 - self.burn_in_scam // self.scam_iteration_updates
)
update_scam_step_size = jnp.logical_and(
total_number_iterations > 0, total_number_iterations % self.scam_iteration_updates == 0
)
# Update the SCAM step size for the Metropolis-Hastings step sampling of r
# new_carry['empirical_variance_r'] = get_1d_recursive_empirical_covariance(
# total_number_iterations,
# new_carry['r_sample'],
# carry['mean_r'],
# carry['empirical_variance_r'],
# s_param=self.s_param_scam,
# epsilon_param=self.epsilon_param_scam_r,
# ).squeeze()
# new_carry['mean_r'] = (total_number_iterations * carry['mean_r'] + carry['r_sample']) / (
# total_number_iterations + 1
# )
new_carry['empirical_variance_r'] = jax.lax.cond(
update_scam_step_size,
lambda x: get_1d_recursive_empirical_covariance(
total_number_iterations,
new_carry['r_sample'],
carry['mean_r'],
x,
s_param=self.s_param_scam,
epsilon_param=self.epsilon_param_scam_r,
).squeeze(),
lambda x: x,
carry['empirical_variance_r'],
)
new_carry['mean_r'] = jax.lax.cond(
update_scam_step_size,
lambda x: (total_number_iterations * carry['mean_r'] + x) / (total_number_iterations + 1),
lambda x: x,
new_carry['r_sample'],
)
# Update the SCAM step size for the Metropolis-Hastings step sampling of Bf
# new_carry['empirical_variance_Bf'] = get_1d_recursive_empirical_covariance(
# total_number_iterations,
# new_carry['params_mixing_matrix_sample'],
# carry['mean_Bf'],
# carry['empirical_variance_Bf'],
# s_param=self.s_param_scam,
# epsilon_param=self.epsilon_param_scam_Bf,
# )
# new_carry['mean_Bf'] = (
# total_number_iterations * carry['mean_Bf'] + carry['params_mixing_matrix_sample']
# ) / (total_number_iterations + 1)
new_carry['empirical_variance_Bf'] = jax.lax.cond(
update_scam_step_size,
lambda x: get_1d_recursive_empirical_covariance(
total_number_iterations,
new_carry['params_mixing_matrix_sample'],
carry['mean_Bf'],
x,
s_param=self.s_param_scam,
epsilon_param=self.epsilon_param_scam_Bf,
),
lambda x: x,
carry['empirical_variance_Bf'],
)
new_carry['mean_Bf'] = jax.lax.cond(
update_scam_step_size,
lambda x: (total_number_iterations * carry['mean_Bf'] + x) / (total_number_iterations + 1),
lambda x: x,
new_carry['params_mixing_matrix_sample'],
)
## Passing as well the PRNGKey to the next iteration
new_carry['PRNGKey'] = PRNGKey
return new_carry, all_samples
## Preparing the initial carry
initial_carry = {
'wiener_filter_term': wiener_filter_term,
'fluctuation_maps': fluctuation_maps,
'red_cov_matrix_sample': red_cov_matrix,
'params_mixing_matrix_sample': params_mixing_matrix_init_sample,
'PRNGKey': PRNGKey,
}
if not (self.classical_Gibbs) and not (self.biased_version):
initial_carry['eta_maps'] = initial_eta
if not (self.classical_Gibbs) and not (self.biased_version):
initial_carry['inverse_term'] = jnp.zeros_like(initial_eta)
if self.sample_r_Metropolis:
initial_carry['r_sample'] = initial_guess_r
if self.save_s_c_spectra:
self.all_samples_s_c_spectra = self.update_variable(
self.all_samples_s_c_spectra,
jnp.expand_dims(jnp.zeros((self.n_correlations, self.lmax + 1 - self.lmin)), axis=0),
)
## Initialising the first carry to the chains saved
self.update_one_sample(initial_carry)
print(
'###### Time before entering scan and all_sampling_steps',
(time.time() - time_test) / 60,
'minutes',
flush=True,
)
if self.use_scam_step_size:
initial_carry['empirical_variance_r'] = jnp.array(self.step_size_r) ** 2
initial_carry['empirical_variance_Bf'] = initial_step_size_Bf**2
initial_carry['mean_r'] = jnp.array(initial_guess_r)
initial_carry['mean_Bf'] = jnp.array(params_mixing_matrix_init_sample)
if 'empirical_variance_r' in dictionnary_additional_parameters:
print(
'Setting the empirical variance for r to the one provided in the additional parameters!', flush=True
)
initial_carry['empirical_variance_r'] = dictionnary_additional_parameters['empirical_variance_r']
if 'empirical_variance_Bf' in dictionnary_additional_parameters:
print(
'Setting the empirical variance for Bf to the one provided in the additional parameters!',
flush=True,
)
initial_carry['empirical_variance_Bf'] = dictionnary_additional_parameters['empirical_variance_Bf']
if 'mean_r' in dictionnary_additional_parameters:
print('Setting the mean value for r to the one provided in the additional parameters!', flush=True)
initial_carry['mean_r'] = jnp.array(dictionnary_additional_parameters['mean_r']).squeeze()
if 'mean_Bf' in dictionnary_additional_parameters:
print('Setting the mean value for Bf to the one provided in the additional parameters!', flush=True)
initial_carry['mean_Bf'] = dictionnary_additional_parameters['mean_Bf']
assert (initial_carry['empirical_variance_r'] > 0).all()
assert (initial_carry['empirical_variance_Bf'] > 0).all()
## Starting the Gibbs sampling !!!!
time_start_sampling = time.time()
# Start sampling !!!
last_sample, all_samples = jlax.scan(all_sampling_steps, initial_carry, jnp.arange(actual_number_of_iterations))
time_full_chain = (time.time() - time_start_sampling) / 60
print(f'End of Gibbs chain in {time_full_chain} minutes, saving all files !', flush=True)
# Saving the samples as attributes of the Sampler object
time_start_updating = time.time()
self.update_samples(all_samples)
time_end_updating = (time.time() - time_start_updating) / 60
print(f'End of updating in {time_end_updating} minutes', flush=True)
# Saving step-sizes if SCAM is used
if self.use_scam_step_size:
self.all_empirical_variance_Bf = all_samples['empirical_variance_Bf']
self.all_empirical_variance_r = all_samples['empirical_variance_r']
# Saving the corresponding mean values for testing purposes
self.all_mean_r = all_samples['mean_r']
self.all_mean_Bf = all_samples['mean_Bf']
self.number_iterations_done = self.number_iterations_sampling
last_sample['number_iterations_done'] = self.number_iterations_done
print('Last key PRNG', last_sample['PRNGKey'], flush=True)
self.last_PRNGKey = last_sample['PRNGKey']
## Saving the last sample
self.last_sample = last_sample