diff --git a/singlestoredb/management/workspace.py b/singlestoredb/management/workspace.py index 827032f98..8ba179c30 100644 --- a/singlestoredb/management/workspace.py +++ b/singlestoredb/management/workspace.py @@ -1391,13 +1391,10 @@ def connect(self, **kwargs: Any) -> connection.Connection: msg='An endpoint has not been set in this ' 'starter workspace configuration', ) - # Parse endpoint as host:port - if ':' in self.endpoint: - host, port = self.endpoint.split(':', 1) - kwargs['host'] = host - kwargs['port'] = int(port) - else: - kwargs['host'] = self.endpoint + + kwargs['host'] = self.endpoint + kwargs['database'] = self.database_name + return connection.connect(**kwargs) def terminate(self) -> None: @@ -1455,7 +1452,7 @@ def starter_workspaces(self) -> NamedList['StarterWorkspace']: def create_user( self, - user_name: str, + username: str, password: Optional[str] = None, ) -> Dict[str, str]: """ @@ -1463,7 +1460,7 @@ def create_user( Parameters ---------- - user_name : str + username : str The starter workspace user name to connect the new user to the database password : str, optional Password for the new user. If not provided, a password will be @@ -1485,7 +1482,7 @@ def create_user( ) payload = { - 'userName': user_name, + 'userName': username, } if password is not None: payload['password'] = password @@ -1854,7 +1851,8 @@ def create_starter_workspace( self, name: str, database_name: str, - workspace_group: dict[str, str], + provider: str, + region_name: str, ) -> 'StarterWorkspace': """ Create a new starter (shared tier) workspace. @@ -1865,28 +1863,21 @@ def create_starter_workspace( Name of the starter workspace database_name : str Name of the database for the starter workspace - workspace_group : dict[str, str] - Workspace group input (dict with keys: 'cell_id' and 'name'). + provider : str + Cloud provider for the starter workspace (e.g., 'aws', 'gcp', 'azure') + region_name : str + Cloud provider region for the starter workspace (e.g., 'us-east-1') Returns ------- :class:`StarterWorkspace` """ - if not workspace_group or not isinstance(workspace_group, dict): - raise ValueError( - 'workspace_group must be a dict with keys: ' - "'cell_id' and 'name'", - ) - if set(workspace_group.keys()) != {'cell_id', 'name'}: - raise ValueError("workspace_group must contain only 'cell_id' and 'name'") payload = { 'name': name, 'databaseName': database_name, - 'workspaceGroup': { - 'name': workspace_group['name'], - 'cellID': workspace_group['cell_id'], - }, + 'provider': provider, + 'regionName': region_name, } res = self._post('sharedtier/virtualWorkspaces', json=payload) diff --git a/singlestoredb/tests/test_management.py b/singlestoredb/tests/test_management.py index cc31e7236..4b2af1bd3 100755 --- a/singlestoredb/tests/test_management.py +++ b/singlestoredb/tests/test_management.py @@ -25,6 +25,11 @@ def clean_name(s): return re.sub(r'[^\w]', r'-', s).replace('_', '-').lower() +def shared_database_name(s): + """Return a shared database name. Cannot contain special characters except -""" + return re.sub(r'[^\w]', '', s).replace('-', '_').lower() + + @pytest.mark.management class TestCluster(unittest.TestCase): @@ -370,39 +375,31 @@ class TestStarterWorkspace(unittest.TestCase): manager = None starter_workspace = None - starter_workspace_user = { - 'username': 'starter_user', - 'password': None, - } - - @property - def starter_username(self): - """Return the username for the starter workspace user.""" - return self.starter_workspace_user['username'] - - @property - def password(self): - """Return the password for the starter workspace user.""" - return self.starter_workspace_user['password'] @classmethod def setUpClass(cls): cls.manager = s2.manage_workspaces() - shared_tier_regions = [ + shared_tier_regions: NamedList[Region] = [ x for x in cls.manager.shared_tier_regions if 'US' in x.name ] - cls.password = secrets.token_urlsafe(20) + '-x&$' + cls.starter_username = 'starter_user' + cls.password = secrets.token_urlsafe(20) - name = clean_name(secrets.token_urlsafe(20)[:20]) + name = shared_database_name(secrets.token_urlsafe(20)[:20]) + + cls.database_name = f'starter_db_{name}' + + shared_tier_region: Region = random.choice(shared_tier_regions) + + if not shared_tier_region: + raise ValueError('No shared tier regions found') cls.starter_workspace = cls.manager.create_starter_workspace( f'starter-ws-test-{name}', - database_name=f'starter_db_{name}', - workspace_group={ - 'name': f'starter-wg-test-{name}', - 'cell_id': random.choice(shared_tier_regions).id, - }, + database_name=cls.database_name, + provider=shared_tier_region.provider, + region_name=shared_tier_region.region_name, ) cls.starter_workspace.create_user( @@ -470,14 +467,14 @@ def test_connect(self): ) as conn: with conn.cursor() as cur: cur.execute('show databases') - assert 'starter_db' in [x[0] for x in list(cur)] + assert self.database_name in [x[0] for x in list(cur)] # Test missing endpoint workspace = self.manager.get_starter_workspace(self.starter_workspace.id) workspace.endpoint = None with self.assertRaises(s2.ManagementError) as cm: - workspace.connect(user='admin', password=self.password) + workspace.connect(user=self.starter_username, password=self.password) assert 'endpoint' in cm.exception.msg, cm.exception.msg