diff --git a/mock_tests/test_collection.py b/mock_tests/test_collection.py index 17b9dacc2..8fd489b60 100644 --- a/mock_tests/test_collection.py +++ b/mock_tests/test_collection.py @@ -406,3 +406,24 @@ def test_grpc_forbidden_exception(forbidden: weaviate.collections.Collection) -> with pytest.raises(weaviate.exceptions.InsufficientPermissionsError): forbidden.data.insert_many([{"name": "test"}]) + + +def test_collection_exists(weaviate_mock: HTTPServer) -> None: + non_existing = "NonExistingCollection" + erroring = "ErroringCollection" + weaviate_mock.expect_request(f"/v1/schema/{non_existing}").respond_with_json( + response_json={"error": [{"message": "collection not found"}]}, status=404 + ) + weaviate_mock.expect_request(f"/v1/schema/{erroring}").respond_with_json( + response_json={"error": [{"message": "this is an error"}]}, status=500 + ) + + with weaviate.connect_to_local( + port=MOCK_PORT, host=MOCK_IP, grpc_port=MOCK_PORT_GRPC, skip_init_checks=True + ) as client: + assert not client.collections.exists(non_existing) + with pytest.raises(weaviate.exceptions.WeaviateInvalidInputError): + client.collections.exists("") + with pytest.raises(weaviate.exceptions.UnexpectedStatusCodeError) as e: + client.collections.exists(erroring) + assert e.value.status_code == 500 diff --git a/weaviate/collections/collection/sync.py b/weaviate/collections/collection/sync.py index 88f728b30..c8eaf7be9 100644 --- a/weaviate/collections/collection/sync.py +++ b/weaviate/collections/collection/sync.py @@ -29,6 +29,7 @@ from weaviate.collections.query import _QueryCollection from weaviate.collections.tenants import _Tenants from weaviate.connect.v4 import ConnectionSync +from weaviate.exceptions import UnexpectedStatusCodeError from weaviate.types import UUID from .base import _CollectionBase @@ -204,8 +205,10 @@ def exists(self) -> bool: try: self.config.get(simple=True) return True - except Exception: - return False + except UnexpectedStatusCodeError as e: + if e.status_code == 404: + return False + raise e def shards(self) -> List[Shard]: """Get the statuses of all the shards of this collection. diff --git a/weaviate/collections/collections/executor.py b/weaviate/collections/collections/executor.py index 69b2baa64..8497cdf51 100644 --- a/weaviate/collections/collections/executor.py +++ b/weaviate/collections/collections/executor.py @@ -318,6 +318,8 @@ def exists(self, name: str) -> executor.Result[bool]: """ _validate_input([_ValidateArgument(expected=[str], name="name", value=name)]) path = f"/schema/{_capitalize_first_letter(name)}" + if name == "": + raise WeaviateInvalidInputError("Collection name cannot be an empty string.") def resp(res: Response) -> bool: return res.status_code == 200