Likelihood Approximations with Neural Networks in PyMC

March 30, 2023
By Ricardo Vieira, Alexander Fengler, Aisulu Omar, and Yang Xu
Likelihood Approximations with Flax in PyMC
Leaning on an applied data-analysis problem in the cognitive modeling space, this post develops the tools to use neural networks trained with the Flax package (a neural network library based on JAX) as approximate likelihoods in likelihood-free inference scenarios. We will spend some time setting up the data analysis problem first, including the modeling framework used and computational bottlenecks that may arise (however if you don't care about the particulars, feel free to skip this part). Then, step by step, we will develop the tools necessary to go from a simple data simulator without access to a likelihood function to Bayesian Inference with PyMC via a custom distribution.
We will try to keep the code as general as possible, to facilitate other use cases with minimal hassle.
Table of Contents
Setting the Stage
To motivate the modeling effort expounded upon below, let's start by building the case for a particular class of models, beginning with an (somewhat stylized) original data analysis problem.
Consider a dataset from the NeuroRacer experiment, illustrated below with an adapted figure from the original paper.
The player/subject in this experiment is tasked with steering a racing car along a curvy racetrack, while reacting appropriately to appearing traffic signs under time pressure. Traffic signs are either of the target or no target type, the players' reaction appropriately being a button press or no button press respectively.
In the lingo of cognitive scientists, we may consider this game a Go / NoGo type task (press or withhold depending on traffic sign), under extra cognitive load (steering the car across the racetrack).
This leaves us with four types of responses to analyse (see the figure below):
- Correct button press (Correct Go)
- Correct withhold (Correct NoGo)
- False button press (False Go)
- False withhold (False NoGo)
What kind of data
Collecting reaction times (rt) and choices (responses) for each of the trials, our dataset will eventually look as follows.
import numpy as np
import pandas as pd
# Generate example data
data = pd.DataFrame(np.random.uniform(size = (100,1)), columns = ['rt'])
data['response'] = 'go'
data['response'].values[int(data.shape[0] / 2):] = 'nogo'
data['trial_type'] = 'target'
data['trial_type'].values[int(data.shape[0] / 2):] = 'notarget'
data
rt | response | trial_type | |
---|---|---|---|
0 | 0.776609 | go | target |
1 | 0.706098 | go | target |
2 | 0.395347 | go | target |
3 | 0.337480 | go | target |
4 | 0.751433 | go | target |
... | ... | ... | ... |
100 rows × 3 columns
The model(s)
Cognitive scientists have powerful framework for the joint analysis of reaction time and choice data: Sequential Sampling Models (SSMs).
The canonical model in this framework is the Drift Diffusion Model (or Diffusion Decision Model). We will take this model as a starting point to explain how it applies to the analysis of NeuroRacer data.
The basic idea behind the Drift Diffusion Model is the following.
We represent the decision process between two options as a Gaussian random walk of a so-called evidence state. The random walk starts after some non-decision time period
Which bound is reached, and the time of crossing, jointly determine the reaction time and choice. Hence, this model specifies a stochastic data generating process and we can define a likelihood function for this.
However, this likelihood function may be quite hard to derive, and possibly too computationally expensive. As an alternative, we will use a general function approximation via Neural Networks in this tutorial.
Let's first look at an illustration of the model and identify the quantities relevant for our example.
A nice aspect of the Drift Diffusion Model (or Diffusion Decision Model) is that the parameters are interpretationally distinct.
, the non-decision time component captures all aspects of decision-time not explicitly modeled as per the random walk process (e.g. motor-preparation, initial time-to-attentive-state etc. etc.) , provides global bias of the process towards one or the other choice. One can think of it as an a priori estimate of the underlying frequency of correct choices as per the experiment design. , is the rate with which evidence is consistently accumulated toward one or the other bound (in favor of one or the other choice). One can think of it as speed of processing. , represent a measure of the desired level of certainty before a decision is committed to. It is also referred to as decision caution.
The two quantities we will make explicit in the analyses are the following are (see also figure above),
, the likelihood of observing a Go choice at time , the likelihood of "observing" a withheld button press, defined as the integral of over .
We will focus on a simple analysis case, in which we observe hypothetical data from a single player, who plays the game for
We will make a simplifying modeling assumption: the rate of evidence accumulation for a NoGo has the same magnitude as that of Go trials but with a flipped sign, meaning participants are less likely to press a button as time goes by. This allows us to estimate a single
Hence we get
Motivating Simulation Based Inference
The Drift Diffusion Model actually has a (cumbersome) analytical likelihood, with specialized algorithms for fast evaluation. There are however many interesting variants for which fast computations are hampered by a lack of closed form solutions (see for example here and here).
Take as one example the model illustrated in the figure below,
Conceptually the only difference is that the decision criterion, described in our simple Drift Diffusion Model above with a single parameter
On the other hand simulators for such variants tend to remain easy to code up (often a few lines in python do the job). A simulator but no likelihood? Welcome to the world of simulation based inference (SBI).
Surveying the field of SBI is beyond the scope of this blog post (the paper above is a good start for those interested), but let it be said that SBI is the overarching paradigm from which we pick a specific method to construct our approach below.
The idea is the following. We start with a simulator for the DDM from which, given a set of parameters (
From these building blocks, we will construct training data to train two Multilayer Perceptrons (MLPs, read: small Neural Networks), one for each of the two parts of the overall likelihood.
These MLPs are going to act as our likelihood functions. We will call the network which represents
We will then proceed by wrapping these trained networks into a custom PyMC distribution and finally get samples from our posterior of interest
With all these steps ahead, let's get going!
From model simulation to PyMC model
Simulating Data
In favor of a digestible reading experience, we will use a convenience package to simulate data from the DDM model. This package not only allows us to simulate trajectories, but also includes utilities to directly produce data in a format suitable for downstream neural network training (which is our target here). The mechanics behind training data generation are described in this paper.
For some intuition, let's start with simulating and plotting a simple collection of
from ssms.basic_simulators import simulator
n_trajectories = 1000
parameter_vector = np.array([1.0, 1.5, 0.5, 0.5, 0.5])
simulation_data = simulator(model = 'angle',
theta = parameter_vector,
n_samples = n_trajectories,
random_state = 42)
simulation_data.keys()
dict_keys(['rts', 'choices', 'metadata'])
The simulator returns a dictionary
with three keys.
rts
, the reaction times for each choice under 2.choices
, here coded asfor lower boundary crossings and for upper boundary crossings. metadata
, extra information about the simulator settings
Let's use this to plot the reaction time distribution (negative reals refer to
Turning it into Training Data
We will now use a couple of convenience functions from the ssm-simulators package, to generate training data for our Neural Networks. This will proceed in two steps. We first define two config dictionaries to specify properties of the simulation runs that will serve as the basis for our training data set.
- The
generator_config
which specifies how to construct training data on top of basic simulations runs. - The
model_config
which specifies the properties of the core simulator.
Second, we will actually run the necessary simulations.
Let's make the config dictionaries.
NOTE:
The details here are quite immaterial. We simply need some way of generating training data of two types.
-
One (for the LAN), which has as features vectors of the kind
and as labels corresponding empirical log-likelihood evaluations . -
One (for the CPN), which takes as features simply the parameter vectors
and as labels corresponding empirical choice probabilities .
We are now in the position to actually run the simulations.
If you run this by yourself,
- Be aware that the next cell may run for a while (between a few minutes and an hour)
- Make sure the
output_folder
specified above exists.
# MAKE DATA
from ssms.dataset_generators import data_generator
n_datasets = 20
# Instantiate a data generator (we pass our configs)
my_dataset_generator = data_generator(generator_config = ddm_generator_config,
model_config = ddm_model_config)
for i in range(n_datasets):
print('Dataset: ', i + 1, ' of ', n_datasets)
training_data = my_dataset_generator.generate_data_training_uniform(save = True,
verbose = True)
Let's take a quick look at the type of data we generated here (if you run this by yourself, pick one of the unique file names generated during your run):
import pickle
training_data_example = pickle.load(open('data/training_data/ddm_high_prec/training_data_167fc318b85511ed81623ceceff2f96e.pickle',
'rb'))
training_data_example.keys()
dict_keys(['data', 'labels', 'choice_p', 'thetas', 'binned_128', 'binned_256', 'generator_config', 'model_config'])
Under the data
key (this is a legacy name, it might more appropriately called features
directly) we find the feature set we need for LANS. A matrix that contains columns [v, a, z, ndt, rt, choice]
. In general, across simulator models, the leading columns contain the parameters of the model, the remaining columns contain columns concerning the output data (in our case: responses and choices).
training_data_example['data'][:10, :]
array([[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 3.5173764,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 5.126489 ,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 4.1766562,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 5.331864 ,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 3.1934366,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 3.8244245,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 5.069471 ,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 6.12916 ,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 4.563048 ,
-1. ],
[-1.320654 , 2.4610643, 0.7317903, 1.5463215, 4.2055674,
-1. ]], dtype=float32)
The labels
key, contains the empirical
training_data_example['labels'][:10]
array([-0.9136966 , -1.7495332 , -1.1017948 , -1.9210151 , -1.0180298 ,
-0.96904343, -1.7043545 , -2.6625948 , -1.3647802 , -1.1204832 ],
dtype=float32)
The final keys we will be interested in, concern the feature and label data useful for training the CPN networks.
This network is a function from the model parameters (theta
key) directly to choice probabilities (choice_p
key).
training_data_example['thetas'][:10]
array([[-1.320654 , 2.4610643 , 0.7317903 , 1.5463215 ],
[ 0.6651015 , 2.032305 , 0.43455952, 0.39412627],
[-0.50315803, 2.434122 , 0.225012 , 0.9221967 ],
[-0.2956373 , 1.1780746 , 0.3984416 , 1.361724 ],
[-0.39936534, 1.170804 , 0.57806057, 1.2206811 ],
[-1.1943408 , 0.68256736, 0.6976298 , 1.172115 ],
[ 0.7937759 , 2.049422 , 0.45930618, 0.48710603],
[-2.0405245 , 2.453905 , 0.7521208 , 1.9120167 ],
[-2.8156106 , 2.2226427 , 0.69219965, 1.0444195 ],
[ 0.37362418, 1.775074 , 0.42867982, 0.22735284]],
dtype=float32)
training_data_example['choice_p'][:10]
array([2.930e-02, 9.146e-01, 1.540e-02, 2.471e-01, 3.542e-01, 3.422e-01,
9.555e-01, 6.300e-03, 7.000e-04, 7.356e-01], dtype=float32)
There are a few other keys in the training_data_example
dictionary. We can ignore these for the purposes of this blog post.
We are ready to move forward by turning our raw training data into a DataLoader
object, which directly prepares for ingestion by the Neural Networks.
The DataLoader
is supposed take care of:
- Efficiently reading in datafiles and
- turning them into batches to be ingested when training a Neural Network.
As has become somewhat of a standard, will work off of the Dataset
class supplied by the torch.utils.data
module in the PyTorch deep learning framework.
The key methods to define in our custom dataset are __getitem__()
and __len__()
.
__len__()
helps us to understand the amount of batches contained in a complete run through the data (epoch in machine learning lingo). __getitem_()
is the method called to retrieve the next batch of data.
Let's construct it.
Let's construct our training dataloaders for both our LAN and CPN networks (which we will define next). We use the DataLoader
class in the torch.utils.data
module to turn our Dataset
class into an iterator
.
NOTE:
To not explode code blocks in this blog post, we will only concern ourselves with training
data here, instead of including (as one should in a serious machine learning application) DataLoader
classes for validation
data as well. Defining validation
data works analogously.
Notice how we change the features_key
and label_key
arguments to access the relevant part of our training data files respectively for the LAN and CPN.
Building and Training the Network
We used the simulator to construct training data and constructed dataloaders on top of that. It is time to build and train our networks!
We will use the Flax python package for this purpose.
Let's first define a basic neural network class, constrained to minimal functionality.
We build such a class by inheriting from the nn.Module
class in the flax.linen
module and specifying two methods.
- The
setup()
method, which will be run as a preparatory step upon instantiation. - The
__call__()
metod defines the forward pass through the network.
Next we define a Neural Network trainer class. This will take a MLPJax
instance and build the necessary infrastructure for network training around it. The approach roughly follows the suggestions in the Flax documentation.
Preparations are all you need! We can now train our LAN and CPN with a few lines of code, making use of our previously defined classes.
# Initialize LAN
network_lan = MLPJax(train = True, # if train = False, output applies transform f such that: f(train_output_type) = logprob
train_output_type = 'logprob')
# Set up the model trainer
ModelTrainerLAN = ModelTrainerJaxMLP(model = network_lan,
train_dl = training_dataloader_lan,
loss = 'huber',
seed = 123)
# Train LAN
model_state_lan = ModelTrainerLAN.train(n_epochs = 10)
100%|██████████| 4880/4880 [00:30<00:00, 159.70it/s]
Epoch: 1 / 10, test_loss: 0.14862245321273804
...
100%|██████████| 4880/4880 [00:28<00:00, 169.85it/s]
Epoch: 10 / 10, test_loss: 0.01756889559328556
# Initialize CPN
network_cpn = MLPJax(train = True,
train_output_type = 'logits')
# Set up the model trainer
ModelTrainerCPN = ModelTrainerJaxMLP(model = network_cpn,
train_dl = training_dataloader_cpn,
loss = 'bcelogit',
seed = 456)
# Train CPN
model_state_cpn = ModelTrainerCPN.train(n_epochs = 20)
100%|██████████| 20/20 [00:02<00:00, 9.14it/s]
Epoch: 1 / 20, test_loss: 0.42765116691589355
...
100%|██████████| 20/20 [00:00<00:00, 27.14it/s]
Epoch: 20 / 20, test_loss: 0.30409830808639526
Connecting to PyMC
At this point we have two networks ready (we will later see example output that illustrate the behavior / quality of the approximation), which can be used as differentiable approximations to likelihood evaluations. The figure below should illustrate the respective function of each network (e.g. in the Go condition). This may help as a guiding visualization for the subsequent content.
A CPN, which we will use as an approximator to,
and,
A LAN, which we will use as an approximator to,
where
Together the CPN and the LAN allow us to construct a likelihood for a complete dataset from the NeuroRacer game.
Take the complete likelihood for a dataset of size
- Go condition, Go choice (we observe a reaction time):
- Go condition, NoGo choice (we don't observe a reaction time):
The log likelihood of the Go condition data can now be represented as:
For the NoGo Condition, we essentially apply the same logic so that the log likelihood of the NoGo condition data can now be represented as:
As per our modeling assumption we switch set
Building a custom distribution
All pieces are lined up to start building a custom distribution for eventual use in a PyMC model.
The starting point has to be the construction of a custom likelihood, as a valid PyTensor Op
.
For this purpose we use the NetworkLike
class below. It allows us to construct proper log-likelihoods from our two networks.
What do we mean by proper log-likelihood?
A valid Jax function that takes in parameters, processes the input data, performs the appropriate forward pass through the networks, and finally sums the resulting trial-wise log-likelihoods to give us a data-log-likelihood. This is taken care of by the make_logp_jax_funcs()
method.
Finally we need to turn these isolated likelihood functions into a valid PyTensor Op
, which is taken care of by the make_jax_logp_ops()
function. Note how we also register our log-likelihood function directly as a Jax log-likelihood (unwrap it) using the jax.funcify
decorator with the logp_op_dispatch()
method. This log-likelihood function does not need to be compiled (note how we pass the logp_nojit
likelihood there), which will instead be taken care of by any of the Jax sampler that PyMC provides (via NumPyro, or BlackJax)
NOTE:
The below code is a little involved and could be hard to digest on a first pass. Consider looking into the excellent tutorials in the PyMC docs and the PyMC Labs Blog on similar topics.
Specifically, the tutorial on using a blackbox likelihood function, the tutorial on custom distributions, the tutorial on wrapping jax functions into PyTensor Ops.
Finally there is an excellent tutorial from PyMC Labs, which incorporates Flax to train Bayesian Neural Networks (amongst other things): A different spin on our story here, not exactly equivalent, but helpful to understand the scope of use-cases encompassed at the intersection of Neural Networks and the Bayesian workflow.
The likelihood class will come in handy when defining our PyMC
model below .
We now construct simple forward functions for our networks (lan_forward()
, cpn_forward()
). We use the make_forward_partial()
method of our previously defined MLPJax
class.
First we instantiate the networks in evaluation mode. The make_forward_partial()
function then attaches our trained parameters to the usual Flax forward call (which takes in two arguments, the parameters and the model input) so that we can call lan_forward()
and cpn_forward()
with a single argument, the input data to be pushed through the respective network.
As you can check above, the work is done by the partial()
function.
# Initialize LAN in evaluation mode
network_lan_eval = MLPJax(train = False,
train_output_type = 'logprob')
# Make jitted forward passes (with fixed weights)
lan_forward, _ = network_lan_eval.make_forward_partial(state = ModelTrainerLAN.state.params)
# Initialize CPN in evaluation mode
network_cpn_eval = MLPJax(train = False,
train_output_type = 'logits')
# Make jitted forward passes (with fixed weights)
cpn_forward, _ = network_cpn_eval.make_forward_partial(state = ModelTrainerCPN.state.params)
As a quick aside, to illustrate the performance of the Networks, we plot their behavior below.
First, consider the LAN, which gives us choice / reaction time distributions directly.
We will vary the
Next we consider the performance of the CPN which, remember, spits out choice probabilities only.
In this plot we vary the
The outputs of the networks behave very regularly which is reassuring. To see if they match the real data generation process well, let's consider our previous figure on simulator behavior.
As we can see, the network outputs (shown in blue ) follow the simulation data very well.
NOTE:
We emphasize that for serious applications we are better served using a much larger training data set. The scale of the simulation run here was chosen to make running the code in this blog-post feasible on local machines in a reasonable amount of time.
Plug the custom likelihoods into a PyMC model
Now the hard work in the previous section culminates into actual results. We are able to construct our PyMC model by assembling the pieces we built in the previous sections. We instantiate our LAN and CPN based likelihood ops using the methods defined in our NetworkLike
class. First, we define simple like likelihood functions via the make_logp_jax_funcs()
method, then we construct the actual PyTensor LogOp
's, which will be used directly in the PyMC model below.
# Instantiate LAN logp functions
lan_logp_jitted, lan_logp_grad_jitted, lan_logp = NetworkLike.make_logp_jax_funcs(
model = lan_forward,
n_params = 4,
kind = "lan")
# Turn into logp op
lan_logp_op = NetworkLike.make_jax_logp_ops(
logp = lan_logp_jitted,
logp_grad = lan_logp_grad_jitted,
logp_nojit = lan_logp)
# Instantiate CPN logp functions
cpn_logp_jitted, cpn_logp_grad_jitted, cpn_logp = NetworkLike.make_logp_jax_funcs(
model = cpn_forward,
n_params = 4,
kind = "cpn")
# Turn into logp op
cpn_logp_op = NetworkLike.make_jax_logp_ops(
logp = cpn_logp_jitted,
logp_grad = cpn_logp_grad_jitted,
logp_nojit = cpn_logp)
Finally, let's define a function that constructs our PyMC
model for us. Note how we use our likelihood ops, the lan_logp_op()
and the cpn_logp_op()
respectively to define two pm.Potential()
functions. You can learn more about pm.Potential()
in the docs, and more connected to blackbox likelihoods, in this helpful basic tutorial.
def construct_pymc_model(data: Type[pd.DataFrame] | None = None):
"""
Construct our PyMC model given a dataset.
"""
# Data preprocessing:
# We expect three columns [rt, choice, condition(go or nogo)]
# We split the data according to whether the choice is go or nogo
data_nogo = data.loc[data.choice < 0, :]['is_go_trial'].values
data_go = data.loc[data.choice > 0, :].values
with pm.Model() as ddm:
# Define simple Uniform priors
v = pm.Uniform("v", -3.0, 3.0)
a = pm.Uniform("a", 0.3, 2.5)
z = pt.constant(0.5)
t = pm.Uniform("t", 0.0, 2.0)
pm.Potential("choice_rt", lan_logp_op(data_go, v, a, z, t))
pm.Potential("choice_only", cpn_logp_op(data_nogo, v, a, z, t))
return ddm
Inference example
We are nearing the end of this blog-post (promised). All that remains is to simply try it out. At this point we can simulate some synthetic Neuroracer experiment data, fire up our newly designed PyMC model and run our MCMC sampler for parameter inference.
We pick a set of parameters, and following our modeling assumptions, we apply
# Let's make some data
from ssms.basic_simulators import simulator
parameters = {'v': 1.0,
'a': 1.5,
'z': 0.5,
't': 0.5}
parameters_go = [parameters[key_] for key_ in parameters.keys()]
parameters_nogo = [parameters[key_] if key_ != 'v' else ((-1)*parameters[key_]) for key_ in parameters.keys()]
# Run simulations for each condition (go, nogo)
sim_go = simulator(theta = parameters_go, model = 'ddm', n_samples = 500)
sim_nogo = simulator(theta = parameters_nogo, model = 'ddm', n_samples = 500)
# Process data and add a column that signifies whether the trial,
# belongs to a go (1) or nogo (-1) condition
data_go_condition = np.hstack([sim_go['rts'], sim_go['choices'], np.ones((500, 1))])
data_nogo_condition = np.hstack([sim_nogo['rts'], sim_nogo['choices'], (-1)*np.ones((500, 1))])
# Stack the two datasets and turn into DataFrame
data = np.vstack([data_go_condition, data_nogo_condition]).astype(np.float32)
data_pd = pd.DataFrame(data, columns = ['rt', 'choice', 'is_go_trial'])
Our dataset at hand, we can now intiate the PyMC model.
ddm_blog = construct_pymc_model(data_pd)
Let's visualize the model structure.
pm.model_to_graphviz(ddm_blog)
The graphical model nicely illustrates how we handle the Go choices and NoGo choice via separate likelihod objects, while our basic parameters feed into both of these.
Note that we don't fit the
We are now ready to sample...
from pymc.sampling import jax as pmj
# Just to keep the blog-post pretty automatically
import warnings
warnings.filterwarnings('ignore')
with ddm_blog:
ddm_blog_traces_numpyro = pmj.sample_numpyro_nuts(
chains=2, draws=2000, tune=500, chain_method="vectorized"
)
Compiling...
Compilation time = 0:00:08.502226
Sampling...
sample: 100%|██████████| 2500/2500 [00:17<00:00, 140.40it/s]
Sampling time = 0:01:03.744460
Transforming variables...
Transformation time = 0:00:01.292058
As a last step we can check our posterior distributions. Did all of this actually work out?
NOTE:
The posterior mass here may be somewhat off the mark when comparing to the ground truth parameters. While this hints at a calibration issue, it was conscious approach to trade-off on precision to avoid potentially very long runtimes for this tutorial. We can in general improve the performance of our neural network by training on much more synthetic data (which in real applications is advisable). This would however make running this notebook very cumbersome, which we in turn encourage you to try!
import arviz as az
az.plot_posterior(ddm_blog_traces_numpyro,
kind = 'hist',
**{'color': 'black',
'histtype': 'step'},
ref_val = {'v': [{'ref_val': parameters['v']}],
'a': [{'ref_val': parameters['a']}],
't': [{'ref_val': parameters['z']}]
},
ref_val_color = 'green')
array([<Axes: title={'center': 'v'}>, <Axes: title={'center': 'a'}>,
<Axes: title={'center': 't'}>], dtype=object)
A somewhat long but hopefully rewarding tutorial is hereby finished. We hope you see some potential in this approach. Many extensions are possible, from the choice of neural network architectures to the structure of the PyMC model a plethora of options arise. As a lowest bar, we hope that this may serve you as another take on a tutorial concerning custom distributions in PyMC.
For related tutorials check out: