Simplest architecture: Multilayer Perceptron (MLP)
Performance can be improved for particular types of data by making use of inductive biases
Figures and data from Ho et al. 2019
What we will feed the model:
Training data from MultiDark Planck 2 N-body simulation (Klypin et al. 2016) with 261287 clusters.
and what is it?
import jax.numpy as np
m = np.eye(10) # Some matrix
def my_func(x):
return m.dot(x).sum()
x = np.linspace(0,1,10)
y = my_func(x)
from jax import grad
df_dx = grad(my_func)
y = df_dx(x)
import flax.linen as nn
import optax
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(128)(x))
x = nn.relu(nn.Dense(128)(x))
x = nn.Dense(1)(x)
return x
# Instantiate the Neural Network
model = MLP()
# Initialize the parameters
params = model.init(jax.random.PRNGKey(0), x)
prediction = model.apply(params, x)
# Instantiate Optimizer
tx = optax.adam(learning_rate=0.001)
opt_state = tx.init(params)
# Define loss function
def loss_fn(params, x, y):
mse = model.apply(params, x) -y)**2
return jnp.mean(mse)
# Compute gradients
grads = jax.grad(loss_fn)(params, x, y)
# Update parameters
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
def loss_fn(params, x, y):
mse = (model.apply(params, x) - y)**2
return jnp.mean(mse)
Raise your hand if you manage to get a cluster mass prediction!
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(128)(x))
x = nn.relu(nn.Dense(128)(x))
x = nn.tanh(nn.Dense(64)(x))
x = nn.Dense(1)(x)
return x
def loss_fn(params, x, y):
prediction = model.apply(params, x)
return jnp.mean( (prediction - y)**2 )
Let's try to understand the neural network output by looking at the loss function
$$ \mathcal{L} = \sum_{(x_i, y_i) \in \mathcal{D}} \parallel x_i - f_\theta(y_i)\parallel^2 \quad \simeq \quad \int \parallel x - f_\theta(y) \parallel^2 \ p(x,y) \ dx dy $$In our case here:
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
### Create a mixture of two scalar Gaussians:
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7]),
components_distribution=tfd.Normal(
loc=[-1., 1],
scale=[0.1, 0.5]))
# Evaluate probability
gm.log_prob(1.0)
class MDN(nn.Module):
num_components: int
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(128)(x))
x = nn.relu(nn.Dense(128)(x))
x = nn.tanh(nn.Dense(64)(x))
# Instead of regressing directly the value of the mass, the network
# will now try to estimate the parameters of a mass distribution.
categorical_logits = nn.Dense(self.num_components)(x)
loc = nn.Dense(self.num_components)(x)
scale = 1e-3 + nn.softplus(nn.Dense(self.num_components)(x))
dist =tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=categorical_logits),
components_distribution=tfd.Normal(loc=loc, scale=scale))
# To make it understand the batch dimension
dist = tfd.Independent(dist)
# Returns a distribution !
return dist
def loss_fn(params, x, y):
q = model.apply(params, x)
return jnp.mean( - q.log_prob(y[:,0]) )
Raise your hand if you manage to get a cluster mass prediction with this new model!
Still a latent variable model, except the mapping $f_\theta$ is made to be bijective.
We want to solve for the Maximum A Posterior solution:
$$\arg \max - \frac{1}{2} \parallel {\color{Orchid} y} - {\color{SkyBlue} x} \parallel_2^2 + \log p_\theta({\color{SkyBlue} x})$$ This can be done by gradient descent as long as one has access to the score function $\frac{\color{orange} d \color{orange}\log \color{orange}p\color{orange}(\color{orange}x\color{orange})}{\color{orange} d \color{orange}x}$.Raise your hand when you reach the Maximum Likelihood Solution i.e. end of Step III.