# This file is part of MICMAC.
# Copyright (C) 2024 CNRS / SciPol developers
#
# MICMAC is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# MICMAC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with MICMAC. If not, see <https://www.gnu.org/licenses/>.
import copy
import chex as chx
import jax
import jax.numpy as jnp
import numpy as np
from micmac.foregrounds.templates import (
create_one_template,
create_one_template_from_bdefaultvalue,
get_n_patches_b,
get_values_b,
)
__all__ = ['get_indexes_b', 'get_indexes_patches', 'get_len_params', 'MixingMatrix']
# Note:
# the mixing matrix is supposed to be the same for Q and U Stokes params
# (also we suppose that I is not used)
# Mixing matrix dimensions: n_frequencies*n_components*number_pixels
[docs]
def get_indexes_b(n_frequencies, n_components, spv_nodes_b):
"""
Return indexes of params for all frequencies and components
Parameters
----------
n_frequencies: int
Number of frequencies
n_components: int
Number of components
spv_nodes_b: list
List of nodes for b containing info patches to build spv_templates
Returns
-------
indexes: array
Indexes of params for all frequencies and components
"""
indexes = np.zeros((n_frequencies, n_components), dtype=int)
for freq in range(n_frequencies):
for comp in range(n_components):
indexes[freq, comp] = get_n_patches_b(spv_nodes_b[freq + comp * n_frequencies - 1])
indexes[0, 0] = 0
return indexes.ravel(order='F').cumsum().reshape((n_frequencies, n_components), order='F')
[docs]
def get_indexes_patches(indexes_b, len_params):
"""
Return indexes of params for all patches of a b
Parameters
----------
indexes_b: array
Indexes of params for all frequencies and components
len_params: int
Total number of free parameters (per frequency, component, patch)
Returns
-------
indexes_patches_list: list
List of indexes of params for all patches of a b
"""
indexes_patches_list = []
for i in range(indexes_b.shape[0] - 1):
indexes_patches_list.append(jnp.arange(indexes_b[i], indexes_b[i + 1]))
indexes_patches_list.append(jnp.arange(indexes_b[-1], len_params))
return indexes_patches_list
[docs]
def get_len_params(spv_nodes_b):
"""
Return total number of free parameters (per frequency, component, patch)
Parameters
----------
spv_nodes_b: list
List of nodes for b containing info patches to build spv_templates
Returns
-------
len_params: int
Total number of free parameters (per frequency, component, patch)
"""
len_params = 0
for node in spv_nodes_b:
len_params += get_n_patches_b(node)
return len_params
[docs]
class MixingMatrix:
[docs]
def __init__(self, frequency_array, n_components, spv_nodes_b, nside, params=None, pos_special_freqs=[0, -1]):
"""
Note: units are K_CMB.
Parameters
----------
frequency_array: array
Array of frequencies
n_components: int
Number of components
spv_nodes_b: list
List of nodes for b containing info patches to build spv_templates
nside: int
Healpix nside of the expected input maps
params: array (optional)
Initial values of the parameters of the mixing matrix, default None (then initialized with zeros)
pos_special_freqs: list (optional)
List of indexes of special frequencies (e.g. 0 for synchrotron, -1 for dust)
"""
self.frequency_array = np.array(frequency_array, dtype=int) # all input freq bands
self.n_frequencies = np.size(frequency_array) # all input freq bands
self.n_components = n_components # all comps (also cmb)
self.spv_nodes_b = spv_nodes_b # nodes for b containing info patches to build spv_templates
self.nside = nside # nside of the expected input maps
self.len_params = get_len_params(
self.spv_nodes_b
) # total number of free parameters (per frequency, component, patch)
if params is None:
params = np.zeros(self.len_params)
else:
try:
assert np.shape(params)[0] == self.len_params
except:
raise Exception('params must be of dimensions', self.len_params, flush=True)
self.params = params
# Indexes frequency array without the special frequencies
self.indexes_frequency_array_no_special = np.delete(np.arange(self.n_frequencies), pos_special_freqs)
### checks on pos_special_freqs
# check no duplicates
assert len(pos_special_freqs) == len(set(pos_special_freqs))
# make pos_special_freqs only positive
for i, val_i in enumerate(pos_special_freqs):
if val_i < 0:
pos_special_freqs[i] = self.n_frequencies + pos_special_freqs[i]
self.pos_special_freqs = pos_special_freqs
if self.n_components != 1:
# Values of the patch nsides corresponding to each node
self.values_b = (
jnp.array(
get_values_b(
self.spv_nodes_b, self.n_frequencies - len(self.pos_special_freqs), self.n_components - 1
)
)
.ravel(order='F')
.reshape((self.n_frequencies - len(self.pos_special_freqs), self.n_components - 1), order='F')
)
# Values of the first index of each Bf parameter in params
self.indexes_b = jnp.array(
get_indexes_b(self.n_frequencies - len(self.pos_special_freqs), self.n_components - 1, self.spv_nodes_b)
)
self.size_patches = jnp.array([get_n_patches_b(node) for node in self.spv_nodes_b])
self.sum_size_patches_indexed_freq_comp = (self.size_patches.cumsum() - self.size_patches).reshape(
(self.n_frequencies - len(self.pos_special_freqs), self.n_components - 1), order='F'
)
self.max_len_patches_Bf = int(self.size_patches.max())
n_unknown_freqs = self.n_frequencies - self.n_components + 1
n_comp_fgs = self.n_components - 1
self.multipatch_bool = not (
(self.size_patches == 1).all() and (self.len_params == n_comp_fgs * n_unknown_freqs)
)
else:
self.values_b = None # Values of the patch nsides corresponding to each node
self.indexes_b = jnp.array([[0]]) # Values of the first index of each Bf parameter in params
self.size_patches = None # Number of patches for each node
self.sum_size_patches_indexed_freq_comp = None # Cumulative sum of the number of patches for each node
self.max_len_patches_Bf = None # Maximum number of patches for each node
self.multipatch_bool = False
@property
def n_pix(self):
"""
Number of pixels of one input freq map
"""
return 12 * self.nside**2
[docs]
def update_params(self, new_params, jax_use=False):
"""
Update values of the params in the mixing matrix.
Parameters
----------
new_params: array
New values of the parameters of the mixing matrix
jax_use: bool (optional)
If True, use JAX to update it as JAX Array, default False
"""
if jax_use:
chx.assert_shape(new_params, (self.len_params,))
self.params = jnp.array(new_params)
return
assert np.shape(new_params)[0] == self.len_params
self.params = new_params
return
[docs]
def get_params_long_python(self, params, print_bool=False):
# only python version
"""
From the params (Bf) to all the entries of the mixing matrix
only with Python (numpy) without JAX
Parameters
----------
params: array[float]
Flattened version of all free parameters of the mixing matrix per patch
expected to be stored as [Bf1_comp1_patch1, Bf1_comp1_patch2, ..., Bf2_comp1_patch1, ..., Bf1_comp2_patch1, ..., Bfn_comp2_patchn, ...]
print_bool: bool (optional)
If True, print the node names
Returns
-------
params_long: array[float] of dimensions [n_frequencies - n_components + 1, n_components - 1, n_pix]
Reshaped free parameters of the mixing matrix
"""
n_unknown_freqs = self.n_frequencies - self.n_components + 1
n_comp_fgs = self.n_components - 1
params_long = np.zeros((n_unknown_freqs, n_comp_fgs, self.n_pix))
ind_params = 0
for ind_node_b, node_b in enumerate(self.spv_nodes_b):
if print_bool:
print('node: ', node_b.parent.name, node_b.name)
# template of all the patches for this b
spv_template_b = np.array(
create_one_template(node_b, nside=self.nside, all_nsides=None, spv_templates=None)
)
# hp.mollview(spv_template_b)
# plt.show()
# loop over the patches of this b
params_long_b = np.zeros(self.n_pix)
for b in range(get_n_patches_b(node_b)):
params_long_b += np.where(spv_template_b == b, 1, 0) * params[ind_params]
ind_params += 1
# hp.mollview(params_long_b)
# plt.show()
ind_freq = np.where(ind_node_b < n_unknown_freqs, ind_node_b, ind_node_b - n_unknown_freqs)
ind_comp = np.where(ind_node_b < n_unknown_freqs, 0, 1)
params_long[ind_freq, ind_comp, :] = params_long_b
return params_long
# def pure_call_ud_get_params_long_python(self, params):
# """
# JAX Pure call to get_params_long_python
# Parameters
# ----------
# params: compressed version of the parameters of the mixing matrix
# Returns
# -------
# Full parameters of the mixing matrix
# """
# shape_output = (self.n_frequencies-self.n_components+1,self.n_components-1,12*self.nside**2,)
# return jax.pure_callback(self.get_params_long_python, jax.ShapeDtypeStruct(shape_output, np.float64),params,)
[docs]
def get_idx_template_params_long_python(self, idx_template, params, print_bool=False):
# only python version
"""
From the params to all the entries of the mixing matrix
For a given template index, retrieve the corresponding template
Parameters
----------
idx_template: array[int]
index of params of the corresponding template which will be saved
params: array[float]
flatttened compressed array of the free params of the mixing matrix
expected to be stored as [Bf1_comp1_patch1, Bf1_comp1_patch2, ..., Bf2_comp1_patch1, ..., Bf1_comp2_patch1, ..., Bfn_comp2_patchn, ...]
Returns
-------
idx_template: array[int]
Full parameters of the mixing matrix re-flattened as [(n_frequencies - n_components + 1)*(n_components - 1), n_pix]
stacked with one patch distribution template indicated by idx_template
"""
n_unknown_freqs = self.n_frequencies - self.n_components + 1
n_comp_fgs = self.n_components - 1
params_long = np.zeros((n_unknown_freqs, n_comp_fgs, self.n_pix))
all_templates = []
ind_params = 0
for ind_node_b, node_b in enumerate(self.spv_nodes_b):
if print_bool:
print('node: ', node_b.parent.name, node_b.name)
# template of all the patches for this b
spv_template_b = np.array(
create_one_template(node_b, nside=self.nside, all_nsides=None, spv_templates=None)
)
# loop over the patches of this b
params_long_b = np.zeros(self.n_pix)
patch_arange = np.arange(get_n_patches_b(node_b))
arange_ind_params = ind_params + patch_arange
if np.isin(arange_ind_params, idx_template).any():
all_templates.append(spv_template_b)
for b in patch_arange:
params_long_b += np.where(spv_template_b == b, 1, 0) * params[ind_params]
ind_params += 1
# hp.mollview(params_long_b)
# plt.show()
ind_freq = np.where(ind_node_b < n_unknown_freqs, ind_node_b, ind_node_b - n_unknown_freqs)
ind_comp = np.where(ind_node_b < n_unknown_freqs, 0, 1)
params_long[ind_freq, ind_comp, :] = params_long_b
all_templates = np.array(all_templates)
return np.vstack([params_long.reshape((n_unknown_freqs * n_comp_fgs, self.n_pix)), all_templates.squeeze()])
# def pure_call_ud_get_idx_template_params_long_python(self, idx_template, params):
# """
# JAX Pure call to get_params_long_python
# Parameters
# ----------
# idx_template
# params: compressed version of the parameters of the mixing matrix
# Returns
# -------
# Full parameters of the mixing matrix
# """
# n_unknown_freqs = self.n_frequencies-self.n_components+1
# n_comp_fgs = self.n_components-1
# shape_output = (((n_unknown_freqs*n_comp_fgs+1),self.n_pix))
# output_pure_call_back = jax.pure_callback(self.get_idx_template_params_long_python, jax.ShapeDtypeStruct(shape_output, np.float64),idx_template,params,)
# return output_pure_call_back[:-1].reshape((n_unknown_freqs,n_comp_fgs,self.n_pix,)), output_pure_call_back[-1]
[docs]
def get_all_templates(self):
"""
Retrieve all templates maps whose values correspond to the indices of params,
and indexed per frequency and component
Returns
-------
all_templates: array[int] of dimensions [n_frequencies - n_components + 1, n_components - 1, 12*nside**2]
All templates indexes maps whose values correspond to the indices of params
for all the patches distributions per frequency and component
"""
n_unknown_freqs = self.n_frequencies - self.n_components + 1
n_comp_fgs = self.n_components - 1
# if not self.multipatch_bool: # No multipatch
# return jnp.broadcast_to(
# jnp.arange(self.len_params).reshape((n_comp_fgs, n_unknown_freqs), order='F').T,
# (self.n_pix, n_unknown_freqs, n_comp_fgs),
# ).T
## Creating all the templates
def create_all_templates_indexed_freq(idx_freq):
def create_all_templates_indexed_comp(idx_comp):
template_idx_comp = create_one_template_from_bdefaultvalue(
jnp.expand_dims(self.values_b[idx_freq, idx_comp], axis=0),
self.nside,
all_nsides=None,
spv_templates=None,
use_jax=True,
print_bool=False,
)
return template_idx_comp + self.sum_size_patches_indexed_freq_comp[idx_freq, idx_comp]
template_idx_freq_comp = jax.vmap(create_all_templates_indexed_comp)(jnp.arange(n_comp_fgs))
return template_idx_freq_comp
## Maping over the functions to create the templates
return jax.vmap(create_all_templates_indexed_freq)(jnp.arange(n_unknown_freqs))
[docs]
def get_one_template(self, nside_patch):
"""
Retrieve all templates maps whose values correspond to the indices of params,
and indexed per frequency and component
Parameters
----------
nside_patch: int
Healpix nside of one patch distribution
Returns
-------
template: array[int] of dimensions [12*nside_patch**2]
One template indexes map whose values correspond to the indices of params
for one patch distribution according to Healpix pixelization
"""
return create_one_template_from_bdefaultvalue(
jnp.expand_dims(nside_patch, axis=0),
self.nside,
all_nsides=None,
spv_templates=None,
use_jax=True,
print_bool=False,
)
[docs]
def get_params_long(self, jax_use=False):
"""
From the params to all the entries of the mixing matrix
Parameters
----------
jax_use: bool (optional)
If True, params are expected as JAX Array, default False
Returns
-------
params_long: array[float] of dimensions [n_frequencies - n_components + 1, n_components - 1, n_pix]
Reshaped free parameters of the mixing matrix
"""
if jax_use:
templates_to_fill = self.get_all_templates()
## Filling the templates with parameters values
return self.params.at[templates_to_fill].get()
return self.get_params_long_python(self.params)
[docs]
def get_B_fgs(self, jax_use=False):
"""
Foreground part of the mixing matrix.
Parameters
----------
jax_use: bool (optional)
If True, params are expected as JAX Array, default False
Returns
-------
B_fgs: array[float] of dimensions [n_frequencies, n_components - 1, n_pix]
Foreground part of the mixing matrix (including special frequencies)
"""
ncomp_fgs = self.n_components - 1
params_long = self.get_params_long(jax_use=jax_use)
if jax_use:
B_fgs = jnp.zeros((self.n_frequencies, ncomp_fgs, self.n_pix))
# insert all the ones given by the pos_special_freqs
B_fgs = B_fgs.at[jnp.array(self.pos_special_freqs), ...].set(
jnp.broadcast_to(jnp.eye(ncomp_fgs), (self.n_pix, ncomp_fgs, ncomp_fgs)).T
)
# insert all the parameters values
B_fgs = B_fgs.at[self.indexes_frequency_array_no_special, ...].set(params_long)
return B_fgs
if ncomp_fgs != 0:
assert params_long.shape == ((self.n_frequencies - len(self.pos_special_freqs)), ncomp_fgs, self.n_pix)
assert len(self.pos_special_freqs) <= ncomp_fgs
B_fgs = np.zeros((self.n_frequencies, ncomp_fgs, self.n_pix))
if len(self.pos_special_freqs) != 0:
# insert all the ones given by the pos_special_freqs
for c in range(len(self.pos_special_freqs)):
B_fgs[self.pos_special_freqs[c]][c] = 1
# insert all the parameters values
f = 0
for i in range(self.n_frequencies):
if i not in self.pos_special_freqs:
B_fgs[i, :] = params_long[f, :, :]
f += 1
return B_fgs
[docs]
def get_B_cmb(self, jax_use=False):
"""
CMB column of the mixing matrix.
Parameters
----------
jax_use: bool (optional)
If True, returned as JAX Array, default False
Returns
-------
B_cmb: array[float] of dimensions [n_frequencies, 1, n_pix]
CMB column of the mixing matrix, filled with ones
"""
if jax_use:
B_cmb = jnp.ones((self.n_frequencies, self.n_pix))
return B_cmb[:, np.newaxis, :]
B_cmb = np.ones((self.n_frequencies, self.n_pix))
B_cmb = B_cmb[:, np.newaxis, :]
return B_cmb
[docs]
def get_B(self, jax_use=False):
"""
Full mixing matrix, (n_frequencies*n_components).
CMB is given as the first component.
Parameters
----------
jax_use: bool (optional)
If True, returned as JAX Array, default False
Returns
-------
B_mat: array[float] of dimensions [n_frequencies, n_components, n_pix]
Full mixing matrix
"""
if jax_use:
if self.n_components != 1:
return jnp.concatenate((self.get_B_cmb(jax_use=jax_use), self.get_B_fgs(jax_use=jax_use)), axis=1)
else:
return self.get_B_cmb(jax_use=jax_use)
if self.n_components != 1:
B_mat = np.concatenate((self.get_B_cmb(), self.get_B_fgs()), axis=1)
else:
B_mat = self.get_B_cmb()
return B_mat
[docs]
def get_B_fgs_from_params(self, params, jax_use=False):
"""
Foreground part of the mixing matrix obtained from the parameters.
Parameters
----------
params: array[float]
Flattened version of all free parameters of the mixing matrix per patch
expected to be stored as [Bf1_comp1_patch1, Bf1_comp1_patch2, ..., Bf2_comp1_patch1, ..., Bf1_comp2_patch1, ..., Bfn_comp2_patchn, ...]
jax_use: bool (optional)
If True, params are expected as JAX Array and B_fgs will be returned as JAX Array, default False
Returns
-------
B_fgs: array[float] of dimensions [n_frequencies, n_components - 1, n_pix]
Foreground part of the mixing matrix (including special frequencies)
"""
ncomp_fgs = self.n_components - 1
if jax_use:
# Get all templates
templates = self.get_all_templates()
B_fgs = jnp.zeros((self.n_frequencies, ncomp_fgs, self.n_pix))
# insert all the ones given by the pos_special_freqs
B_fgs = B_fgs.at[jnp.array(self.pos_special_freqs), ...].set(
jnp.broadcast_to(jnp.eye(ncomp_fgs), (self.n_pix, ncomp_fgs, ncomp_fgs)).T
)
# insert all the parameters values
B_fgs = B_fgs.at[self.indexes_frequency_array_no_special, ...].set(params.at[templates].get())
return B_fgs
params_long = self.get_params_long_python(params)
if ncomp_fgs != 0:
assert params_long.shape == ((self.n_frequencies - len(self.pos_special_freqs)), ncomp_fgs, self.n_pix)
assert len(self.pos_special_freqs) <= ncomp_fgs
B_fgs = np.zeros((self.n_frequencies, ncomp_fgs, self.n_pix))
if len(self.pos_special_freqs) != 0:
# insert all the ones given by the pos_special_freqs
for c in range(len(self.pos_special_freqs)):
B_fgs[self.pos_special_freqs[c]][c] = 1
# insert all the parameters values
f = 0
for i in range(self.n_frequencies):
if i not in self.pos_special_freqs:
B_fgs[i, :] = params_long[f, :, :]
f += 1
return B_fgs
[docs]
def get_B_from_params(self, params, jax_use=False):
"""
Full mixing matrix, (n_frequencies*n_components), obtained from the parameters.
CMB is given as the first component.
Parameters
----------
params: array[float]
Flattened version of all free parameters of the mixing matrix per patch
expected to be stored as [Bf1_comp1_patch1, Bf1_comp1_patch2, ..., Bf2_comp1_patch1, ..., Bf1_comp2_patch1, ..., Bfn_comp2_patchn, ...]
jax_use: bool (optional)
If True, params are expected as JAX Array and B_mat will be returned as JAX Array, default False
Returns
-------
B_mat: array[float] of dimensions [n_frequencies, n_components, n_pix]
Full mixing matrix
"""
if jax_use:
if self.n_components != 1:
return jnp.concatenate(
(self.get_B_cmb(jax_use=jax_use), self.get_B_fgs_from_params(params, jax_use=jax_use)), axis=1
)
else:
return self.get_B_cmb(jax_use=jax_use)
B_mat = np.concatenate((self.get_B_cmb(), self.get_B_fgs_from_params(params)), axis=1)
return B_mat
[docs]
def get_template_B_fgs_from_params(self, nside_patch, params, jax_use=False):
"""
Foreground (fgs) part of the mixing matrix and one patch distribution template
obtained from nside_patch expected lower than nside of the input maps.
Parameters
----------
nside_patch: int
Healpix nside of one patch distribution, expected lower than nside of the input maps
params: array[float]
Flattened version of all free parameters of the mixing matrix per patch
expected to be stored as [Bf1_comp1_patch1, Bf1_comp1_patch2, ..., Bf2_comp1_patch1, ..., Bf1_comp2_patch1, ..., Bfn_comp2_patchn, ...]
jax_use: bool (optional)
If True, params are expected as JAX Array and the results will be returned as JAX Array, default False
Returns
-------
B_fgs: array[float] of dimensions [n_frequencies, n_components - 1, n_pix]
Foreground part of the mixing matrix (including special frequencies)
template: array[int] of dimensions [12*nside_patch**2]
One template indexes map whose values correspond to the indices of params
for one patch distribution according to Healpix pixelization
"""
ncomp_fgs = self.n_components - 1
if jax_use:
# Get all templates
templates = self.get_all_templates()
B_fgs = jnp.zeros((self.n_frequencies, ncomp_fgs, self.n_pix))
# insert all the ones given by the pos_special_freqs
B_fgs = B_fgs.at[jnp.array(self.pos_special_freqs), ...].set(
jnp.broadcast_to(jnp.eye(ncomp_fgs), (self.n_pix, ncomp_fgs, ncomp_fgs)).T
)
# insert all the parameters values
B_fgs = B_fgs.at[self.indexes_frequency_array_no_special, ...].set(params[templates])
# Retrieving freq and comp indices corresponding to idx_template
# freq_idx_template, comp_idx_template = jnp.argwhere(self.indexes_b==idx_template)
return B_fgs, self.get_one_template(nside_patch)
assert nside_patch <= self.nside
if ncomp_fgs != 0:
assert params_long.shape == ((self.n_frequencies - len(self.pos_special_freqs)), ncomp_fgs, self.n_pix)
assert len(self.pos_special_freqs) <= ncomp_fgs
params_long, template = self.get_idx_template_params_long_python(idx_template, params)
B_fgs = np.zeros((self.n_frequencies, ncomp_fgs, self.n_pix))
if len(self.pos_special_freqs) != 0:
# insert all the ones given by the pos_special_freqs
for c in range(len(self.pos_special_freqs)):
B_fgs[self.pos_special_freqs[c]][c] = 1
# insert all the parameters values
f = 0
for i in range(self.n_frequencies):
if i not in self.pos_special_freqs:
B_fgs[i, :] = params_long[f, :, :]
f += 1
return B_fgs, template
[docs]
def get_patch_B_from_params(self, nside_patch, params, jax_use=False):
"""
Full mixing matrix, (n_frequencies*n_components) from params and one patch distribution template.
cmb is given as the first component.
Parameters
----------
nside_patch: int
Healpix nside of one patch distribution, expected lower than nside of the input maps
params: array[float]
Flattened version of all free parameters of the mixing matrix per patch
expected to be stored as [Bf1_comp1_patch1, Bf1_comp1_patch2, ..., Bf2_comp1_patch1, ..., Bf1_comp2_patch1, ..., Bfn_comp2_patchn, ...]
jax_use: bool (optional)
If True, params are expected as JAX Array and the results will be returned as JAX Array, default False
Returns
-------
B_mat: array[float] of dimensions [n_frequencies, n_components, n_pix]
Full mixing matrix
template: array[int] of dimensions [12*nside_patch**2]
One template indexes map whose values correspond to the indices of params
for one patch distribution according to Healpix pixelization
"""
if jax_use:
B_fgs, template = self.get_template_B_fgs_from_params(nside_patch, params, jax_use=jax_use)
return jnp.concatenate((self.get_B_cmb(jax_use=jax_use), B_fgs), axis=1), template
B_fgs, template = self.get_template_B_fgs_from_params(nside_patch, params, jax_use=jax_use)
B_mat = np.concatenate((self.get_B_cmb(), B_fgs), axis=1)
return B_mat, template
[docs]
def get_params_db(self, jax_use=False):
# TODO: adjust with spv
"""
STATUS: Not used currently, to be adjusted with spv
Derivatives of the part of the Mixing Matrix w params
(wrt to each entry of first comp and then each entry of second comp)
Note: w/o pixel dimension
"""
nrows = self.n_frequencies - self.n_components + 1
ncols = self.n_components - 1
if jax_use:
def set_1(i):
params_dBi = jnp.zeros((nrows, ncols))
index_i = i // 2
index_j = i % 2
return params_dBi.at[index_i, index_j].set(1).ravel(order='C').reshape((nrows, ncols), order='F')
return jax.vmap(set_1)(jnp.arange(nrows * ncols))
params_dBi = np.zeros((nrows, ncols))
params_dB = []
for j in range(ncols):
for i in range(nrows):
params_dBi_copy = copy.deepcopy(params_dBi)
params_dBi_copy[i, j] = 1
params_dB.append(params_dBi_copy)
return params_dB
[docs]
def get_B_db(self, jax_use=False):
"""
STATUS: Not used currently, to be adjusted with spv
List of derivatives of the Mixing Matrix
(wrt to each entry of first comp and then each entry of second comp)
Note: w/o pixel dimension
"""
params_db = self.get_params_db(jax_use=jax_use)
if jax_use:
B_db = jnp.zeros((self.n_frequencies, self.n_frequencies, self.n_components))
relevant_indexes = jnp.arange(self.pos_special_freqs[0] + 1, self.pos_special_freqs[-1])
B_db = B_db.at[:, relevant_indexes, 1:].set(params_db)
return B_db
B_db = []
for B_db_i in params_db:
# add the zeros of special positions
for i in self.pos_special_freqs:
B_db_i = np.insert(B_db_i, i, np.zeros(self.n_components - 1), axis=0)
# add the zeros of CMB
B_db_i = np.column_stack((np.zeros(self.n_frequencies), B_db_i))
B_db.append(B_db_i)
return B_db