-
Notifications
You must be signed in to change notification settings - Fork 17
Add abstraction to RL likelihood functions #824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
75dc4d6
1673910
63ac052
54caf5c
add487f
b241efb
b0acf3a
6d5f76e
3af10da
d4e2f0f
9bd3357
3e965c8
132a2da
4b4c2f0
9d8ed86
4d664e9
d65c3fd
2c407e8
ea9fec4
67dd08a
f8faa92
25fa64e
1a1a9c6
ecaf7d2
7a941d6
fe8ea53
cb720ee
c155f91
3412c51
b9b7b9b
55e806b
50e7e63
ae3d256
24fe7f8
4634748
127dd38
730e82a
3d7ad85
0315fe0
668a940
3f3e7fe
15c872c
033b2c6
fa606cb
675c4f7
1abc8cd
2eb711a
7e95244
ee58e32
f83135b
38bf0bc
06aa798
776d1cc
91a06c9
0ccd8b7
e0d3872
fdf21c3
5a56dbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,291 @@ | ||
| """The log-likelihood function for the RLDM model.""" | ||
|
|
||
| import functools | ||
| from typing import Any, Callable | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| from jax.lax import scan | ||
| from pytensor.graph import Op | ||
|
|
||
| from hssm.distribution_utils.func_utils import make_vjp_func | ||
|
|
||
| from ..distribution_utils.jax import make_jax_logp_ops | ||
| from ..distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx | ||
|
|
||
| # Obtain the angle log-likelihood function from an ONNX model. | ||
| angle_logp_jax_func = make_jax_matrix_logp_funcs_from_onnx( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loading the angle LAN was hardcoded since the existing likelihood used angle. But we will need to make this generic (based on the decision process defined in the rlssm model_config) and load the correct LAN based on what the user wants.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @krishnbera. This is another element to abstract as well. |
||
| model="angle.onnx", | ||
| ) | ||
|
|
||
|
|
||
| def annotate_function(**kwargs): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has parallels with the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The model_name: str,
choices: list | np.ndarray | None = None,
obs_dim: int = 2, # At least for now ssms models all fall under 2 obs dimsWe could refactor def decorate_atomic_simulator(
model_name: str,
choices: list | np.ndarray | None = None,
obs_dim: int = 2, # At least for now ssms models all fall under 2 obs dims
):
choices = [-1, 1] if choices is None else choices
return annotate_function(
model_name=model_name,
choices=choices,
obs_dim=obs_dim,
)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this change should probably be introduced in a separate PR. |
||
| """Attach arbitrary metadata as attributes to a function. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| **kwargs | ||
| Arbitrary keyword arguments to attach as attributes. | ||
|
|
||
| Returns | ||
| ------- | ||
| Callable | ||
| Decorator that adds metadata attributes to the wrapped function. | ||
| """ | ||
|
|
||
| def decorator(func: Callable) -> Callable: | ||
| @functools.wraps(func) | ||
| def wrapper(*args, **inner_kwargs): | ||
| return func(*args, **inner_kwargs) | ||
|
|
||
| for key, value in kwargs.items(): | ||
| setattr(wrapper, key, value) | ||
| return wrapper | ||
|
|
||
| return decorator | ||
|
|
||
|
|
||
| # Inner function to compute the drift rate and update q-values for each trial. | ||
| # This function is used with `jax.lax.scan` to process each trial in the RLDM model. | ||
| def compute_v_trial_wise( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to the The learning rule could be applied to any parameter, we should use a name for it that just reflects the type of updating we are doing, regardless of which parameter it is applied to. |
||
| q_val: jnp.ndarray, inputs: jnp.ndarray | ||
| ) -> tuple[jnp.ndarray, jnp.ndarray]: | ||
| """Compute the drift rate and updates the q-values for each trial. | ||
|
|
||
| This function is used with `jax.lax.scan` to process each trial. It takes the | ||
| current q-values and the RL parameters (rl_alpha, scaler), action (response), | ||
| and reward (feedback) for the current trial, computes the drift rate, and | ||
| updates the q-values. The q_values are updated in each iteration and carried | ||
| forward to the next one. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| q_val | ||
| A length-2 jnp array containing the current q-values for the two alternatives. | ||
| These values are updated in each iteration and carried forward to the next | ||
| trial. | ||
| inputs | ||
| A 2D jnp array containing the RL parameters (rl_alpha, scaler), | ||
| action (response), and reward (feedback) for the current trial. | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple | ||
| A tuple containing the updated q-values and the computed drift rate (v). | ||
| """ | ||
| rl_alpha, scaler, action, reward = inputs | ||
| action = jnp.astype(action, jnp.int32) | ||
|
|
||
| # drift rate on each trial depends on difference in expected rewards for | ||
| # the two alternatives: | ||
| # drift rate = (q_up - q_low) * scaler where | ||
| # the scaler parameter describes the weight to put on the difference in | ||
| # q-values. | ||
| computed_v = (q_val[1] - q_val[0]) * scaler | ||
|
|
||
| # compute the reward prediction error | ||
| delta_RL = reward - q_val[action] | ||
|
|
||
| # update the q-values using the RL learning rule (here, simple TD rule) | ||
| q_val = q_val.at[action].set(q_val[action] + rl_alpha * delta_RL) | ||
|
|
||
| return q_val, computed_v | ||
|
|
||
|
|
||
| # This function computes the drift rates (v) for each subject by processing | ||
| # their trials one by one. It uses `jax.lax.scan` to efficiently iterate over | ||
| # the trials and compute the drift rates based on the RL parameters, actions, | ||
| # and rewards for each trial. | ||
| def compute_v_subject_wise( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would take out |
||
| subj_trials: jnp.ndarray, | ||
| ) -> jnp.ndarray: | ||
| """Compute the drift rates (v) for a given subject. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| subj_trials: | ||
| A jnp array of dimension (n_trials, 4) containing rl_alpha, scaler, | ||
| action (response), and reward (feedback) for each trial of the subject. | ||
|
|
||
| Returns | ||
| ------- | ||
| jnp.ndarray | ||
| The computed drift rates (v) for the RLDM model for the given subject. | ||
| """ | ||
| _, v = scan( | ||
| compute_v_trial_wise, | ||
| jnp.ones(2) * 0.5, # initial q-values for the two alternatives | ||
| subj_trials, | ||
| ) | ||
|
|
||
| return v | ||
|
|
||
|
|
||
| def _get_column_indices( | ||
| cols_to_look_up: list[str], | ||
| data_cols: list[str], | ||
| list_params: list[str] | None, | ||
| extra_fields: list[str] | None, | ||
| ) -> dict[str, tuple[str, int]]: | ||
| """Return indices for required columns. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| cols_to_look_up : list[str] | ||
| Columns to find indices for | ||
| data_cols : list[str] | ||
| Available data columns | ||
| list_params : list[str] | None | ||
| Available list parameters | ||
| extra_fields : list[str] | None | ||
| Available extra fields | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, tuple[str, int]] | ||
| Mapping of column names to (source, index) tuples | ||
| """ | ||
| list_params = list_params or [] | ||
| extra_fields = extra_fields or [] | ||
| list_params_extra_fields = list_params + extra_fields | ||
| colidxs = {} | ||
| for col in cols_to_look_up: | ||
| if col in data_cols: | ||
| colidxs[col] = ("data", data_cols.index(col)) | ||
| elif col in list_params_extra_fields: | ||
| colidxs[col] = ("args", list_params_extra_fields.index(col)) | ||
| else: | ||
| raise ValueError( | ||
| f"Column '{col}' not found in any of `data`, `list_params`, " | ||
| f"or `extra_fields`." | ||
| ) | ||
| return colidxs | ||
|
|
||
|
|
||
| def _collect_cols_arrays(data, _args, colidxs): | ||
| collected = [] | ||
| for col in colidxs: | ||
| source, idx = colidxs[col] | ||
| if source == "data": | ||
| collected.append(data[:, idx]) | ||
| else: | ||
| collected.append(_args[idx]) | ||
| return collected | ||
|
|
||
|
|
||
| def make_rl_logp_func( | ||
| subject_wise_func: Callable[..., Any], | ||
| n_participants: int, | ||
| n_trials: int, | ||
| data_cols: list[str] = ["rt", "response"], | ||
| list_params: list[str] | None = None, | ||
| extra_fields: list[str] | None = None, | ||
| ) -> Callable: | ||
| """Create a function to compute the drift rates (v) for the RLDM model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| subject_wise_func : Callable | ||
| Function that computes drift rates for a subject's trials. | ||
| n_participants : int | ||
| Number of participants in the dataset. | ||
| n_trials : int | ||
| Number of trials per participant. | ||
| data_cols : list[str] | None | ||
| List of column names in the data array. | ||
| dist_params : list[str] | None | ||
| List of distribution parameter names required by the RL model. | ||
| extra_fields : list[str] | None | ||
| List of extra field names required by the RL model. | ||
|
|
||
| Returns | ||
| ------- | ||
| Callable | ||
| A function that computes drift rates (v) for all subjects given their trial data | ||
| and RLDM parameters. | ||
| """ | ||
| inputs = subject_wise_func.inputs # type: ignore[attr-defined] | ||
| # _validate_columns(data_cols, inputs) | ||
| colidxs = _get_column_indices( | ||
| inputs, | ||
| data_cols, | ||
| list_params, | ||
| extra_fields, | ||
| ) | ||
|
|
||
| # Vectorized version of subject_wise_func to handle multiple subjects. | ||
| subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0) | ||
|
|
||
| def logp(data, *args) -> np.ndarray: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function name should be updated since it not computing the logp?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cpaniaguam this is basically just the update rule for the target parameter of the reinforcement learning process (
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @krishnbera @AlexanderFengler Thanks for pointing this out. We intentionally removed a couple of steps from the true logp computation which we intend to reintroduce in a different function. |
||
| """Compute the drift rates (v) for each trial in a reinforcement learning model. | ||
|
|
||
| data : np.ndarray | ||
| A 2D array containing trial data. | ||
|
|
||
| args: Model parameters included in list_params and extra_fields. | ||
|
|
||
| Notes | ||
| ----- | ||
| - The function internally reshapes the input data to group trials by | ||
| participant and applies a vectorized mapping function to compute drift | ||
| rates. | ||
| - The function assumes that `n_participants`, `n_trials`, `idxs`, and | ||
| `subject_wise_vmapped` are defined in the surrounding scope. | ||
|
|
||
| Returns | ||
| ------- | ||
| np.ndarray | ||
| The computed drift rates for each trial, reshaped as a 2D array. | ||
| """ | ||
| # Reshape subj_trials into a 3D array of shape | ||
| # (n_participants, n_trials, len(args)) | ||
| # so we can act on this object with the vmapped version of the mapping function | ||
| _data = _collect_cols_arrays(data, args, colidxs) | ||
|
|
||
| subj_trials = jnp.stack(_data, axis=1).reshape(n_participants, n_trials, -1) | ||
|
|
||
| drift_rates = subject_wise_vmapped(subj_trials).reshape((-1, 1)) | ||
| return drift_rates | ||
|
|
||
| # TODO: reintroduce workflow using a jax function and handling the selection | ||
| # dist_params to stack | ||
| # create parameter arrays to be passed to the likelihood function | ||
| # ddm_params_matrix = jnp.stack(dist_params[2:6], axis=1) | ||
| # lan_matrix = jnp.concatenate((v, ddm_params_matrix, data), axis=1) | ||
| # return _logp_jax_func(lan_matrix) | ||
|
|
||
| return logp | ||
|
|
||
|
|
||
| # TODO[CP]: Adapt this function given the changes to make_rl_logp_func | ||
| # pragma: no cover | ||
| def make_rldm_logp_op( | ||
| subject_wise_func: Callable[..., Any], | ||
| n_participants: int, | ||
| n_trials: int, | ||
| n_params: int, | ||
| ) -> Op: | ||
| """Create a pytensor Op for the likelihood function of RLDM model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_participants : int | ||
| The number of participants in the dataset. | ||
| n_trials : int | ||
| The number of trials per participant. | ||
|
|
||
| Returns | ||
| ------- | ||
| Op | ||
| A pytensor Op that computes the log likelihood for the RLDM model. | ||
| """ | ||
| logp = make_rl_logp_func(subject_wise_func, n_participants, n_trials) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function meant to output the logp or drifts (v) only? I am assuming logp, right?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cpaniaguam / @krishnbera let's put a meeting on the calendar for monday to discuss this point? ( @digicosmos86 also if you are back).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, we are aware this is computing drift rates only and some renaming of variables is necessary. The main goal of this PR is to build the abstraction and then ingest the drifts to get logp in a separate step.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AlexanderFengler @krishnbera this is still WIP so none of the code you see is in a final reviewable state. @cpaniaguam can you change this PR to a draft so we can iterate more on it? |
||
| vjp_logp = make_vjp_func(logp, params_only=False, n_params=n_params) | ||
|
|
||
| return make_jax_logp_ops( | ||
| logp=jax.jit(logp), | ||
| logp_vjp=jax.jit(vjp_logp), | ||
| logp_nojit=logp, | ||
| n_params=n_params, # rl_alpha, scaler, a, z, t, theta | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.