- 
                Notifications
    
You must be signed in to change notification settings  - Fork 17
 
Add abstraction to RL likelihood functions #824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add abstraction to RL likelihood functions #824
Conversation
…n make_rldm_logp_func
| # TODO[CP]: Note really sure how to adapt this function given the changes to make_rldm_logp_func | ||
| def make_rldm_logp_op( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@digicosmos86 Due to the scope reduction, should the rest of the workflow be added here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be another function in-between. Let's focus on make_rl_logp_func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far so good! Just one observation: rldm is just one RLSSM model (you'll see more in rldm.py), so in terms of naming we should be a bit more general
| action = data[:, 1] | ||
| rl_alpha = dist_params[0] | ||
| scaler = dist_params[1] | ||
| feedback = dist_params[-1] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you'd be just stacking data and dist_params together to make a matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that's the next and more delicate change. I wanted to ensure everything was in place before tackling this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@digicosmos86 It seems like the order in which these are coming in might be important given how they are passed to jnp.stack. What's the most intuitive order? Perhaps require (rl_alpha, scaler, action, feedback) as in the current implementation?
        subj_trials = jnp.stack((rl_alpha, scaler, action, feedback), axis=1).reshape(
            n_participants, n_trials, -1
        )| # TODO[CP]: Note really sure how to adapt this function given the changes to make_rldm_logp_func | ||
| def make_rldm_logp_op( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be another function in-between. Let's focus on make_rl_logp_func
… make_rldm_logp_op for clarity
          Codecov Report❌ Patch coverage is  
 
 ... and 14 files with indirect coverage changes 🚀 New features to boost your workflow:
  | 
    
| rl_alpha = dist_params[0] | ||
| scaler = dist_params[1] | ||
| action = data[:, 1] | ||
| feedback = dist_params[-1] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some ideas about parameter look up in dist_params for different funcitons. For now, let's work on the most general case: this function just stacks data and every vector in dist_params and make a matrix. Is this enough info to unblock you for now?
| 
           @cpaniaguam @AlexanderFengler @krishnbera: about parameter lookup that we discussed today, here's the idea: The inner function that  @rl(inputs=["response", "rl_alpha", "scaler", "feedback"], outputs=["v"])
def logp_subjwise_inner_func(matrix)
    ...Then this makes it very explicit to everyone what the inputs and outputs of this function are. Inside the function decorator, we just attach the inputs and outputs arguments to the function itself: def rl(inner_func, inputs, outputs):
    inner_func.inputs = inputs
    inner_func.outputs = outputs
    return inner_funcThen we can access the information inside the  def make_rl_logp_func(subject_wise, ...):
    subjwise_inputs = subject_wise.inputs
    subjwise_outputs = subject_wise.outputs
    ...Then we can perform look up based on the information. Potentially, this also enables us to accept multiple inner functions and compose them with the LAN functions, since we know exactly what the outputs are. Let me know what you think  | 
    
…nts instead of fixed parameters
          
 I like this idea of attaching the metadata to the inner function. I think it's sensible and elegant.  | 
    
…lumns, distribution parameters, and extra fields
…del parameters, enhancing drift rate computation for multiple participants.
…d data handling and validation
… error handling for missing required columns
| def _validate_columns( | ||
| data_cols: list[str] | None, | ||
| dist_params: list[str] | None = None, | ||
| extra_fields: list[str] | None = None, | ||
| ) -> None: | ||
| """Validate that required columns are present. | ||
| 
               | 
          ||
| Parameters | ||
| ---------- | ||
| data_cols | ||
| List of column names available in the data matrix. May be None when | ||
| called from higher-level factory functions before data is fully known. | ||
| dist_params | ||
| Distribution parameter names required by the RL likelihood. | ||
| extra_fields | ||
| Additional field names required by the RL likelihood. | ||
| """ | ||
| dist_params = dist_params or [] | ||
| extra_fields = extra_fields or [] | ||
| if data_cols is None: | ||
| # If data_cols is None but we have required parameters, raise early. | ||
| if dist_params or extra_fields: | ||
| raise ValueError("data_cols is None but required columns were provided.") | ||
| return | ||
| all_cols = [*dist_params, *extra_fields] | ||
| missing_cols = set(all_cols) - set(data_cols) | ||
| if missing_cols: | ||
| raise ValueError( | ||
| f"The following columns are missing from data_cols: {missing_cols}" | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of validating that the columns exist, but if we are writing make_rl_logp_func as a pure function, we won't be able to validate the columns because we don't know what the columns are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm missing something, but the intent of passing in data_columns is to capture what the columns in the data matrix represent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the misunderstanding is the idea that data_cols = dist_params + extra_fields. Actually, data_cols only means the columns in the data object, which is the only object that is passed in as a matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more or less the following applies:
likelihood(data_cols|dist_params;extra_cols)
So data-cols refer to observables, dist_params are estimable parameters (even if we fix them in a given model) and extra_cols is some extra fixed set of covariates that affects the internal computation of the likelihood.
In the rl case extra_fields are the feedback vectors which are used in the scan to update v (a dist_param).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! We are definitely on the right track here, but I still see some lingering confusion with the signature of logp() in make_rl_logp_func. As we discussed, PyMC requires it to have this signature logp(data, *dist_params, *extra_fields). See here for how it is called
| ) | ||
| 
               | 
          ||
| 
               | 
          ||
| def _get_column_indices( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see my next comment. I think this is the right idea, but I think that, rather than just returning the indices, maybe this function should return the columns. The reason is that you are looking across two objects, the data object, and a tuple of args
| # Vectorized version of subject_wise_func to handle multiple subjects. | ||
| subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0) | ||
| 
               | 
          ||
| def logp(data) -> np.ndarray: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't change the signature of this function, unless we want to have a RL version of make_distribution in distribution_utils. Otherwise, like we discussed, this function will have to have logp(data, *dist_params, *extra_fields) signature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I changed it to this is because when building traces jax.jit(logp) will throw a type error if any of the arguments of the function passed are not array data types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is safe to assume that all arguments passed to this function are arrays. Is it because some values passed to the jitted function in the tests are scalars? They should not be scalars
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking that logp(data: np.array, dist_param:list[str], extra: list[str]). In the constructor,
def make_rl_logp_func(
    subject_wise_func: Callable,
    n_participants: int,
    n_trials: int,
    data_cols: list[str] | None = None, # column names in the data matrix passed to logp
    dist_params: list[str] | None = None, # dist_param column names
    extra_fields: list[str] | None = None, # extra_fields column names
)With the last three arguments, we can do the lookup for the indices of the names in dist_params and extra_fields within data_columns. With this, one only needs the data matrix in logp so its signature can be simplified to logp(data) as all the information in args would be available in the scope of make_rl_logp_func.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where the misunderstanding is. data_cols contains only the name of the columns in the data matrix (the first argument passed to logp, that typically has 'rt, response' columns. It is not to be confused with the compiled matrix that we pass to the inner function
| n_participants: int, | ||
| n_trials: int, | ||
| data_cols: list[str] | None = None, | ||
| dist_params: list[str] | None = None, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably shouldn't have a default as None here, since we are not replacing it within the function to have sensible defaults. Might be a good idea to make it required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem there is that one can't follow an argument with a default value with a required argument. Their order would need to be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order is not super important for these arguments here, but if we want to keep the order, we can handle the None inside the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is handling of None defaults in _validate_columns.
…rldm_logp_op to Callable[..., Any]
| from ..distribution_utils.onnx import make_jax_matrix_logp_funcs_from_onnx | ||
| 
               | 
          ||
| # Obtain the angle log-likelihood function from an ONNX model. | ||
| angle_logp_jax_func = make_jax_matrix_logp_funcs_from_onnx( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loading the angle LAN was hardcoded since the existing likelihood used angle. But we will need to make this generic (based on the decision process defined in the rlssm model_config) and load the correct LAN based on what the user wants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @krishnbera. This is another element to abstract as well.
| # Vectorized version of subject_wise_func to handle multiple subjects. | ||
| subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0) | ||
| 
               | 
          ||
| def logp(data, *args) -> np.ndarray: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function name should be updated since it not computing the logp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpaniaguam this is basically just the update rule for the target parameter of the reinforcement learning process (v in our example model, but again, doesn't have to be that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@krishnbera @AlexanderFengler Thanks for pointing this out. We intentionally removed a couple of steps from the true logp computation which we intend to reintroduce in a different function.
| Op | ||
| A pytensor Op that computes the log likelihood for the RLDM model. | ||
| """ | ||
| logp = make_rl_logp_func(subject_wise_func, n_participants, n_trials) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function meant to output the logp or drifts (v) only? I am assuming logp, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpaniaguam / @krishnbera let's put a meeting on the calendar for monday to discuss this point? ( @digicosmos86 also if you are back).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, we are aware this is computing drift rates only and some renaming of variables is necessary. The main goal of this PR is to build the abstraction and then ingest the drifts to get logp in a separate step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexanderFengler @krishnbera this is still WIP so none of the code you see is in a final reviewable state. @cpaniaguam can you change this PR to a draft so we can iterate more on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @cpaniaguam clear progress. For some of the lingering misunderstandings, let's meet again to talk things through a bit more.
Would be awesome if you could make a wire diagram on how you current see the code pieces fitting together so that we can interface with and play with that during the meeting.
| ) | ||
| 
               | 
          ||
| 
               | 
          ||
| def annotate_function(**kwargs): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has parallels with the decorate_simulator decorator?
Can we reuse functionality (either base that one on this one or vice versa)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The annotate_function decorator is more generic than decorate_atomic_simulator as the latter attaches specific metadata
    model_name: str,
    choices: list | np.ndarray | None = None,
    obs_dim: int = 2,  # At least for now ssms models all fall under 2 obs dimsWe could refactor decorate_atomic_simulator using a call to annotate_function:
def decorate_atomic_simulator(
    model_name: str,
    choices: list | np.ndarray | None = None,
    obs_dim: int = 2,  # At least for now ssms models all fall under 2 obs dims
):
    choices = [-1, 1] if choices is None else choices
    return annotate_function(
        model_name=model_name,
        choices=choices,
        obs_dim=obs_dim,
    )There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this change should probably be introduced in a separate PR.
| 
               | 
          ||
| # Inner function to compute the drift rate and update q-values for each trial. | ||
| # This function is used with `jax.lax.scan` to process each trial in the RLDM model. | ||
| def compute_v_trial_wise( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to the angle in the name, here we shouldn't hardcode _v_ in the function name.
The learning rule could be applied to any parameter, we should use a name for it that just reflects the type of updating we are doing, regardless of which parameter it is applied to.
| # their trials one by one. It uses `jax.lax.scan` to efficiently iterate over | ||
| # the trials and compute the drift rates based on the RL parameters, actions, | ||
| # and rewards for each trial. | ||
| def compute_v_subject_wise( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would take out _v_ here again in the name, and also instead of subject_wise (which insinuates that it is done per subject but FOR EACH subject) rather _one_subject.
| def _validate_columns( | ||
| data_cols: list[str] | None, | ||
| dist_params: list[str] | None = None, | ||
| extra_fields: list[str] | None = None, | ||
| ) -> None: | ||
| """Validate that required columns are present. | ||
| 
               | 
          ||
| Parameters | ||
| ---------- | ||
| data_cols | ||
| List of column names available in the data matrix. May be None when | ||
| called from higher-level factory functions before data is fully known. | ||
| dist_params | ||
| Distribution parameter names required by the RL likelihood. | ||
| extra_fields | ||
| Additional field names required by the RL likelihood. | ||
| """ | ||
| dist_params = dist_params or [] | ||
| extra_fields = extra_fields or [] | ||
| if data_cols is None: | ||
| # If data_cols is None but we have required parameters, raise early. | ||
| if dist_params or extra_fields: | ||
| raise ValueError("data_cols is None but required columns were provided.") | ||
| return | ||
| all_cols = [*dist_params, *extra_fields] | ||
| missing_cols = set(all_cols) - set(data_cols) | ||
| if missing_cols: | ||
| raise ValueError( | ||
| f"The following columns are missing from data_cols: {missing_cols}" | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more or less the following applies:
likelihood(data_cols|dist_params;extra_cols)
So data-cols refer to observables, dist_params are estimable parameters (even if we fix them in a given model) and extra_cols is some extra fixed set of covariates that affects the internal computation of the likelihood.
In the rl case extra_fields are the feedback vectors which are used in the scan to update v (a dist_param).
| When data_cols is None, return an empty list so that callers can defer | ||
| indexing until data is available. | ||
| """ | ||
| if data_cols is None: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_cols should never be None or am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good eye! _get_column_indices is called after a validation check on data_cols . Thanks for catching that!
| # Vectorized version of subject_wise_func to handle multiple subjects. | ||
| subject_wise_vmapped = jax.vmap(subject_wise_func, in_axes=0) | ||
| 
               | 
          ||
| def logp(data, *args) -> np.ndarray: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpaniaguam this is basically just the update rule for the target parameter of the reinforcement learning process (v in our example model, but again, doesn't have to be that).
| Op | ||
| A pytensor Op that computes the log likelihood for the RLDM model. | ||
| """ | ||
| logp = make_rl_logp_func(subject_wise_func, n_participants, n_trials) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpaniaguam / @krishnbera let's put a meeting on the calendar for monday to discuss this point? ( @digicosmos86 also if you are back).
…fy logic for column index retrieval.
This pull request introduces a new log-likelihood implementation for the RLDM model, including its JAX-based abstraction and associated unit tests. The main focus is on improving the computation of trial-wise drift rates and integrating the abstraction into the broader model likelihood system. The changes are grouped into new likelihood implementation and testing.
New RLDM likelihood implementation:
src/hssm/likelihoods/rldm_optimized_abstraction.pywith JAX-based functions for computing RLDM drift rates and log-likelihoods, includingcompute_v_trial_wise,compute_v_subject_wise, and factory methods for log-likelihood functions and PyTensor Ops.Testing and validation:
tests/test_rldm_likelihood_abstraction.pywith unit tests for the new RLDM log-likelihood abstraction, verifying correct output shapes and numerical accuracy against expected values.