Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
75dc4d6
Add abstraction to log-likelihood function RLDM model
cpaniaguam Oct 9, 2025
1673910
Add tests for JAX log-likelihood functions in RLDM model
cpaniaguam Oct 9, 2025
63ac052
Reduce scope to return drift rates
cpaniaguam Oct 9, 2025
54caf5c
Add TODO comment for adapting make_rldm_logp_op function to changes i…
cpaniaguam Oct 9, 2025
add487f
Drop op test for now
cpaniaguam Oct 9, 2025
b241efb
Update docstring
cpaniaguam Oct 9, 2025
b0acf3a
Add a blank line before the TODO comment in make_rldm_logp_op function
cpaniaguam Oct 9, 2025
6d5f76e
Fix linting
cpaniaguam Oct 9, 2025
3af10da
Rename make_rldm_logp_func to make_rl_logp_func and update references…
cpaniaguam Oct 9, 2025
d4e2f0f
Rename mapping_function to subject_wise_func in make_rl_logp_func and…
cpaniaguam Oct 9, 2025
9bd3357
Update logp docstring for clarity on RL model specification
cpaniaguam Oct 9, 2025
3e965c8
Rename variable v to drift_rates in make_rl_logp_func for clarity
cpaniaguam Oct 9, 2025
132a2da
Reorder action assignment in make_rl_logp_func for clarity
cpaniaguam Oct 14, 2025
4b4c2f0
Update TODO comment for clarity on adapting make_rldm_logp_op function
cpaniaguam Oct 14, 2025
9d8ed86
Add pragma directive to make_rldm_logp_op for test coverage exclusion
cpaniaguam Oct 14, 2025
4d664e9
Refactor test_make_rl_logp_func to use action from data instead of se…
cpaniaguam Oct 15, 2025
d65c3fd
Refactor logp function in make_rl_logp_func to accept variable argume…
cpaniaguam Oct 15, 2025
2c407e8
Update logp docstring
cpaniaguam Oct 15, 2025
ea9fec4
Fix linting
cpaniaguam Oct 15, 2025
67dd08a
Reformat
cpaniaguam Oct 15, 2025
f8faa92
Add column validation and index retrieval functions for data processing
cpaniaguam Oct 20, 2025
25fa64e
Enhance make_rl_logp_func to accept additional parameters for data co…
cpaniaguam Oct 20, 2025
1a1a9c6
Add column validation and index retrieval in make_rl_logp_func
cpaniaguam Oct 20, 2025
ecaf7d2
Refactor make_rl_logp_func to accept a 2D array for trial data and mo…
cpaniaguam Oct 20, 2025
7a941d6
Refactor test_make_rl_logp_func to use rldm_setup fixture for improve…
cpaniaguam Oct 20, 2025
fe8ea53
Enhance column validation functions to handle None values and improve…
cpaniaguam Oct 20, 2025
cb720ee
Format _validate_columns call for improved readability in test cases
cpaniaguam Oct 20, 2025
c155f91
Set default data_cols in make_rl_logp_func to ["rt", "response"]
cpaniaguam Oct 20, 2025
3412c51
Add annotate_function decorator to attach metadata to functions
cpaniaguam Oct 24, 2025
b9b7b9b
Refactor _get_column_indices to rename parameters for clarity
cpaniaguam Oct 24, 2025
55e806b
Refactor _get_column_indices to use cols_to_look_up for index retrieval
cpaniaguam Oct 24, 2025
50e7e63
Update make_rl_logp_func to use subject_wise_func inputs for column v…
cpaniaguam Oct 24, 2025
ae3d256
Update logp function in make_rl_logp_func to accept additional arguments
cpaniaguam Oct 24, 2025
24fe7f8
Add tests for annotate_function decorator to verify attribute assignment
cpaniaguam Oct 24, 2025
4634748
Remove dist_params and extra_fields from make_rl_logp_func call in rl…
cpaniaguam Oct 24, 2025
127dd38
Update type hint for subject_wise_func in make_rl_logp_func and make_…
cpaniaguam Oct 24, 2025
730e82a
Refactor _get_column_indices to remove optional parameters and simpli…
cpaniaguam Nov 3, 2025
3d7ad85
Add TODO comments to reintroduce parameter handling in make_rl_logp_func
cpaniaguam Nov 3, 2025
0315fe0
Fix default value type for data_cols in make_rl_logp_func
cpaniaguam Nov 3, 2025
668a940
Update TODO comment in make_rl_logp_func to clarify workflow reintrod…
cpaniaguam Nov 3, 2025
3f3e7fe
Remove unused _validate_columns function to streamline code
cpaniaguam Nov 4, 2025
15c872c
Remove TestValidateColumns class to eliminate redundant validation tests
cpaniaguam Nov 4, 2025
033b2c6
Refactor _get_column_indices to improve column lookup and add error h…
cpaniaguam Nov 4, 2025
fa606cb
Add tests for column indexing and data collection in RLDM likelihood …
cpaniaguam Nov 4, 2025
675c4f7
Refactor _get_column_indices and update make_rl_logp_func to improve …
cpaniaguam Nov 4, 2025
1abc8cd
Refactor test_get_column_indices to improve parameter order and enhan…
cpaniaguam Nov 4, 2025
2eb711a
Remove breakpoint from _collect_cols_arrays to clean up debugging code
cpaniaguam Nov 4, 2025
7e95244
Rename parameter 'dist_params' to 'list_params' in make_rl_logp_func …
cpaniaguam Nov 4, 2025
ee58e32
Update docstring in make_rl_logp_func to clarify data parameter and a…
cpaniaguam Nov 4, 2025
f83135b
Refactor make_rl_logp_func to utilize _collect_cols_arrays for data p…
cpaniaguam Nov 4, 2025
38bf0bc
Add missing imports for _get_column_indices and _collect_cols_arrays …
cpaniaguam Nov 4, 2025
06aa798
Refactor rldm_setup to streamline parameter definitions and remove un…
cpaniaguam Nov 4, 2025
776d1cc
Fix input parameter name in compute_v_subject_wise for consistency
cpaniaguam Nov 4, 2025
91a06c9
Refactor rldm_setup to initialize feedback and parameter arrays for c…
cpaniaguam Nov 4, 2025
0ccd8b7
Refactor rldm_setup to include list_params and extra_fields in logp_f…
cpaniaguam Nov 4, 2025
e0d3872
Update test for logp_fn to pass additional arguments and correct expe…
cpaniaguam Nov 4, 2025
fdf21c3
Refactor test in TestRldmLikelihoodAbstraction to use unpacked _args …
cpaniaguam Nov 4, 2025
5a56dbb
Fix mypy
cpaniaguam Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 291 additions & 0 deletions src/hssm/likelihoods/rldm_optimized_abstraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
"""The log-likelihood function for the RLDM model."""

import functools
from typing import Any, Callable

import jax
import jax.numpy as jnp
import numpy as np
from jax.lax import scan
from pytensor.graph import Op

from hssm.distribution_utils.func_utils import make_vjp_func

from ..distribution_utils.jax import make_jax_logp_ops
from ..distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx

# Obtain the angle log-likelihood function from an ONNX model.
angle_logp_jax_func = make_jax_matrix_logp_funcs_from_onnx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loading the angle LAN was hardcoded since the existing likelihood used angle. But we will need to make this generic (based on the decision process defined in the rlssm model_config) and load the correct LAN based on what the user wants.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @krishnbera. This is another element to abstract as well.

model="angle.onnx",
)


def annotate_function(**kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has parallels with the decorate_simulator decorator?
Can we reuse functionality (either base that one on this one or vice versa)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The annotate_function decorator is more generic than decorate_atomic_simulator as the latter attaches specific metadata

    model_name: str,
    choices: list | np.ndarray | None = None,
    obs_dim: int = 2,  # At least for now ssms models all fall under 2 obs dims

We could refactor decorate_atomic_simulator using a call to annotate_function:

def decorate_atomic_simulator(
    model_name: str,
    choices: list | np.ndarray | None = None,
    obs_dim: int = 2,  # At least for now ssms models all fall under 2 obs dims
):
    choices = [-1, 1] if choices is None else choices

    return annotate_function(
        model_name=model_name,
        choices=choices,
        obs_dim=obs_dim,
    )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this change should probably be introduced in a separate PR.

"""Attach arbitrary metadata as attributes to a function.

Parameters
----------
**kwargs
Arbitrary keyword arguments to attach as attributes.

Returns
-------
Callable
Decorator that adds metadata attributes to the wrapped function.
"""

def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **inner_kwargs):
return func(*args, **inner_kwargs)

for key, value in kwargs.items():
setattr(wrapper, key, value)
return wrapper

return decorator


# Inner function to compute the drift rate and update q-values for each trial.
# This function is used with `jax.lax.scan` to process each trial in the RLDM model.
def compute_v_trial_wise(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the angle in the name, here we shouldn't hardcode _v_ in the function name.

The learning rule could be applied to any parameter, we should use a name for it that just reflects the type of updating we are doing, regardless of which parameter it is applied to.

q_val: jnp.ndarray, inputs: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Compute the drift rate and updates the q-values for each trial.

This function is used with `jax.lax.scan` to process each trial. It takes the
current q-values and the RL parameters (rl_alpha, scaler), action (response),
and reward (feedback) for the current trial, computes the drift rate, and
updates the q-values. The q_values are updated in each iteration and carried
forward to the next one.

Parameters
----------
q_val
A length-2 jnp array containing the current q-values for the two alternatives.
These values are updated in each iteration and carried forward to the next
trial.
inputs
A 2D jnp array containing the RL parameters (rl_alpha, scaler),
action (response), and reward (feedback) for the current trial.

Returns
-------
tuple
A tuple containing the updated q-values and the computed drift rate (v).
"""
rl_alpha, scaler, action, reward = inputs
action = jnp.astype(action, jnp.int32)

# drift rate on each trial depends on difference in expected rewards for
# the two alternatives:
# drift rate = (q_up - q_low) * scaler where
# the scaler parameter describes the weight to put on the difference in
# q-values.
computed_v = (q_val[1] - q_val[0]) * scaler

# compute the reward prediction error
delta_RL = reward - q_val[action]

# update the q-values using the RL learning rule (here, simple TD rule)
q_val = q_val.at[action].set(q_val[action] + rl_alpha * delta_RL)

return q_val, computed_v


# This function computes the drift rates (v) for each subject by processing
# their trials one by one. It uses `jax.lax.scan` to efficiently iterate over
# the trials and compute the drift rates based on the RL parameters, actions,
# and rewards for each trial.
def compute_v_subject_wise(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would take out _v_ here again in the name, and also instead of subject_wise (which insinuates that it is done per subject but FOR EACH subject) rather _one_subject.

subj_trials: jnp.ndarray,
) -> jnp.ndarray:
"""Compute the drift rates (v) for a given subject.

Parameters
----------
subj_trials:
A jnp array of dimension (n_trials, 4) containing rl_alpha, scaler,
action (response), and reward (feedback) for each trial of the subject.

Returns
-------
jnp.ndarray
The computed drift rates (v) for the RLDM model for the given subject.
"""
_, v = scan(
compute_v_trial_wise,
jnp.ones(2) * 0.5, # initial q-values for the two alternatives
subj_trials,
)

return v


def _get_column_indices(
cols_to_look_up: list[str],
data_cols: list[str],
list_params: list[str] | None,
extra_fields: list[str] | None,
) -> dict[str, tuple[str, int]]:
"""Return indices for required columns.

Parameters
----------
cols_to_look_up : list[str]
Columns to find indices for
data_cols : list[str]
Available data columns
list_params : list[str] | None
Available list parameters
extra_fields : list[str] | None
Available extra fields

Returns
-------
dict[str, tuple[str, int]]
Mapping of column names to (source, index) tuples
"""
list_params = list_params or []
extra_fields = extra_fields or []
list_params_extra_fields = list_params + extra_fields
colidxs = {}
for col in cols_to_look_up:
if col in data_cols:
colidxs[col] = ("data", data_cols.index(col))
elif col in list_params_extra_fields:
colidxs[col] = ("args", list_params_extra_fields.index(col))
else:
raise ValueError(
f"Column '{col}' not found in any of `data`, `list_params`, "
f"or `extra_fields`."
)
return colidxs


def _collect_cols_arrays(data, _args, colidxs):
collected = []
for col in colidxs:
source, idx = colidxs[col]
if source == "data":
collected.append(data[:, idx])
else:
collected.append(_args[idx])
return collected


def make_rl_logp_func(
subject_wise_func: Callable[..., Any],
n_participants: int,
n_trials: int,
data_cols: list[str] = ["rt", "response"],
list_params: list[str] | None = None,
extra_fields: list[str] | None = None,
) -> Callable:
"""Create a function to compute the drift rates (v) for the RLDM model.

Parameters
----------
subject_wise_func : Callable
Function that computes drift rates for a subject's trials.
n_participants : int
Number of participants in the dataset.
n_trials : int
Number of trials per participant.
data_cols : list[str] | None
List of column names in the data array.
dist_params : list[str] | None
List of distribution parameter names required by the RL model.
extra_fields : list[str] | None
List of extra field names required by the RL model.

Returns
-------
Callable
A function that computes drift rates (v) for all subjects given their trial data
and RLDM parameters.
"""
inputs = subject_wise_func.inputs # type: ignore[attr-defined]
# _validate_columns(data_cols, inputs)
colidxs = _get_column_indices(
inputs,
data_cols,
list_params,
extra_fields,
)

# Vectorized version of subject_wise_func to handle multiple subjects.
subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0)

def logp(data, *args) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function name should be updated since it not computing the logp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpaniaguam this is basically just the update rule for the target parameter of the reinforcement learning process (v in our example model, but again, doesn't have to be that).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krishnbera @AlexanderFengler Thanks for pointing this out. We intentionally removed a couple of steps from the true logp computation which we intend to reintroduce in a different function.

"""Compute the drift rates (v) for each trial in a reinforcement learning model.

data : np.ndarray
A 2D array containing trial data.

args: Model parameters included in list_params and extra_fields.

Notes
-----
- The function internally reshapes the input data to group trials by
participant and applies a vectorized mapping function to compute drift
rates.
- The function assumes that `n_participants`, `n_trials`, `idxs`, and
`subject_wise_vmapped` are defined in the surrounding scope.

Returns
-------
np.ndarray
The computed drift rates for each trial, reshaped as a 2D array.
"""
# Reshape subj_trials into a 3D array of shape
# (n_participants, n_trials, len(args))
# so we can act on this object with the vmapped version of the mapping function
_data = _collect_cols_arrays(data, args, colidxs)

subj_trials = jnp.stack(_data, axis=1).reshape(n_participants, n_trials, -1)

drift_rates = subject_wise_vmapped(subj_trials).reshape((-1, 1))
return drift_rates

# TODO: reintroduce workflow using a jax function and handling the selection
# dist_params to stack
# create parameter arrays to be passed to the likelihood function
# ddm_params_matrix = jnp.stack(dist_params[2:6], axis=1)
# lan_matrix = jnp.concatenate((v, ddm_params_matrix, data), axis=1)
# return _logp_jax_func(lan_matrix)

return logp


# TODO[CP]: Adapt this function given the changes to make_rl_logp_func
# pragma: no cover
def make_rldm_logp_op(
subject_wise_func: Callable[..., Any],
n_participants: int,
n_trials: int,
n_params: int,
) -> Op:
"""Create a pytensor Op for the likelihood function of RLDM model.

Parameters
----------
n_participants : int
The number of participants in the dataset.
n_trials : int
The number of trials per participant.

Returns
-------
Op
A pytensor Op that computes the log likelihood for the RLDM model.
"""
logp = make_rl_logp_func(subject_wise_func, n_participants, n_trials)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function meant to output the logp or drifts (v) only? I am assuming logp, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpaniaguam / @krishnbera let's put a meeting on the calendar for monday to discuss this point? ( @digicosmos86 also if you are back).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, we are aware this is computing drift rates only and some renaming of variables is necessary. The main goal of this PR is to build the abstraction and then ingest the drifts to get logp in a separate step.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderFengler @krishnbera this is still WIP so none of the code you see is in a final reviewable state. @cpaniaguam can you change this PR to a draft so we can iterate more on it?

vjp_logp = make_vjp_func(logp, params_only=False, n_params=n_params)

return make_jax_logp_ops(
logp=jax.jit(logp),
logp_vjp=jax.jit(vjp_logp),
logp_nojit=logp,
n_params=n_params, # rl_alpha, scaler, a, z, t, theta
)
Loading