-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Type annotate messengers #3309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotate messengers #3309
Conversation
ordabayevy
commented
Jan 2, 2024
- plate_messenger
- reentrant_messenger
- reparam_messenger
pyro/poutine/runtime.py
Outdated
| args: Tuple | ||
| kwargs: Dict | ||
| value: Optional[torch.Tensor] | ||
| value: Optional[T] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've learned this neat trick with Generic and TypeVar where the type of value can be inferred from the Callable signature. I also fixed effectful so that it gives the correct signature for the decorated function when the return type is diferent from torch.Tensor (e.g. reparam_messenger._get_init_messengers).
pyro/poutine/runtime.py
Outdated
| ) | ||
| # apply the stack and return its return value | ||
| apply_stack(msg) | ||
| assert msg["value"] is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always correct? All tests have passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, just a couple nits.
pyro/poutine/runtime.py
Outdated
| ) | ||
| # apply the stack and return its return value | ||
| apply_stack(msg) | ||
| assert msg["value"] is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fritzo can you have another look? I have addressed your comments.
| ) -> Union[T, torch.Tensor, None]: | ||
| obs: Optional[_T] = None, | ||
| **kwargs: _P.kwargs, | ||
| ) -> Optional[_T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this back to return Optional and removed the assert msg["value"] is not None line. One concern I have is that if _T itself is None then it will raise an assertion error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks for the ping!