-
Notifications
You must be signed in to change notification settings - Fork 4k
[Prototype] Allow defining session state via a class #13592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
✅ Snyk checks have passed. No issues have been found so far.
💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse. |
✅ PR preview is ready!
|
The session_state.__call__ metrics name is more descriptive and consistent with other session_state metrics (set_item, set_attr). Added session_state to ignored_commands with explanation.
|
@cursor review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request introduces a new feature that allows users to define session state using class decorators. The @st.session_state decorator transforms a regular Python class into a proxy that automatically stores all fields in Streamlit's session state, providing both class-level and instance-based access patterns.
Changes:
- Added
__call__method toSessionStateProxyto enable it to function as a class decorator - Implemented helper classes (
_StateAccessor,_SessionStateClassMeta) and functions to support the decorator pattern - Added comprehensive unit tests (38 tests) covering various use cases including field access, methods, collision detection, and edge cases
- Updated metrics tracking to exclude
session_statefrom direct command tracking (tracked assession_state.__call__instead)
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 20 comments.
| File | Description |
|---|---|
| lib/streamlit/runtime/state/session_state_proxy.py | Core implementation of the @st.session_state decorator including field extraction, method binding, metaclass for class-level access, and state initialization |
| lib/tests/streamlit/runtime/state/session_state_proxy_test.py | Comprehensive unit test suite with 38 tests covering basic functionality, edge cases, and error conditions |
| lib/tests/streamlit/runtime/metrics_util_test.py | Updated to exclude session_state from direct API command tracking with explanatory comment |
| def _create_bound_method( | ||
| method: Callable[..., Any], fields: dict[str, Any] | ||
| ) -> Callable[..., Any]: | ||
| """Create a bound method that uses StateAccessor as self. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| method : Callable[..., Any] | ||
| The original method from the class. | ||
| fields : dict[str, Any] | ||
| Dictionary of valid field names. | ||
|
|
||
| Returns | ||
| ------- | ||
| Callable[..., Any] | ||
| A callable that invokes the method with a StateAccessor. | ||
| """ | ||
| from functools import wraps | ||
|
|
||
| accessor = _StateAccessor(fields) | ||
|
|
||
| @wraps(method) | ||
| def bound_method(*args: Any, **kwargs: Any) -> Any: | ||
| return method(accessor, *args, **kwargs) | ||
|
|
||
| return bound_method |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function _create_bound_method creates a new _StateAccessor instance for every bound method. This means each method gets its own accessor instance, which is inefficient and could lead to memory overhead when there are many methods. Consider creating a single shared accessor instance for all methods within a class, or reuse the accessor from the class-level binding.
| # Allow instantiation - return a StateAccessor proxy instance | ||
| def _create_instance(_cls: type, *_args: Any, **_kwargs: Any) -> _StateAccessor: | ||
| return _StateAccessor(fields, bound_methods, class_name) | ||
|
|
||
| proxy_class.__new__ = _create_instance # type: ignore[assignment,method-assign] |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __new__ method override is non-standard. The typical Python way to customize instance creation is to override __call__ on the metaclass rather than replacing __new__ on the class itself. This could lead to unexpected behavior with inheritance or other Python magic methods. Consider implementing __call__ on the _SessionStateClassMeta metaclass instead.
| fields : dict[str, Any] | ||
| Dictionary of valid field names for this state class. | ||
| methods : dict[str, Callable[..., Any]] | None | ||
| Dictionary of bound methods (only needed for instance access). |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for _StateAccessor says methods can be None, but the type annotation shows dict[str, Callable[..., Any]] | None. When None is passed, it's converted to an empty dict. Consider updating the docstring to clarify that None is only accepted as a convenience and is converted to an empty dict, or remove the None option from the type annotation and always require a dict to be passed.
| Dictionary of bound methods (only needed for instance access). | |
| Dictionary of bound methods (only needed for instance access). | |
| If ``None`` (the default), it is treated as an empty mapping and | |
| no bound methods will be available on the accessor instance. |
| def _initialize_state_fields(fields: dict[str, Any]) -> None: | ||
| """Initialize fields in session state with their default values. | ||
|
|
||
| Only sets values that don't already exist in session state. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| fields : dict[str, Any] | ||
| Dictionary mapping field names to default values. | ||
| """ | ||
| state = get_session_state() | ||
|
|
||
| for name, default in fields.items(): | ||
| if name not in state: | ||
| # Handle mutable defaults by copying | ||
| if isinstance(default, (list, dict, set)): | ||
| state[name] = copy.deepcopy(default) | ||
| else: | ||
| state[name] = default |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The field initialization only checks if the field name exists in session state, but doesn't verify if the existing value is of the expected type. If a user manually sets st.session_state["counter"] = "not a number" before decorating a class with counter: int = 0, the decorator will silently keep the string value instead of the expected int. Consider adding type validation or at least documenting this behavior, as it could lead to unexpected runtime errors when methods try to use the field.
| # Handle mutable defaults by copying | ||
| if isinstance(default, (list, dict, set)): | ||
| state[name] = copy.deepcopy(default) | ||
| else: |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The deep copy for mutable defaults only handles list, dict, and set. Other mutable types like bytearray, collections.deque, or custom mutable objects won't be copied, which could lead to the same mutable default being shared across all uses. Consider using a more comprehensive approach, such as always using copy.deepcopy() for all defaults, or documenting which mutable types are supported.
| # Handle mutable defaults by copying | |
| if isinstance(default, (list, dict, set)): | |
| state[name] = copy.deepcopy(default) | |
| else: | |
| # Handle defaults by copying so mutable values are not shared. | |
| try: | |
| state[name] = copy.deepcopy(default) | |
| except Exception: | |
| # Fall back to the original value if it cannot be deep-copied. |
| def _check_and_register_keys(fields: dict[str, Any], class_name: str) -> None: | ||
| """Check for key collisions and register field keys. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| fields : dict[str, Any] | ||
| Dictionary of field names to check and register. | ||
| class_name : str | ||
| Name of the class registering these fields. | ||
|
|
||
| Raises | ||
| ------ | ||
| StreamlitAPIException | ||
| If a field name is already registered by a different class. | ||
| """ | ||
| state = get_session_state() | ||
|
|
||
| # Initialize registry if not exists | ||
| if _SESSION_STATE_CLASS_REGISTRY_KEY not in state: | ||
| state[_SESSION_STATE_CLASS_REGISTRY_KEY] = {} | ||
|
|
||
| registry = state[_SESSION_STATE_CLASS_REGISTRY_KEY] | ||
|
|
||
| for field_name in fields: | ||
| if field_name in registry: | ||
| existing_class = registry[field_name] | ||
| # Allow re-registration by the same class (for script reruns) | ||
| if existing_class != class_name: | ||
| raise StreamlitAPIException( | ||
| f"Key collision in @st.session_state: Field '{field_name}' " | ||
| f"is already registered by class '{existing_class}'. " | ||
| f"Cannot register it again for class '{class_name}'. " | ||
| f"Each field name must be unique across all @st.session_state classes." | ||
| ) | ||
| registry[field_name] = class_name | ||
|
|
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The registry key $$_st_session_state_class_registry is stored in session state but is not validated against user keys. A user could theoretically manually set or delete this key, which would break the collision detection mechanism. Consider using a more protected mechanism for storing this metadata, or at minimum add validation to prevent users from accidentally or intentionally modifying this internal key.
| def __call__(self, cls: type[_T]) -> type[_T]: | ||
| """Decorator to create a session state class. | ||
|
|
||
| Transforms a class definition into a proxy that stores all fields | ||
| in Streamlit's session state. Fields are persisted across script reruns | ||
| and can be accessed in two ways: | ||
|
|
||
| 1. **Class-level access** (direct): ``MyState.counter`` | ||
| 2. **Instance-based access** (Pythonic): ``state = MyState(); state.counter`` | ||
|
|
||
| Both patterns access the same underlying session state. Multiple | ||
| instantiations return equivalent proxy objects that share the same state. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| cls : type | ||
| The class to transform into a session state class. | ||
|
|
||
| Returns | ||
| ------- | ||
| type | ||
| A proxy class that stores state in session_state. | ||
|
|
||
| Raises | ||
| ------ | ||
| StreamlitAPIException | ||
| If a field has no default value or if there's a key collision | ||
| with another @st.session_state class. | ||
|
|
||
| Example | ||
| ------- | ||
| Define a state class with fields and methods: | ||
|
|
||
| >>> @st.session_state | ||
| ... class MyState: | ||
| ... counter: int = 0 | ||
| ... name: str = "default" | ||
| ... | ||
| ... def increment(self): | ||
| ... self.counter += 1 | ||
|
|
||
| **Class-level access** (quick scripts): | ||
|
|
||
| >>> MyState.counter # Read from session state | ||
| 0 | ||
| >>> MyState.increment() # Call method | ||
| >>> MyState.counter | ||
| 1 | ||
|
|
||
| **Instance-based access** (recommended, more Pythonic): | ||
|
|
||
| >>> state = MyState() # Create proxy instance | ||
| >>> state.counter # Read from session state | ||
| 1 | ||
| >>> state.increment() # Call method | ||
| >>> state.counter | ||
| 2 | ||
|
|
||
| Both access the same underlying session state: | ||
|
|
||
| >>> st.session_state["counter"] | ||
| 2 | ||
|
|
||
| Note: All instances share the same state: | ||
|
|
||
| >>> state1 = MyState() | ||
| >>> state2 = MyState() | ||
| >>> state1.counter = 100 | ||
| >>> state2.counter # Same value! | ||
| 100 | ||
| """ | ||
| class_name = cls.__name__ | ||
|
|
||
| # Extract fields and methods from the original class | ||
| fields = _extract_fields_from_class(cls) | ||
| methods = _extract_methods_from_class(cls) | ||
|
|
||
| # Check for key collisions and register keys | ||
| _check_and_register_keys(fields, class_name) | ||
|
|
||
| # Initialize fields in session state | ||
| _initialize_state_fields(fields) | ||
|
|
||
| # Create bound methods | ||
| bound_methods: dict[str, Callable[..., Any]] = { | ||
| name: _create_bound_method(method, fields) | ||
| for name, method in methods.items() | ||
| } | ||
|
|
||
| # Create the proxy class using metaclass | ||
| proxy_class = _SessionStateClassMeta( | ||
| class_name, | ||
| (), | ||
| { | ||
| "_st_fields": fields, | ||
| "_st_methods": bound_methods, | ||
| "_st_class_name": class_name, | ||
| "__doc__": cls.__doc__, | ||
| "__module__": cls.__module__, | ||
| }, | ||
| ) | ||
|
|
||
| # Allow instantiation - return a StateAccessor proxy instance | ||
| def _create_instance(_cls: type, *_args: Any, **_kwargs: Any) -> _StateAccessor: | ||
| return _StateAccessor(fields, bound_methods, class_name) | ||
|
|
||
| proxy_class.__new__ = _create_instance # type: ignore[assignment,method-assign] | ||
|
|
||
| return proxy_class # type: ignore[return-value] |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test coverage for the new @st.session_state decorator functionality. According to the custom coding guidelines, the repository should have typing tests in lib/tests/streamlit/typing/ for public API features. Consider adding a typing test file (e.g., session_state_decorator_types.py) to verify that the type annotations work correctly with mypy, including field access, method calls, and the return types of decorated classes.
| def test_setting_undefined_attribute_raises_error(self) -> None: | ||
| """Test that setting undefined attributes raises AttributeError.""" | ||
|
|
||
| @self.session_state_proxy | ||
| class MyState: | ||
| value: int = 0 | ||
|
|
||
| # Setting an undefined field should go through the metaclass | ||
| # which allows it for internal setup, but for user code it should | ||
| # be stored in session state only if it's a defined field | ||
| MyState.value = 10 # This should work | ||
| assert MyState.value == 10 |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing negative assertion. According to the custom coding guidelines for unit tests, you should include a negative/anti-regression assertion when practical. This test verifies that setting undefined attributes on the class-level works for defined fields, but it doesn't verify that attempting to set a truly undefined field raises an error or behaves correctly. Consider adding an assertion that verifies attempting to set an undefined field name fails appropriately.
| def test_dict_default_is_copied(self) -> None: | ||
| """Test that dict defaults are deep copied.""" | ||
|
|
||
| @self.session_state_proxy | ||
| class MyState: | ||
| data: dict[str, int] = {} # noqa: RUF012 | ||
|
|
||
| MyState.data["key1"] = 1 | ||
| MyState.data["key2"] = 2 | ||
|
|
||
| assert MyState.data == {"key1": 1, "key2": 2} |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing anti-regression check. According to the custom coding guidelines, unit tests should include negative assertions when practical. This test verifies that the dict default is copied and can be modified, but doesn't verify that creating another class with a dict default doesn't share the same dict instance. Consider adding an assertion that creates a second state class with a dict default and verifies the dicts are independent.
| def _extract_methods_from_class(cls: type) -> dict[str, Callable[..., Any]]: | ||
| """Extract methods from a class definition. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| cls : type | ||
| The class to extract methods from. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, Callable[..., Any]] | ||
| Dictionary mapping method names to the method functions. | ||
| """ | ||
| methods: dict[str, Callable[..., Any]] = {} | ||
|
|
||
| for name, value in inspect.getmembers(cls, predicate=inspect.isfunction): | ||
| # Skip private/dunder methods | ||
| if name.startswith("_"): | ||
| continue | ||
| methods[name] = value | ||
|
|
||
| return methods |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test coverage for special methods. The implementation extracts methods using inspect.getmembers with inspect.isfunction predicate and skips methods starting with underscore. However, there's no test to verify what happens if a user defines __init__, __str__, or other special methods in their state class. Consider adding a test to document and verify the expected behavior when special methods are defined.
| """ | ||
| from functools import wraps | ||
|
|
||
| accessor = _StateAccessor(fields) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Methods cannot call other methods via self
High Severity
The _create_bound_method function creates a _StateAccessor with only fields, not passing the methods dictionary. This means when a method uses self to call another method (e.g., self.increment()), the _StateAccessor.__getattr__ will check an empty _methods dict and raise AttributeError. Methods can access fields via self.field_name but cannot call sibling methods like self.other_method().
| return | ||
|
|
||
| # For initial class setup, allow setting | ||
| super().__setattr__(name, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class-level undefined attribute setting silently succeeds
Medium Severity
In _SessionStateClassMeta.__setattr__, when setting an attribute that isn't a defined field, the code falls through to super().__setattr__(name, value), silently setting it as a regular class attribute. This is inconsistent with instance-level behavior in _StateAccessor.__setattr__, which raises AttributeError for undefined fields. The test test_setting_undefined_attribute_raises_error claims to verify error-raising behavior but only tests setting defined attributes, missing this bug.
|
Hello! Would this decorator support defining the class as a pydantic model with fields constraints? For instance, this is what we currently do in our apps: class PageState(pydantic.BaseModel):
some_checked_state: bool = False
other_stuff: list[FooBar] = pydantic.Field(default_factory=list)
@classmethod
def get(cls) -> Self:
if "__page_state__" not in st.session_state:
st.session_state.__page_state__ = cls()
return st.session_state.__page_state__From your examples this would now look like: @st.session_state
class PageState:
some_checked_state: bool = FalseBut how complete is the type checking, is it enforced? 👉 Could you describe how comparable this is going to behave compared to using the pydantic way? Thanks a lot 🙏 |
Describe your changes
This PR adds the ability to use
st.session_stateas a class decorator, allowing users to define dataclass-like state classes where fields are automatically stored in session state.Motivation
Managing session state in Streamlit apps often involves repetitive boilerplate:
This PR introduces a cleaner, more Pythonic pattern:
Usage
Defining State Classes
Two Access Patterns
1. Class-level access (quick scripts):
2. Instance-based access (recommended, more Pythonic):
Both patterns access the same underlying session state—they're fully interchangeable.
Multiple State Classes
Session State Compatibility
Fields are stored directly in
st.session_stateand can be accessed via dict syntax:Features
Error Handling
Fields must have defaults:
Key collisions are detected:
Implementation
__call__method toSessionStateProxyclass_SessionStateClassMeta) for class-level attribute access_StateAccessorclass proxies attribute access to session state for both method binding and instance accessGitHub Issue Link (if applicable)
Testing Plan
Contribution License Agreement
By submitting this pull request you agree that all contributions to this project are made under the Apache 2.0 license.