Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion installer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,14 @@ def bootstrap():
from deepnote_core.runtime.types import StreamlitSpec

if cfg.server.start_streamlit_servers:
from .module.streamlit import fetch_streamlit_apps
from .module.streamlit import fetch_streamlit_apps, set_integration_env_vars

# Fetch and set integration env vars before starting Streamlit so that
# Streamlit subprocesses inherit them via os.environ.
try:
set_integration_env_vars(logger)
except Exception:
logger.exception("Failed to set integration env vars")

try:
streamlit_apps = fetch_streamlit_apps(logger)
Expand Down
63 changes: 63 additions & 0 deletions installer/module/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,69 @@ def fetch_streamlit_apps(logger: logging.Logger) -> List[dict]:
return streamlit_apps


def fetch_integration_env_vars(logger: logging.Logger) -> List[dict]:
"""Fetches integration environment variables from the WebApp.

Returns a list of dicts with 'name' and 'value' keys.
"""
base_url = get_webapp_url()
url = f"{base_url}/integrations/environment-variables"

timeout = 3
max_retries = 3

logger.info(f"Fetching integration environment variables from {url}.")

try:
json_content = request_with_retries(
logger,
url,
max_retries=max_retries,
timeout=timeout,
)
variables = json.loads(json_content)
if not isinstance(variables, list):
logger.error(
"Invalid integration env vars payload type: expected list, "
f"got {type(variables).__name__}."
)
return []
logger.info(f"Fetched {len(variables)} integration environment variables.")
return variables
Comment thread
coderabbitai[bot] marked this conversation as resolved.
except urllib.error.URLError:
logger.exception("Network error while fetching integration env vars.")
except json.JSONDecodeError:
logger.exception("JSON parsing error while fetching integration env vars.")
except Exception:
logger.exception("Unexpected error while fetching integration env vars.")

return []


def set_integration_env_vars(logger: logging.Logger) -> None:
"""Fetches integration env vars and sets them in os.environ.

This ensures that Streamlit processes (and other subprocesses started
after this call) inherit integration environment variables.
"""
variables = fetch_integration_env_vars(logger)
for variable in variables:
if not isinstance(variable, dict):
logger.warning("Skipping integration env var entry with invalid shape.")
continue
name = variable.get("name")
value = variable.get("value")
if not isinstance(name, str) or not isinstance(value, str):
continue
if not name or "=" in name or "\0" in name:
logger.warning(f"Skipping invalid env var name: {name!r}")
continue
if "\0" in value:
logger.warning(f"Skipping env var with invalid value bytes: {name!r}")
continue
os.environ[name] = value


def start_streamlit_servers(
venv: VirtualEnvironment, logger: logging.Logger
) -> List[str]:
Expand Down
134 changes: 133 additions & 1 deletion tests/unit/test_streamlit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import json
import logging
import os
import unittest
from unittest.mock import MagicMock, patch

from installer.module.streamlit import fetch_streamlit_apps
from installer.module.streamlit import (
fetch_integration_env_vars,
fetch_streamlit_apps,
set_integration_env_vars,
)


class TestFetchStreamlitApps(unittest.TestCase):
Expand Down Expand Up @@ -44,3 +49,130 @@ def test_fetch_streamlit_apps(self):
}
],
)


class TestFetchIntegrationEnvVars(unittest.TestCase):
"""Tests for fetching integration env vars from the WebApp."""

def test_fetch_integration_env_vars(self) -> None:
"""Returns the parsed list of env var dicts on success."""
mock_data = [
{"name": "SNOWFLAKE_USER", "value": "admin"},
{"name": "SNOWFLAKE_PASSWORD", "value": "secret123"},
]

mock_response = MagicMock()
mock_response.read.return_value = json.dumps(mock_data).encode("utf-8")

with patch("urllib.request.urlopen") as mock_urlopen:
mock_urlopen.return_value.__enter__.return_value = mock_response

test_logger = logging.getLogger("testLogger")
variables = fetch_integration_env_vars(test_logger)

mock_urlopen.assert_called_once_with(
"http://localhost:19456/userpod-api/integrations/environment-variables",
timeout=3,
)

self.assertEqual(variables, mock_data)

def test_fetch_integration_env_vars_empty(self) -> None:
"""Returns an empty list when the endpoint returns no vars."""
mock_response = MagicMock()
mock_response.read.return_value = json.dumps([]).encode("utf-8")

with patch("urllib.request.urlopen") as mock_urlopen:
mock_urlopen.return_value.__enter__.return_value = mock_response

test_logger = logging.getLogger("testLogger")
variables = fetch_integration_env_vars(test_logger)

self.assertEqual(variables, [])

def test_fetch_integration_env_vars_network_error(self) -> None:
"""Returns an empty list when a network error occurs."""
import urllib.error

with patch("urllib.request.urlopen") as mock_urlopen:
mock_urlopen.side_effect = urllib.error.URLError("connection refused")

test_logger = logging.getLogger("testLogger")
variables = fetch_integration_env_vars(test_logger)

self.assertEqual(variables, [])

def test_fetch_integration_env_vars_non_list_payload(self) -> None:
"""Returns an empty list when the payload is not a list."""
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({"unexpected": "shape"}).encode(
"utf-8"
)

with patch("urllib.request.urlopen") as mock_urlopen:
mock_urlopen.return_value.__enter__.return_value = mock_response

test_logger = logging.getLogger("testLogger")
variables = fetch_integration_env_vars(test_logger)

self.assertEqual(variables, [])


class TestSetIntegrationEnvVars(unittest.TestCase):
"""Tests for setting integration env vars in os.environ."""

def test_set_integration_env_vars(self) -> None:
"""Valid env vars are set in os.environ."""
self.addCleanup(os.environ.pop, "TEST_INT_VAR_A", None)
self.addCleanup(os.environ.pop, "TEST_INT_VAR_B", None)

mock_data = [
{"name": "TEST_INT_VAR_A", "value": "value_a"},
{"name": "TEST_INT_VAR_B", "value": "value_b"},
]

mock_response = MagicMock()
mock_response.read.return_value = json.dumps(mock_data).encode("utf-8")

with patch("urllib.request.urlopen") as mock_urlopen:
mock_urlopen.return_value.__enter__.return_value = mock_response

test_logger = logging.getLogger("testLogger")
set_integration_env_vars(test_logger)

self.assertEqual(os.environ.get("TEST_INT_VAR_A"), "value_a")
self.assertEqual(os.environ.get("TEST_INT_VAR_B"), "value_b")

def test_set_integration_env_vars_skips_invalid_entries(self) -> None:
"""Invalid entries are skipped without affecting valid ones."""
self.addCleanup(os.environ.pop, "TEST_INT_VAR_C", None)

mock_data = [
{"name": "TEST_INT_VAR_C", "value": "value_c"},
{"name": None, "value": "orphan_value"},
{"name": "TEST_INT_VAR_D", "value": None},
{"name": "", "value": "empty_key"},
{"name": "HAS=EQUALS", "value": "bad"},
{"name": 123, "value": "non_string_key"},
{"name": "GOOD_KEY", "value": 456},
"not_a_dict_entry",
42,
{"name": "NULL_BYTE_VAL", "value": "bad\0value"},
{"name": "BAD\0NAME", "value": "null_byte_name"},
]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

mock_response = MagicMock()
mock_response.read.return_value = json.dumps(mock_data).encode("utf-8")

with patch("urllib.request.urlopen") as mock_urlopen:
mock_urlopen.return_value.__enter__.return_value = mock_response

test_logger = logging.getLogger("testLogger")
set_integration_env_vars(test_logger)

self.assertEqual(os.environ.get("TEST_INT_VAR_C"), "value_c")
self.assertNotIn("TEST_INT_VAR_D", os.environ)
self.assertNotIn("HAS=EQUALS", os.environ)
self.assertNotIn("GOOD_KEY", os.environ)
self.assertNotIn("NULL_BYTE_VAL", os.environ)
self.assertNotIn("BAD\0NAME", os.environ)
Loading