Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/unit/vertexai/genai/replays/test_skills_retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Tests the skills.retrieve() method against the autopush endpoint."""

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types

pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
)


def test_retrieve_skills(client):
# Target the prod endpoint for the Skill Registry API
client._api_client._http_options.base_url = (
"https://us-central1-aiplatform.googleapis.com"
)

response = client.skills.retrieve(query="stubby", config={"top_k": 2})

assert isinstance(response, types.RetrieveSkillsResponse)
assert response.retrieved_skills is not None

for retrieved in response.retrieved_skills:
assert isinstance(retrieved, types.RetrievedSkill)
assert retrieved.skill_name is not None
assert retrieved.description is not None
110 changes: 106 additions & 4 deletions tests/unit/vertexai/genai/test_genai_skills.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# //third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/test_genai_skills.py
import json
from unittest import mock
import google.auth.credentials
from vertexai import _genai as genai
from vertexai._genai import client as vertexai_client
from google.genai import types as genai_types
Expand All @@ -9,31 +10,42 @@

@pytest.fixture
def skills_client():
creds = mock.MagicMock()
creds = mock.create_autospec(google.auth.credentials.Credentials, instance=True)
creds.token = "test_token"
client = vertexai_client.Client(
project="test-project", location="test-location", credentials=creds
)
return client.skills


@pytest.fixture
def async_skills_client():
creds = mock.create_autospec(google.auth.credentials.Credentials, instance=True)
creds.token = "test_token"
client = vertexai_client.Client(
project="test-project", location="test-location", credentials=creds
)
return client.aio.skills


class TestGenaiSkills:
mock_get_skill_response = {
"name": "projects/test-project/locations/test-location/skills/test-skill",
"displayName": "My Test Skill",
}

def test_get_skill(self, skills_client):
"""Tests the get_skill method."""
with mock.patch.object(skills_client._api_client, "request") as request_mock:
with mock.patch.object(
skills_client._api_client, "request", autospec=True
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body=json.dumps(self.mock_get_skill_response)
)
skill_name = (
"projects/test-project/locations/test-location/skills/test-skill"
)
skill = skills_client.get(name=skill_name)
request_mock.assert_called_with(
request_mock.assert_called_once_with(
"get",
skill_name,
{"_url": {"name": skill_name}},
Expand All @@ -42,3 +54,93 @@ def test_get_skill(self, skills_client):
assert isinstance(skill, genai.types.Skill)
assert skill.name == skill_name
assert skill.display_name == "My Test Skill"

def test_retrieve_skills_response(self, skills_client):
mock_retrieve_response = {
"retrievedSkills": [
{
"skillName": (
"projects/test-project/locations/test-location/skills/skill-1"
),
"description": "Skill 1 Description",
},
{
"skillName": (
"projects/test-project/locations/test-location/skills/skill-2"
),
"description": "Skill 2 Description",
},
]
}

with mock.patch.object(
skills_client._api_client, "request", autospec=True
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body=json.dumps(mock_retrieve_response)
)

response = skills_client.retrieve(query="test query", config={"top_k": 5})

assert isinstance(response, genai.types.RetrieveSkillsResponse)
assert len(response.retrieved_skills) == 2
assert response.retrieved_skills[0].skill_name == (
"projects/test-project/locations/test-location/skills/skill-1"
)
assert response.retrieved_skills[0].description == "Skill 1 Description"

def test_retrieve_skills_request_params(self, skills_client):
mock_retrieve_response = {"retrievedSkills": []}

with mock.patch.object(
skills_client._api_client, "request", autospec=True
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body=json.dumps(mock_retrieve_response)
)

skills_client.retrieve(query="test query", config={"top_k": 5})

request_mock.assert_called_once_with(
"get",
"skills:retrieve?query=test+query&topK=5",
{"_query": {"query": "test query", "topK": 5}},
None,
)

@pytest.mark.asyncio
async def test_retrieve_skills_async(self, async_skills_client):
mock_retrieve_response = {
"retrievedSkills": [
{
"skillName": (
"projects/test-project/locations/test-location/skills/skill-1"
),
"description": "Skill 1 Description",
}
]
}

with mock.patch.object(
async_skills_client._api_client, "async_request", autospec=True
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body=json.dumps(mock_retrieve_response)
)

response = await async_skills_client.retrieve(
query="test query", config={"top_k": 1}
)

assert isinstance(response, genai.types.RetrieveSkillsResponse)
assert len(response.retrieved_skills) == 1
assert response.retrieved_skills[0].skill_name == (
"projects/test-project/locations/test-location/skills/skill-1"
)

request_mock.assert_called_once_with(
"get",
"skills:retrieve?query=test+query&topK=1",
{"_query": {"query": "test query", "topK": 1}},
None,
)
170 changes: 170 additions & 0 deletions vertexai/_genai/skills.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,36 @@ def _GetSkillRequestParameters_to_vertex(
return to_object


def _RetrieveSkillsConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}

if getv(from_object, ["top_k"]) is not None:
setv(parent_object, ["_query", "topK"], getv(from_object, ["top_k"]))

return to_object


def _RetrieveSkillsRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["query"]) is not None:
setv(to_object, ["_query", "query"], getv(from_object, ["query"]))

if getv(from_object, ["config"]) is not None:
setv(
to_object,
["config"],
_RetrieveSkillsConfig_to_vertex(getv(from_object, ["config"]), to_object),
)

return to_object


class Skills(_api_module.BaseModule):
"""Class for managing Skills in the Skill Registry."""

Expand Down Expand Up @@ -116,6 +146,75 @@ def get(
self._api_client._verify_response(return_value)
return return_value

def retrieve(
self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None
) -> types.RetrieveSkillsResponse:
"""
Retrieves skills semantically matched to a query.
"""

parameter_model = types._RetrieveSkillsRequestParameters(
query=query,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError(
"This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client."
)
else:
request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "skills:retrieve".format_map(request_url_dict)
else:
path = "skills:retrieve"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = self._api_client.request("get", path, request_dict, http_options)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.RetrieveSkillsResponse._from_response(
response=response_dict,
kwargs=(
{
"config": {
"response_schema": getattr(
parameter_model.config, "response_schema", None
),
"response_json_schema": getattr(
parameter_model.config, "response_json_schema", None
),
"include_all_fields": getattr(
parameter_model.config, "include_all_fields", None
),
}
}
if getattr(parameter_model, "config", None)
else {}
),
)

self._api_client._verify_response(return_value)
return return_value


class AsyncSkills(_api_module.BaseModule):
"""Class for managing Skills in the Skill Registry."""
Expand Down Expand Up @@ -190,3 +289,74 @@ async def get(

self._api_client._verify_response(return_value)
return return_value

async def retrieve(
self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None
) -> types.RetrieveSkillsResponse:
"""
Retrieves skills semantically matched to a query.
"""

parameter_model = types._RetrieveSkillsRequestParameters(
query=query,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError(
"This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client."
)
else:
request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "skills:retrieve".format_map(request_url_dict)
else:
path = "skills:retrieve"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = await self._api_client.async_request(
"get", path, request_dict, http_options
)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.RetrieveSkillsResponse._from_response(
response=response_dict,
kwargs=(
{
"config": {
"response_schema": getattr(
parameter_model.config, "response_schema", None
),
"response_json_schema": getattr(
parameter_model.config, "response_json_schema", None
),
"include_all_fields": getattr(
parameter_model.config, "include_all_fields", None
),
}
}
if getattr(parameter_model, "config", None)
else {}
),
)

self._api_client._verify_response(return_value)
return return_value
Loading
Loading