Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ class BaseConfig:
# List of plugins to use in the app.
plugins: list[Plugin] = dataclasses.field(default_factory=list)

# List of fully qualified import paths of plugins to disable in the app (e.g. reflex.plugins.sitemap.SitemapPlugin).
disable_plugins: list[str] = dataclasses.field(default_factory=list)
# List of plugin types to disable in the app.
disable_plugins: list[type[Plugin]] = dataclasses.field(default_factory=list)

# The transport method for client-server communication.
transport: Literal["websocket", "polling"] = "websocket"
Expand Down Expand Up @@ -353,6 +353,9 @@ def _post_init(self, **kwargs):
for key, env_value in env_kwargs.items():
setattr(self, key, env_value)

# Normalize disable_plugins: convert strings and Plugin subclasses to instances.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "convert strings and Plugin subclasses to instances", but the actual logic does the opposite: it converts Plugin instances (and strings) to Plugin classes. The word "subclasses" and "instances" are swapped.

Suggested change
# Normalize disable_plugins: convert strings and Plugin subclasses to instances.
# Normalize disable_plugins: convert strings and Plugin instances to Plugin subclasses.

self._normalize_disable_plugins()

# Add builtin plugins if not disabled.
if not self._skip_plugins_checks:
self._add_builtin_plugins()
Expand All @@ -369,16 +372,58 @@ def _post_init(self, **kwargs):
msg = f"{self._prefixes[0]}REDIS_URL is required when using the redis state manager."
raise ConfigError(msg)

def _normalize_disable_plugins(self):
"""Normalize disable_plugins list entries to Plugin subclasses.

Handles backward compatibility by converting strings (fully qualified
import paths) and Plugin instances to their associated classes.
"""
normalized: list[type[Plugin]] = []
for entry in self.disable_plugins:
if isinstance(entry, type) and issubclass(entry, Plugin):
normalized.append(entry)
elif isinstance(entry, Plugin):
console.deprecate(
feature_name="Passing Plugin instances to disable_plugins",
reason="pass Plugin classes directly instead, e.g. disable_plugins=[SitemapPlugin]",
deprecation_version="0.8.28",
removal_version="0.9.0",
)
normalized.append(type(entry))
elif isinstance(entry, str):
console.deprecate(
feature_name="Passing strings to disable_plugins",
reason="pass Plugin classes directly instead, e.g. disable_plugins=[SitemapPlugin]",
deprecation_version="0.8.28",
removal_version="0.9.0",
)
try:
from reflex.environment import interpret_plugin_class_env

normalized.append(
interpret_plugin_class_env(entry, "disable_plugins")
)
except Exception:
console.warn(
f"Failed to import plugin from string {entry!r} in disable_plugins. "
"Please pass Plugin subclasses directly.",
)
else:
console.warn(
f"reflex.Config.disable_plugins should contain Plugin subclasses, but got {entry!r}.",
)
self.disable_plugins = normalized

def _add_builtin_plugins(self):
"""Add the builtin plugins to the config."""
for plugin in _PLUGINS_ENABLED_BY_DEFAULT:
plugin_name = plugin.__module__ + "." + plugin.__qualname__
if plugin_name not in self.disable_plugins:
if plugin not in self.disable_plugins:
if not any(isinstance(p, plugin) for p in self.plugins):
console.warn(
f"`{plugin_name}` plugin is enabled by default, but not explicitly added to the config. "
"If you want to use it, please add it to the `plugins` list in your config inside of `rxconfig.py`. "
f"To disable this plugin, set `disable_plugins` to `{[plugin_name, *self.disable_plugins]!r}`.",
f"To disable this plugin, add `{plugin.__name__}` to the `disable_plugins` list.",
)
self.plugins.append(plugin())
else:
Expand All @@ -389,16 +434,9 @@ def _add_builtin_plugins(self):
)

for disabled_plugin in self.disable_plugins:
if not isinstance(disabled_plugin, str):
console.warn(
f"reflex.Config.disable_plugins should only contain strings, but got {disabled_plugin!r}. "
)
if not any(
plugin.__module__ + "." + plugin.__qualname__ == disabled_plugin
for plugin in _PLUGINS_ENABLED_BY_DEFAULT
):
if disabled_plugin not in _PLUGINS_ENABLED_BY_DEFAULT:
console.warn(
f"`{disabled_plugin}` is disabled in the config, but it is not a built-in plugin. "
f"`{disabled_plugin!r}` is disabled in the config, but it is not a built-in plugin. "
"Please remove it from the `disable_plugins` list in your config inside of `rxconfig.py`.",
)

Expand Down
40 changes: 35 additions & 5 deletions reflex/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,17 @@ def interpret_path_env(value: str, field_name: str) -> Path:
return Path(value)


def interpret_plugin_env(value: str, field_name: str) -> Plugin:
"""Interpret a plugin environment variable value.
def interpret_plugin_class_env(value: str, field_name: str) -> type[Plugin]:
"""Interpret an environment variable value as a Plugin subclass.

Resolves a fully qualified import path to the Plugin subclass it refers to.

Args:
value: The environment variable value.
value: The environment variable value (e.g. "reflex.plugins.sitemap.SitemapPlugin").
field_name: The field name.

Returns:
The interpreted value.
The Plugin subclass.

Raises:
EnvironmentVarValueError: If the value is invalid.
Expand All @@ -184,10 +186,30 @@ def interpret_plugin_env(value: str, field_name: str) -> Plugin:
msg = f"Invalid plugin class: {plugin_name!r} for {field_name}. Must be a subclass of Plugin."
raise EnvironmentVarValueError(msg)

return plugin_class


def interpret_plugin_env(value: str, field_name: str) -> Plugin:
"""Interpret a plugin environment variable value.

Resolves a fully qualified import path and returns an instance of the Plugin.

Args:
value: The environment variable value (e.g. "reflex.plugins.sitemap.SitemapPlugin").
field_name: The field name.

Returns:
An instance of the Plugin subclass.

Raises:
EnvironmentVarValueError: If the value is invalid.
"""
plugin_class = interpret_plugin_class_env(value, field_name)

try:
return plugin_class()
except Exception as e:
msg = f"Failed to instantiate plugin {plugin_name!r} for {field_name}: {e}"
msg = f"Failed to instantiate plugin {plugin_class.__name__!r} for {field_name}: {e}"
raise EnvironmentVarValueError(msg) from e


Expand Down Expand Up @@ -268,6 +290,14 @@ def interpret_env_var_value(
return interpret_existing_path_env(value, field_name)
if field_type is Plugin:
return interpret_plugin_env(value, field_name)
if get_origin(field_type) is type:
type_args = get_args(field_type)
if (
type_args
and isinstance(type_args[0], type)
and issubclass(type_args[0], Plugin)
):
return interpret_plugin_class_env(value, field_name)
if get_origin(field_type) is Literal:
literal_values = get_args(field_type)
for literal_value in literal_values:
Expand Down
60 changes: 60 additions & 0 deletions tests/units/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
interpret_enum_env,
interpret_int_env,
)
from reflex.plugins import Plugin
from reflex.plugins.sitemap import SitemapPlugin


def test_requires_app_name():
Expand Down Expand Up @@ -402,3 +404,61 @@ def test_env_file(
)
for key, value in exp_env_vars.items():
assert os.environ.get(key) == value


class TestDisablePlugins:
"""Tests for the disable_plugins config option."""

def test_disable_with_plugin_class(self):
"""Test disabling a plugin by passing the class (type)."""
config = rx.Config(app_name="test", disable_plugins=[SitemapPlugin])
assert not any(isinstance(p, SitemapPlugin) for p in config.plugins)

def test_disable_with_plugin_instance_backward_compat(self):
"""Test disabling a plugin by passing an instance (deprecated)."""
config = rx.Config(app_name="test", disable_plugins=[SitemapPlugin()]) # pyright: ignore[reportArgumentType]
assert not any(isinstance(p, SitemapPlugin) for p in config.plugins)

def test_disable_with_string_backward_compat(self):
"""Test disabling a plugin by passing a string (deprecated)."""
config = rx.Config(
app_name="test",
disable_plugins=["reflex.plugins.sitemap.SitemapPlugin"], # pyright: ignore[reportArgumentType]
)
assert not any(isinstance(p, SitemapPlugin) for p in config.plugins)

def test_disable_plugins_normalized_to_classes(self):
"""Test that disable_plugins entries are normalized to Plugin subclasses."""
config = rx.Config(app_name="test", disable_plugins=[SitemapPlugin])
assert all(
isinstance(dp, type) and issubclass(dp, Plugin)
for dp in config.disable_plugins
)

def test_disable_instance_normalized_to_class(self):
"""Test that a Plugin instance in disable_plugins is normalized to its class."""
config = rx.Config(app_name="test", disable_plugins=[SitemapPlugin()]) # pyright: ignore[reportArgumentType]
assert config.disable_plugins == [SitemapPlugin]

def test_disable_string_normalized_to_class(self):
"""Test that a string in disable_plugins is normalized to the class."""
config = rx.Config(
app_name="test",
disable_plugins=["reflex.plugins.sitemap.SitemapPlugin"], # pyright: ignore[reportArgumentType]
)
assert config.disable_plugins == [SitemapPlugin]

def test_disable_and_plugins_conflict_warns(self):
"""Test that a warning is issued when a plugin is both enabled and disabled."""
config = rx.Config(
app_name="test",
plugins=[SitemapPlugin()],
disable_plugins=[SitemapPlugin],
)
# Plugin should still be in plugins list (just warned)
assert any(isinstance(p, SitemapPlugin) for p in config.plugins)

def test_no_disable_adds_builtin(self):
"""Test that builtin plugins are added when not disabled."""
config = rx.Config(app_name="test")
assert any(isinstance(p, SitemapPlugin) for p in config.plugins)
34 changes: 34 additions & 0 deletions tests/units/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
interpret_existing_path_env,
interpret_int_env,
interpret_path_env,
interpret_plugin_class_env,
interpret_plugin_env,
)
from reflex.plugins import Plugin
Expand Down Expand Up @@ -125,6 +126,30 @@ def test_interpret_plugin_env_invalid_class(self):
with pytest.raises(EnvironmentVarValueError, match="Invalid plugin class"):
interpret_plugin_env("tests.units.test_environment.TestEnum", "TEST_FIELD")

def test_interpret_plugin_class_env_valid(self):
"""Test plugin class interpretation returns the class, not an instance."""
result = interpret_plugin_class_env(
"tests.units.test_environment.TestPlugin", "TEST_FIELD"
)
assert result is TestPlugin

def test_interpret_plugin_class_env_invalid_format(self):
"""Test plugin class interpretation with invalid format."""
with pytest.raises(EnvironmentVarValueError, match="Invalid plugin value"):
interpret_plugin_class_env("invalid_format", "TEST_FIELD")

def test_interpret_plugin_class_env_import_error(self):
"""Test plugin class interpretation with import error."""
with pytest.raises(EnvironmentVarValueError, match="Failed to import module"):
interpret_plugin_class_env("non.existent.module.Plugin", "TEST_FIELD")

def test_interpret_plugin_class_env_invalid_class(self):
"""Test plugin class interpretation with invalid class."""
with pytest.raises(EnvironmentVarValueError, match="Invalid plugin class"):
interpret_plugin_class_env(
"tests.units.test_environment.TestEnum", "TEST_FIELD"
)

def test_interpret_enum_env_valid(self):
"""Test enum interpretation with valid values."""
result = interpret_enum_env("value1", _TestEnum, "TEST_FIELD")
Expand Down Expand Up @@ -172,6 +197,15 @@ def test_interpret_plugin(self):
)
assert isinstance(result, TestPlugin)

def test_interpret_plugin_class(self):
"""Test type[Plugin] interpretation returns the class."""
result = interpret_env_var_value(
"tests.units.test_environment.TestPlugin",
type[Plugin],
"TEST_FIELD",
)
assert result is TestPlugin

def test_interpret_list(self):
"""Test list interpretation."""
result = interpret_env_var_value("1:2:3", list[int], "TEST_FIELD")
Expand Down
Loading