diff --git a/reflex/state.py b/reflex/state.py index 6ba5633b5e6..c7200f5c377 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2059,6 +2059,15 @@ def _dirty_computed_vars( if include_backend or not self.computed_vars[cvar]._backend } + @property + def _skip_serialization(self) -> bool: + """Whether to skip serialization for this state. + + Override in a subclass to skip sending state updates to the frontend. + Useful e.g. for permission/role-based state visibility. + """ + return False + def get_delta(self) -> Delta: """Get the delta for the state. @@ -2067,6 +2076,9 @@ def get_delta(self) -> Delta: """ delta = {} + if self._skip_serialization: + return delta + self._mark_dirty_computed_vars() frontend_computed_vars: set[str] = { name for name, cv in self.computed_vars.items() if not cv._backend @@ -2203,6 +2215,9 @@ def dict( Returns: The object as a dictionary. """ + if not initial and self._skip_serialization: + return {} + if include_computed: self._mark_dirty_computed_vars() base_vars = { diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 5802cfa8b6f..36d973301a5 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -4470,3 +4470,81 @@ async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: assert state.data["a"] == [2, 3] # Object identity persists across serialization, so data["b"] is also mutated. assert state.data["b"] == [2, 3] + + +class NormalState(rx.State): + """This state should be serialized.""" + + value: int = 42 + + +class SkippedState(NormalState): + """This state should not be serialized.""" + + skipped_value: int = 43 + + @property + def _skip_serialization(self) -> bool: + return True + + +class SkippedSubState(SkippedState): + """This state should not be serialized.""" + + substate_value: int = 44 + + +def test_default_skip_serialization_is_false(): + state = NormalState() + assert state._skip_serialization is False + + +def test_subclass_override_is_respected(): + """Subclass override must actually take effect.""" + assert SkippedState()._skip_serialization is True + assert NormalState()._skip_serialization is False + + +def test_dict_contains_value_by_default(): + state = NormalState() + state_dict = str(state.dict()) + assert "42" in state_dict + assert "43" not in state_dict + assert "44" not in state_dict + + +def test_dict_empty_when_skip_serialization(): + state = SkippedState() + assert state.dict() == {} + + +def test_get_delta_contains_value_after_change(): + state = NormalState() + state.value = 99 + state_delta = str(state.get_delta()) + assert "99" in state_delta + assert "43" not in state_delta + assert "44" not in state_delta + + +def test_get_delta_empty_when_skip_serialization(): + state = SkippedState() + state.skipped_value = 99 + assert state.get_delta() == {} + + +def test_substate_of_skipped_parent_is_also_skipped(): + """Substates of a skipped parent are also skipped, even without overriding _skip_serialization.""" + state = SkippedSubState() + state.substate_value = 99 + assert state.get_delta() == {} + assert state.dict() == {} + + +def test_normal_state_unaffected_by_skipped_substate(): + """NormalState delta must not be affected by its skipped substate.""" + state = NormalState() + state.value = 99 + delta = str(state.get_delta()) + assert "99" in delta + assert "43" not in delta