Skip to content

Commit 54260fd

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: list all Model Garden models
PiperOrigin-RevId: 861346333
1 parent 4636507 commit 54260fd

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

tests/unit/vertexai/model_garden/test_model_garden.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,44 @@ def test_list_deployable_models(self, list_publisher_models_mock):
13551355
"google/gemma-2-2b",
13561356
]
13571357

1358+
def test_list_models(self, list_publisher_models_mock):
1359+
"""Tests listing models."""
1360+
aiplatform.init(
1361+
project=_TEST_PROJECT,
1362+
location=_TEST_LOCATION,
1363+
)
1364+
1365+
mg_models = model_garden.list_models()
1366+
list_publisher_models_mock.assert_called_with(
1367+
types.ListPublisherModelsRequest(
1368+
parent="publishers/*",
1369+
list_all_versions=True,
1370+
filter="is_hf_wildcard(false)",
1371+
)
1372+
)
1373+
1374+
assert mg_models == [
1375+
"google/paligemma@001",
1376+
"google/paligemma@002",
1377+
"google/paligemma@003",
1378+
"google/paligemma@004",
1379+
]
1380+
1381+
hf_models = model_garden.list_models(list_hf_models=True)
1382+
list_publisher_models_mock.assert_called_with(
1383+
types.ListPublisherModelsRequest(
1384+
parent="publishers/*",
1385+
list_all_versions=True,
1386+
filter="is_hf_wildcard(true)",
1387+
)
1388+
)
1389+
assert hf_models == [
1390+
"google/gemma-2-2b",
1391+
"google/gemma-2-2b",
1392+
"google/gemma-2-2b",
1393+
"google/gemma-2-2b",
1394+
]
1395+
13581396
def test_batch_prediction_success(self, batch_prediction_mock):
13591397
aiplatform.init(
13601398
project=_TEST_PROJECT,

vertexai/model_garden/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@
2222
OpenModel = _model_garden.OpenModel
2323
PartnerModel = _model_garden.PartnerModel
2424
list_deployable_models = _model_garden.list_deployable_models
25+
list_models = _model_garden.list_models
2526

26-
__all__ = ("OpenModel", "PartnerModel", "list_deployable_models")
27+
__all__ = ("OpenModel", "PartnerModel", "list_deployable_models", "list_models")

vertexai/model_garden/_model_garden.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def list_deployable_models(
6262
`{publisher}/{model}@{version}` or Hugging Face model ID in the format
6363
of `{organization}/{model}`.
6464
"""
65+
6566
filter_str = _NATIVE_MODEL_FILTER
6667
if list_hf_models:
6768
filter_str = " AND ".join([_HF_WILDCARD_FILTER, _VERIFIED_DEPLOYMENT_FILTER])
@@ -93,6 +94,50 @@ def list_deployable_models(
9394
return output
9495

9596

97+
def list_models(
98+
*, list_hf_models: bool = False, model_filter: Optional[str] = None
99+
) -> List[str]:
100+
"""Lists the models in Model Garden.
101+
102+
Args:
103+
list_hf_models: Whether to list the Hugging Face models.
104+
model_filter: Optional. A string to filter the models by.
105+
106+
Returns:
107+
The names of the models in Model Garden in the format of
108+
`{publisher}/{model}@{version}` or Hugging Face model ID in the format
109+
of `{organization}/{model}`.
110+
"""
111+
filter_str = _NATIVE_MODEL_FILTER
112+
if list_hf_models:
113+
filter_str = _HF_WILDCARD_FILTER
114+
if model_filter:
115+
filter_str = (
116+
f'{filter_str} AND (model_user_id=~"(?i).*{model_filter}.*" OR'
117+
f' display_name=~"(?i).*{model_filter}.*")'
118+
)
119+
120+
request = types.ListPublisherModelsRequest(
121+
parent="publishers/*",
122+
list_all_versions=True,
123+
filter=filter_str,
124+
)
125+
client = initializer.global_config.create_client(
126+
client_class=_ModelGardenClientWithOverride,
127+
credentials=initializer.global_config.credentials,
128+
location_override="us-central1",
129+
)
130+
response = client.list_publisher_models(request)
131+
output = []
132+
for page in response.pages:
133+
for model in page.publisher_models:
134+
output.append(
135+
re.sub(r"publishers/(hf-|)|models/", "", model.name)
136+
+ ("" if list_hf_models else ("@" + model.version_id))
137+
)
138+
return output
139+
140+
96141
def _is_hugging_face_model(model_name: str) -> bool:
97142
"""Returns whether the model is a Hugging Face model."""
98143
return re.match(r"^(?P<publisher>[^/]+)/(?P<model>[^/@]+)$", model_name)

0 commit comments

Comments
 (0)