diff --git a/installer/__main__.py b/installer/__main__.py index 08c4b0d..74982c9 100644 --- a/installer/__main__.py +++ b/installer/__main__.py @@ -333,7 +333,14 @@ def bootstrap(): from deepnote_core.runtime.types import StreamlitSpec if cfg.server.start_streamlit_servers: - from .module.streamlit import fetch_streamlit_apps + from .module.streamlit import fetch_streamlit_apps, set_integration_env_vars + + # Fetch and set integration env vars before starting Streamlit so that + # Streamlit subprocesses inherit them via os.environ. + try: + set_integration_env_vars(logger) + except Exception: + logger.exception("Failed to set integration env vars") try: streamlit_apps = fetch_streamlit_apps(logger) diff --git a/installer/module/streamlit.py b/installer/module/streamlit.py index 70906b6..e0f5b7d 100644 --- a/installer/module/streamlit.py +++ b/installer/module/streamlit.py @@ -70,6 +70,69 @@ def fetch_streamlit_apps(logger: logging.Logger) -> List[dict]: return streamlit_apps +def fetch_integration_env_vars(logger: logging.Logger) -> List[dict]: + """Fetches integration environment variables from the WebApp. + + Returns a list of dicts with 'name' and 'value' keys. + """ + base_url = get_webapp_url() + url = f"{base_url}/integrations/environment-variables" + + timeout = 3 + max_retries = 3 + + logger.info(f"Fetching integration environment variables from {url}.") + + try: + json_content = request_with_retries( + logger, + url, + max_retries=max_retries, + timeout=timeout, + ) + variables = json.loads(json_content) + if not isinstance(variables, list): + logger.error( + "Invalid integration env vars payload type: expected list, " + f"got {type(variables).__name__}." + ) + return [] + logger.info(f"Fetched {len(variables)} integration environment variables.") + return variables + except urllib.error.URLError: + logger.exception("Network error while fetching integration env vars.") + except json.JSONDecodeError: + logger.exception("JSON parsing error while fetching integration env vars.") + except Exception: + logger.exception("Unexpected error while fetching integration env vars.") + + return [] + + +def set_integration_env_vars(logger: logging.Logger) -> None: + """Fetches integration env vars and sets them in os.environ. + + This ensures that Streamlit processes (and other subprocesses started + after this call) inherit integration environment variables. + """ + variables = fetch_integration_env_vars(logger) + for variable in variables: + if not isinstance(variable, dict): + logger.warning("Skipping integration env var entry with invalid shape.") + continue + name = variable.get("name") + value = variable.get("value") + if not isinstance(name, str) or not isinstance(value, str): + continue + if not name or "=" in name or "\0" in name: + logger.warning(f"Skipping invalid env var name: {name!r}") + continue + if "\0" in value: + logger.warning(f"Skipping env var with invalid value bytes: {name!r}") + continue + os.environ[name] = value + + def start_streamlit_servers( venv: VirtualEnvironment, logger: logging.Logger ) -> List[str]: diff --git a/tests/unit/test_streamlit.py b/tests/unit/test_streamlit.py index 9c4fd0e..c7916e6 100644 --- a/tests/unit/test_streamlit.py +++ b/tests/unit/test_streamlit.py @@ -1,9 +1,14 @@ import json import logging +import os import unittest from unittest.mock import MagicMock, patch -from installer.module.streamlit import fetch_streamlit_apps +from installer.module.streamlit import ( + fetch_integration_env_vars, + fetch_streamlit_apps, + set_integration_env_vars, +) class TestFetchStreamlitApps(unittest.TestCase): @@ -44,3 +49,130 @@ def test_fetch_streamlit_apps(self): } ], ) + + +class TestFetchIntegrationEnvVars(unittest.TestCase): + """Tests for fetching integration env vars from the WebApp.""" + + def test_fetch_integration_env_vars(self) -> None: + """Returns the parsed list of env var dicts on success.""" + mock_data = [ + {"name": "SNOWFLAKE_USER", "value": "admin"}, + {"name": "SNOWFLAKE_PASSWORD", "value": "secret123"}, + ] + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(mock_data).encode("utf-8") + + with patch("urllib.request.urlopen") as mock_urlopen: + mock_urlopen.return_value.__enter__.return_value = mock_response + + test_logger = logging.getLogger("testLogger") + variables = fetch_integration_env_vars(test_logger) + + mock_urlopen.assert_called_once_with( + "http://localhost:19456/userpod-api/integrations/environment-variables", + timeout=3, + ) + + self.assertEqual(variables, mock_data) + + def test_fetch_integration_env_vars_empty(self) -> None: + """Returns an empty list when the endpoint returns no vars.""" + mock_response = MagicMock() + mock_response.read.return_value = json.dumps([]).encode("utf-8") + + with patch("urllib.request.urlopen") as mock_urlopen: + mock_urlopen.return_value.__enter__.return_value = mock_response + + test_logger = logging.getLogger("testLogger") + variables = fetch_integration_env_vars(test_logger) + + self.assertEqual(variables, []) + + def test_fetch_integration_env_vars_network_error(self) -> None: + """Returns an empty list when a network error occurs.""" + import urllib.error + + with patch("urllib.request.urlopen") as mock_urlopen: + mock_urlopen.side_effect = urllib.error.URLError("connection refused") + + test_logger = logging.getLogger("testLogger") + variables = fetch_integration_env_vars(test_logger) + + self.assertEqual(variables, []) + + def test_fetch_integration_env_vars_non_list_payload(self) -> None: + """Returns an empty list when the payload is not a list.""" + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({"unexpected": "shape"}).encode( + "utf-8" + ) + + with patch("urllib.request.urlopen") as mock_urlopen: + mock_urlopen.return_value.__enter__.return_value = mock_response + + test_logger = logging.getLogger("testLogger") + variables = fetch_integration_env_vars(test_logger) + + self.assertEqual(variables, []) + + +class TestSetIntegrationEnvVars(unittest.TestCase): + """Tests for setting integration env vars in os.environ.""" + + def test_set_integration_env_vars(self) -> None: + """Valid env vars are set in os.environ.""" + self.addCleanup(os.environ.pop, "TEST_INT_VAR_A", None) + self.addCleanup(os.environ.pop, "TEST_INT_VAR_B", None) + + mock_data = [ + {"name": "TEST_INT_VAR_A", "value": "value_a"}, + {"name": "TEST_INT_VAR_B", "value": "value_b"}, + ] + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(mock_data).encode("utf-8") + + with patch("urllib.request.urlopen") as mock_urlopen: + mock_urlopen.return_value.__enter__.return_value = mock_response + + test_logger = logging.getLogger("testLogger") + set_integration_env_vars(test_logger) + + self.assertEqual(os.environ.get("TEST_INT_VAR_A"), "value_a") + self.assertEqual(os.environ.get("TEST_INT_VAR_B"), "value_b") + + def test_set_integration_env_vars_skips_invalid_entries(self) -> None: + """Invalid entries are skipped without affecting valid ones.""" + self.addCleanup(os.environ.pop, "TEST_INT_VAR_C", None) + + mock_data = [ + {"name": "TEST_INT_VAR_C", "value": "value_c"}, + {"name": None, "value": "orphan_value"}, + {"name": "TEST_INT_VAR_D", "value": None}, + {"name": "", "value": "empty_key"}, + {"name": "HAS=EQUALS", "value": "bad"}, + {"name": 123, "value": "non_string_key"}, + {"name": "GOOD_KEY", "value": 456}, + "not_a_dict_entry", + 42, + {"name": "NULL_BYTE_VAL", "value": "bad\0value"}, + {"name": "BAD\0NAME", "value": "null_byte_name"}, + ] + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(mock_data).encode("utf-8") + + with patch("urllib.request.urlopen") as mock_urlopen: + mock_urlopen.return_value.__enter__.return_value = mock_response + + test_logger = logging.getLogger("testLogger") + set_integration_env_vars(test_logger) + + self.assertEqual(os.environ.get("TEST_INT_VAR_C"), "value_c") + self.assertNotIn("TEST_INT_VAR_D", os.environ) + self.assertNotIn("HAS=EQUALS", os.environ) + self.assertNotIn("GOOD_KEY", os.environ) + self.assertNotIn("NULL_BYTE_VAL", os.environ) + self.assertNotIn("BAD\0NAME", os.environ)