Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@cpaniaguam
Copy link
Collaborator

This pull request introduces a new log-likelihood implementation for the RLDM model, including its JAX-based abstraction and associated unit tests. The main focus is on improving the computation of trial-wise drift rates and integrating the abstraction into the broader model likelihood system. The changes are grouped into new likelihood implementation and testing.

New RLDM likelihood implementation:

  • Added src/hssm/likelihoods/rldm_optimized_abstraction.py with JAX-based functions for computing RLDM drift rates and log-likelihoods, including compute_v_trial_wise, compute_v_subject_wise, and factory methods for log-likelihood functions and PyTensor Ops.

Testing and validation:

  • Added tests/test_rldm_likelihood_abstraction.py with unit tests for the new RLDM log-likelihood abstraction, verifying correct output shapes and numerical accuracy against expected values.

@cpaniaguam cpaniaguam linked an issue Oct 9, 2025 that may be closed by this pull request
6 tasks
Comment on lines 161 to 162
# TODO[CP]: Note really sure how to adapt this function given the changes to make_rldm_logp_func
def make_rldm_logp_op(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digicosmos86 Due to the scope reduction, should the rest of the workflow be added here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be another function in-between. Let's focus on make_rl_logp_func

Copy link
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far so good! Just one observation: rldm is just one RLSSM model (you'll see more in rldm.py), so in terms of naming we should be a bit more general

Comment on lines 143 to 146
action = data[:, 1]
rl_alpha = dist_params[0]
scaler = dist_params[1]
feedback = dist_params[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you'd be just stacking data and dist_params together to make a matrix

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that's the next and more delicate change. I wanted to ensure everything was in place before tackling this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digicosmos86 It seems like the order in which these are coming in might be important given how they are passed to jnp.stack. What's the most intuitive order? Perhaps require (rl_alpha, scaler, action, feedback) as in the current implementation?

        subj_trials = jnp.stack((rl_alpha, scaler, action, feedback), axis=1).reshape(
            n_participants, n_trials, -1
        )

Comment on lines 161 to 162
# TODO[CP]: Note really sure how to adapt this function given the changes to make_rldm_logp_func
def make_rldm_logp_op(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be another function in-between. Let's focus on make_rl_logp_func

@codecov
Copy link

codecov bot commented Oct 9, 2025

Codecov Report

❌ Patch coverage is 93.98496% with 8 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/hssm/likelihoods/rldm_optimized_abstraction.py 88.33% 7 Missing ⚠️
tests/test_rldm_likelihood_abstraction.py 98.63% 1 Missing ⚠️
Files with missing lines Coverage Δ
tests/test_rldm_likelihood_abstraction.py 98.63% <98.63%> (ø)
src/hssm/likelihoods/rldm_optimized_abstraction.py 88.33% <88.33%> (ø)

... and 14 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 144 to 147
rl_alpha = dist_params[0]
scaler = dist_params[1]
action = data[:, 1]
feedback = dist_params[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some ideas about parameter look up in dist_params for different funcitons. For now, let's work on the most general case: this function just stacks data and every vector in dist_params and make a matrix. Is this enough info to unblock you for now?

@digicosmos86
Copy link
Collaborator

@cpaniaguam @AlexanderFengler @krishnbera: about parameter lookup that we discussed today, here's the idea:

The inner function that make_rl_logp_func accepts will have this signature: logp_subjwise_inner_func(matrix) where matrix is a 2-D matrix with a subject-wise slice of data. We can make a function decorator for the inner function that looks like this:

@rl(inputs=["response", "rl_alpha", "scaler", "feedback"], outputs=["v"])
def logp_subjwise_inner_func(matrix)
    ...

Then this makes it very explicit to everyone what the inputs and outputs of this function are. Inside the function decorator, we just attach the inputs and outputs arguments to the function itself:

def rl(inner_func, inputs, outputs):
    inner_func.inputs = inputs
    inner_func.outputs = outputs

    return inner_func

Then we can access the information inside the make_rl_logp_func like this:

def make_rl_logp_func(subject_wise, ...):
    subjwise_inputs = subject_wise.inputs
    subjwise_outputs = subject_wise.outputs
    ...

Then we can perform look up based on the information. Potentially, this also enables us to accept multiple inner functions and compose them with the LAN functions, since we know exactly what the outputs are.

Let me know what you think

@cpaniaguam
Copy link
Collaborator Author

@cpaniaguam @AlexanderFengler @krishnbera: about parameter lookup that we discussed today, here's the idea:

The inner function that make_rl_logp_func accepts will have this signature: logp_subjwise_inner_func(matrix) where matrix is a 2-D matrix with a subject-wise slice of data. We can make a function decorator for the inner function that looks like this:

@rl(inputs=["response", "rl_alpha", "scaler", "feedback"], outputs=["v"])
def logp_subjwise_inner_func(matrix)
    ...

Then this makes it very explicit to everyone what the inputs and outputs of this function are. Inside the function decorator, we just attach the inputs and outputs arguments to the function itself:

def rl(inner_func, inputs, outputs):
    inner_func.inputs = inputs
    inner_func.outputs = outputs

    return inner_func

Then we can access the information inside the make_rl_logp_func like this:

def make_rl_logp_func(subject_wise, ...):
    subjwise_inputs = subject_wise.inputs
    subjwise_outputs = subject_wise.outputs
    ...

Then we can perform look up based on the information. Potentially, this also enables us to accept multiple inner functions and compose them with the LAN functions, since we know exactly what the outputs are.

Let me know what you think

I like this idea of attaching the metadata to the inner function. I think it's sensible and elegant.

@cpaniaguam cpaniaguam marked this pull request as ready for review October 15, 2025 16:09
@cpaniaguam cpaniaguam self-assigned this Oct 20, 2025
Comment on lines +98 to +127
def _validate_columns(
data_cols: list[str] | None,
dist_params: list[str] | None = None,
extra_fields: list[str] | None = None,
) -> None:
"""Validate that required columns are present.

Parameters
----------
data_cols
List of column names available in the data matrix. May be None when
called from higher-level factory functions before data is fully known.
dist_params
Distribution parameter names required by the RL likelihood.
extra_fields
Additional field names required by the RL likelihood.
"""
dist_params = dist_params or []
extra_fields = extra_fields or []
if data_cols is None:
# If data_cols is None but we have required parameters, raise early.
if dist_params or extra_fields:
raise ValueError("data_cols is None but required columns were provided.")
return
all_cols = [*dist_params, *extra_fields]
missing_cols = set(all_cols) - set(data_cols)
if missing_cols:
raise ValueError(
f"The following columns are missing from data_cols: {missing_cols}"
)
Copy link
Collaborator

@digicosmos86 digicosmos86 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of validating that the columns exist, but if we are writing make_rl_logp_func as a pure function, we won't be able to validate the columns because we don't know what the columns are

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something, but the intent of passing in data_columns is to capture what the columns in the data matrix represent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the misunderstanding is the idea that data_cols = dist_params + extra_fields. Actually, data_cols only means the columns in the data object, which is the only object that is passed in as a matrix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more or less the following applies:

likelihood(data_cols|dist_params;extra_cols)

So data-cols refer to observables, dist_params are estimable parameters (even if we fix them in a given model) and extra_cols is some extra fixed set of covariates that affects the internal computation of the likelihood.

In the rl case extra_fields are the feedback vectors which are used in the scan to update v (a dist_param).

Copy link
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! We are definitely on the right track here, but I still see some lingering confusion with the signature of logp() in make_rl_logp_func. As we discussed, PyMC requires it to have this signature logp(data, *dist_params, *extra_fields). See here for how it is called

)


def _get_column_indices(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see my next comment. I think this is the right idea, but I think that, rather than just returning the indices, maybe this function should return the columns. The reason is that you are looking across two objects, the data object, and a tuple of args

# Vectorized version of subject_wise_func to handle multiple subjects.
subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0)

def logp(data) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't change the signature of this function, unless we want to have a RL version of make_distribution in distribution_utils. Otherwise, like we discussed, this function will have to have logp(data, *dist_params, *extra_fields) signature

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I changed it to this is because when building traces jax.jit(logp) will throw a type error if any of the arguments of the function passed are not array data types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is safe to assume that all arguments passed to this function are arrays. Is it because some values passed to the jitted function in the tests are scalars? They should not be scalars

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking that logp(data: np.array, dist_param:list[str], extra: list[str]). In the constructor,

def make_rl_logp_func(
    subject_wise_func: Callable,
    n_participants: int,
    n_trials: int,
    data_cols: list[str] | None = None, # column names in the data matrix passed to logp
    dist_params: list[str] | None = None, # dist_param column names
    extra_fields: list[str] | None = None, # extra_fields column names
)

With the last three arguments, we can do the lookup for the indices of the names in dist_params and extra_fields within data_columns. With this, one only needs the data matrix in logp so its signature can be simplified to logp(data) as all the information in args would be available in the scope of make_rl_logp_func.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the misunderstanding is. data_cols contains only the name of the columns in the data matrix (the first argument passed to logp, that typically has 'rt, response' columns. It is not to be confused with the compiled matrix that we pass to the inner function

n_participants: int,
n_trials: int,
data_cols: list[str] | None = None,
dist_params: list[str] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably shouldn't have a default as None here, since we are not replacing it within the function to have sensible defaults. Might be a good idea to make it required

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem there is that one can't follow an argument with a default value with a required argument. Their order would need to be changed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order is not super important for these arguments here, but if we want to keep the order, we can handle the None inside the function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is handling of None defaults in _validate_columns.

from ..distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx

# Obtain the angle log-likelihood function from an ONNX model.
angle_logp_jax_func = make_jax_matrix_logp_funcs_from_onnx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loading the angle LAN was hardcoded since the existing likelihood used angle. But we will need to make this generic (based on the decision process defined in the rlssm model_config) and load the correct LAN based on what the user wants.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @krishnbera. This is another element to abstract as well.

# Vectorized version of subject_wise_func to handle multiple subjects.
subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0)

def logp(data, *args) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function name should be updated since it not computing the logp?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpaniaguam this is basically just the update rule for the target parameter of the reinforcement learning process (v in our example model, but again, doesn't have to be that).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krishnbera @AlexanderFengler Thanks for pointing this out. We intentionally removed a couple of steps from the true logp computation which we intend to reintroduce in a different function.

Op
A pytensor Op that computes the log likelihood for the RLDM model.
"""
logp = make_rl_logp_func(subject_wise_func, n_participants, n_trials)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function meant to output the logp or drifts (v) only? I am assuming logp, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpaniaguam / @krishnbera let's put a meeting on the calendar for monday to discuss this point? ( @digicosmos86 also if you are back).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, we are aware this is computing drift rates only and some renaming of variables is necessary. The main goal of this PR is to build the abstraction and then ingest the drifts to get logp in a separate step.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderFengler @krishnbera this is still WIP so none of the code you see is in a final reviewable state. @cpaniaguam can you change this PR to a draft so we can iterate more on it?

Copy link
Member

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cpaniaguam clear progress. For some of the lingering misunderstandings, let's meet again to talk things through a bit more.

Would be awesome if you could make a wire diagram on how you current see the code pieces fitting together so that we can interface with and play with that during the meeting.

)


def annotate_function(**kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has parallels with the decorate_simulator decorator?
Can we reuse functionality (either base that one on this one or vice versa)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The annotate_function decorator is more generic than decorate_atomic_simulator as the latter attaches specific metadata

    model_name: str,
    choices: list | np.ndarray | None = None,
    obs_dim: int = 2,  # At least for now ssms models all fall under 2 obs dims

We could refactor decorate_atomic_simulator using a call to annotate_function:

def decorate_atomic_simulator(
    model_name: str,
    choices: list | np.ndarray | None = None,
    obs_dim: int = 2,  # At least for now ssms models all fall under 2 obs dims
):
    choices = [-1, 1] if choices is None else choices

    return annotate_function(
        model_name=model_name,
        choices=choices,
        obs_dim=obs_dim,
    )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this change should probably be introduced in a separate PR.


# Inner function to compute the drift rate and update q-values for each trial.
# This function is used with `jax.lax.scan` to process each trial in the RLDM model.
def compute_v_trial_wise(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the angle in the name, here we shouldn't hardcode _v_ in the function name.

The learning rule could be applied to any parameter, we should use a name for it that just reflects the type of updating we are doing, regardless of which parameter it is applied to.

# their trials one by one. It uses `jax.lax.scan` to efficiently iterate over
# the trials and compute the drift rates based on the RL parameters, actions,
# and rewards for each trial.
def compute_v_subject_wise(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would take out _v_ here again in the name, and also instead of subject_wise (which insinuates that it is done per subject but FOR EACH subject) rather _one_subject.

Comment on lines +98 to +127
def _validate_columns(
data_cols: list[str] | None,
dist_params: list[str] | None = None,
extra_fields: list[str] | None = None,
) -> None:
"""Validate that required columns are present.

Parameters
----------
data_cols
List of column names available in the data matrix. May be None when
called from higher-level factory functions before data is fully known.
dist_params
Distribution parameter names required by the RL likelihood.
extra_fields
Additional field names required by the RL likelihood.
"""
dist_params = dist_params or []
extra_fields = extra_fields or []
if data_cols is None:
# If data_cols is None but we have required parameters, raise early.
if dist_params or extra_fields:
raise ValueError("data_cols is None but required columns were provided.")
return
all_cols = [*dist_params, *extra_fields]
missing_cols = set(all_cols) - set(data_cols)
if missing_cols:
raise ValueError(
f"The following columns are missing from data_cols: {missing_cols}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more or less the following applies:

likelihood(data_cols|dist_params;extra_cols)

So data-cols refer to observables, dist_params are estimable parameters (even if we fix them in a given model) and extra_cols is some extra fixed set of covariates that affects the internal computation of the likelihood.

In the rl case extra_fields are the feedback vectors which are used in the scan to update v (a dist_param).

When data_cols is None, return an empty list so that callers can defer
indexing until data is available.
"""
if data_cols is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_cols should never be None or am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good eye! _get_column_indices is called after a validation check on data_cols . Thanks for catching that!

# Vectorized version of subject_wise_func to handle multiple subjects.
subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0)

def logp(data, *args) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpaniaguam this is basically just the update rule for the target parameter of the reinforcement learning process (v in our example model, but again, doesn't have to be that).

Op
A pytensor Op that computes the log likelihood for the RLDM model.
"""
logp = make_rl_logp_func(subject_wise_func, n_participants, n_trials)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpaniaguam / @krishnbera let's put a meeting on the calendar for monday to discuss this point? ( @digicosmos86 also if you are back).

@cpaniaguam cpaniaguam marked this pull request as draft November 3, 2025 20:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Create decorator/wrapper for make_rldm_logp_func

5 participants