Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions singlestoredb/management/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,13 +1391,10 @@ def connect(self, **kwargs: Any) -> connection.Connection:
msg='An endpoint has not been set in this '
'starter workspace configuration',
)
# Parse endpoint as host:port
if ':' in self.endpoint:
host, port = self.endpoint.split(':', 1)
kwargs['host'] = host
kwargs['port'] = int(port)
else:
kwargs['host'] = self.endpoint

kwargs['host'] = self.endpoint
kwargs['database'] = self.database_name

return connection.connect(**kwargs)

def terminate(self) -> None:
Expand Down Expand Up @@ -1455,15 +1452,15 @@ def starter_workspaces(self) -> NamedList['StarterWorkspace']:

def create_user(
self,
user_name: str,
username: str,
password: Optional[str] = None,
) -> Dict[str, str]:
"""
Create a new user for this starter workspace.

Parameters
----------
user_name : str
username : str
The starter workspace user name to connect the new user to the database
password : str, optional
Password for the new user. If not provided, a password will be
Expand All @@ -1485,7 +1482,7 @@ def create_user(
)

payload = {
'userName': user_name,
'userName': username,
}
if password is not None:
payload['password'] = password
Expand Down Expand Up @@ -1854,7 +1851,8 @@ def create_starter_workspace(
self,
name: str,
database_name: str,
workspace_group: dict[str, str],
provider: str,
region_name: str,
) -> 'StarterWorkspace':
"""
Create a new starter (shared tier) workspace.
Expand All @@ -1865,28 +1863,21 @@ def create_starter_workspace(
Name of the starter workspace
database_name : str
Name of the database for the starter workspace
workspace_group : dict[str, str]
Workspace group input (dict with keys: 'cell_id' and 'name').
provider : str
Cloud provider for the starter workspace (e.g., 'aws', 'gcp', 'azure')
region_name : str
Cloud provider region for the starter workspace (e.g., 'us-east-1')

Returns
-------
:class:`StarterWorkspace`
"""
if not workspace_group or not isinstance(workspace_group, dict):
raise ValueError(
'workspace_group must be a dict with keys: '
"'cell_id' and 'name'",
)
if set(workspace_group.keys()) != {'cell_id', 'name'}:
raise ValueError("workspace_group must contain only 'cell_id' and 'name'")

payload = {
'name': name,
'databaseName': database_name,
'workspaceGroup': {
'name': workspace_group['name'],
'cellID': workspace_group['cell_id'],
},
'provider': provider,
'regionName': region_name,
}

res = self._post('sharedtier/virtualWorkspaces', json=payload)
Expand Down
45 changes: 21 additions & 24 deletions singlestoredb/tests/test_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def clean_name(s):
return re.sub(r'[^\w]', r'-', s).replace('_', '-').lower()


def shared_database_name(s):
"""Return a shared database name. Cannot contain special characters except -"""
return re.sub(r'[^\w]', '', s).replace('-', '_').lower()


@pytest.mark.management
class TestCluster(unittest.TestCase):

Expand Down Expand Up @@ -370,39 +375,31 @@ class TestStarterWorkspace(unittest.TestCase):

manager = None
starter_workspace = None
starter_workspace_user = {
'username': 'starter_user',
'password': None,
}

@property
def starter_username(self):
"""Return the username for the starter workspace user."""
return self.starter_workspace_user['username']

@property
def password(self):
"""Return the password for the starter workspace user."""
return self.starter_workspace_user['password']

@classmethod
def setUpClass(cls):
cls.manager = s2.manage_workspaces()

shared_tier_regions = [
shared_tier_regions: NamedList[Region] = [
x for x in cls.manager.shared_tier_regions if 'US' in x.name
]
cls.password = secrets.token_urlsafe(20) + '-x&$'
cls.starter_username = 'starter_user'
cls.password = secrets.token_urlsafe(20)

name = clean_name(secrets.token_urlsafe(20)[:20])
name = shared_database_name(secrets.token_urlsafe(20)[:20])

cls.database_name = f'starter_db_{name}'

shared_tier_region: Region = random.choice(shared_tier_regions)

if not shared_tier_region:
raise ValueError('No shared tier regions found')

cls.starter_workspace = cls.manager.create_starter_workspace(
f'starter-ws-test-{name}',
database_name=f'starter_db_{name}',
workspace_group={
'name': f'starter-wg-test-{name}',
'cell_id': random.choice(shared_tier_regions).id,
},
database_name=cls.database_name,
provider=shared_tier_region.provider,
region_name=shared_tier_region.region_name,
)

cls.starter_workspace.create_user(
Expand Down Expand Up @@ -470,14 +467,14 @@ def test_connect(self):
) as conn:
with conn.cursor() as cur:
cur.execute('show databases')
assert 'starter_db' in [x[0] for x in list(cur)]
assert self.database_name in [x[0] for x in list(cur)]

# Test missing endpoint
workspace = self.manager.get_starter_workspace(self.starter_workspace.id)
workspace.endpoint = None

with self.assertRaises(s2.ManagementError) as cm:
workspace.connect(user='admin', password=self.password)
workspace.connect(user=self.starter_username, password=self.password)

assert 'endpoint' in cm.exception.msg, cm.exception.msg

Expand Down