Skip to content

Commit 79da831

Browse files
darshanmehta17copybara-github
authored andcommitted
feat: RAG - Add Serverless and Spanner modes in preview.
PiperOrigin-RevId: 861235227
1 parent 044c3fa commit 79da831

File tree

5 files changed

+392
-7
lines changed

5 files changed

+392
-7
lines changed

tests/unit/vertex_rag/test_rag_constants_preview.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@
6969
RankService,
7070
Ranking,
7171
Scaled,
72+
Serverless,
7273
SharePointSource,
7374
SharePointSources,
7475
SlackChannel,
7576
SlackChannelsSource,
77+
Spanner,
7678
Unprovisioned,
7779
VertexAiSearchConfig,
7880
VertexFeatureStore,
@@ -584,6 +586,34 @@
584586
TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = (
585587
f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragEngineConfig"
586588
)
589+
TEST_RAG_ENGINE_CONFIG_SERVERLESS = RagEngineConfig(
590+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
591+
rag_managed_db_config=RagManagedDbConfig(mode=Serverless()),
592+
)
593+
TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC = RagEngineConfig(
594+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
595+
rag_managed_db_config=RagManagedDbConfig(
596+
mode=Spanner(tier=Basic()),
597+
),
598+
)
599+
TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED = RagEngineConfig(
600+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
601+
rag_managed_db_config=RagManagedDbConfig(
602+
mode=Spanner(tier=Scaled()),
603+
),
604+
)
605+
TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED = RagEngineConfig(
606+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
607+
rag_managed_db_config=RagManagedDbConfig(
608+
mode=Spanner(tier=Unprovisioned()),
609+
),
610+
)
611+
TEST_RAG_ENGINE_CONFIG_SPANNER_NO_TIER = RagEngineConfig(
612+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
613+
rag_managed_db_config=RagManagedDbConfig(
614+
mode=Spanner(),
615+
),
616+
)
587617
TEST_RAG_ENGINE_CONFIG_BASIC = RagEngineConfig(
588618
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
589619
rag_managed_db_config=RagManagedDbConfig(tier=Basic()),
@@ -604,6 +634,39 @@
604634
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
605635
rag_managed_db_config=None,
606636
)
637+
TEST_BAD_RAG_ENGINE_CONFIG_WITH_MODE_AND_TIER = RagEngineConfig(
638+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
639+
rag_managed_db_config=RagManagedDbConfig(
640+
mode=Spanner(tier=Basic()),
641+
tier=Scaled(),
642+
),
643+
)
644+
TEST_GAPIC_RAG_ENGINE_CONFIG_SERVERLESS = GapicRagEngineConfig(
645+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
646+
rag_managed_db_config=GapicRagManagedDbConfig(
647+
serverless=GapicRagManagedDbConfig.Serverless()
648+
),
649+
)
650+
TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_BASIC = GapicRagEngineConfig(
651+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
652+
rag_managed_db_config=GapicRagManagedDbConfig(
653+
spanner=GapicRagManagedDbConfig.Spanner(basic=GapicRagManagedDbConfig.Basic())
654+
),
655+
)
656+
TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED = GapicRagEngineConfig(
657+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
658+
rag_managed_db_config=GapicRagManagedDbConfig(
659+
spanner=GapicRagManagedDbConfig.Spanner(scaled=GapicRagManagedDbConfig.Scaled())
660+
),
661+
)
662+
TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED = GapicRagEngineConfig(
663+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
664+
rag_managed_db_config=GapicRagManagedDbConfig(
665+
spanner=GapicRagManagedDbConfig.Spanner(
666+
unprovisioned=GapicRagManagedDbConfig.Unprovisioned()
667+
)
668+
),
669+
)
607670
TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC = GapicRagEngineConfig(
608671
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
609672
rag_managed_db_config=GapicRagManagedDbConfig(

tests/unit/vertex_rag/test_rag_data_preview.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,91 @@ def update_rag_engine_config_enterprise_mock():
492492
yield update_rag_engine_config_enterprise_mock
493493

494494

495+
@pytest.fixture()
496+
def update_rag_engine_config_serverless_mock():
497+
with mock.patch.object(
498+
VertexRagDataServiceClient,
499+
"update_rag_engine_config",
500+
) as update_rag_engine_config_serverless_mock:
501+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
502+
update_rag_engine_config_lro_mock.done.return_value = True
503+
update_rag_engine_config_lro_mock.result.return_value = (
504+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SERVERLESS
505+
)
506+
update_rag_engine_config_serverless_mock.return_value = (
507+
update_rag_engine_config_lro_mock
508+
)
509+
yield update_rag_engine_config_serverless_mock
510+
511+
512+
@pytest.fixture()
513+
def update_rag_engine_config_spanner_basic_mock():
514+
with mock.patch.object(
515+
VertexRagDataServiceClient,
516+
"update_rag_engine_config",
517+
) as update_rag_engine_config_spanner_basic_mock:
518+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
519+
update_rag_engine_config_lro_mock.done.return_value = True
520+
update_rag_engine_config_lro_mock.result.return_value = (
521+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_BASIC
522+
)
523+
update_rag_engine_config_spanner_basic_mock.return_value = (
524+
update_rag_engine_config_lro_mock
525+
)
526+
yield update_rag_engine_config_spanner_basic_mock
527+
528+
529+
@pytest.fixture()
530+
def update_rag_engine_config_spanner_scaled_mock():
531+
with mock.patch.object(
532+
VertexRagDataServiceClient,
533+
"update_rag_engine_config",
534+
) as update_rag_engine_config_spanner_scaled_mock:
535+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
536+
update_rag_engine_config_lro_mock.done.return_value = True
537+
update_rag_engine_config_lro_mock.result.return_value = (
538+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED
539+
)
540+
update_rag_engine_config_spanner_scaled_mock.return_value = (
541+
update_rag_engine_config_lro_mock
542+
)
543+
yield update_rag_engine_config_spanner_scaled_mock
544+
545+
546+
@pytest.fixture()
547+
def update_rag_engine_config_spanner_no_tier_mock():
548+
with mock.patch.object(
549+
VertexRagDataServiceClient,
550+
"update_rag_engine_config",
551+
) as update_rag_engine_config_spanner_no_tier_mock:
552+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
553+
update_rag_engine_config_lro_mock.done.return_value = True
554+
update_rag_engine_config_lro_mock.result.return_value = (
555+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED
556+
)
557+
update_rag_engine_config_spanner_no_tier_mock.return_value = (
558+
update_rag_engine_config_lro_mock
559+
)
560+
yield update_rag_engine_config_spanner_no_tier_mock
561+
562+
563+
@pytest.fixture()
564+
def update_rag_engine_config_spanner_unprovisioned_mock():
565+
with mock.patch.object(
566+
VertexRagDataServiceClient,
567+
"update_rag_engine_config",
568+
) as update_rag_engine_config_spanner_unprovisioned_mock:
569+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
570+
update_rag_engine_config_lro_mock.done.return_value = True
571+
update_rag_engine_config_lro_mock.result.return_value = (
572+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED
573+
)
574+
update_rag_engine_config_spanner_unprovisioned_mock.return_value = (
575+
update_rag_engine_config_lro_mock
576+
)
577+
yield update_rag_engine_config_spanner_unprovisioned_mock
578+
579+
495580
@pytest.fixture()
496581
def update_rag_engine_config_scaled_mock():
497582
with mock.patch.object(
@@ -584,6 +669,54 @@ def get_rag_engine_enterprise_config_mock():
584669
yield get_rag_engine_enterprise_config_mock
585670

586671

672+
@pytest.fixture()
673+
def get_rag_engine_spanner_basic_config_mock():
674+
with mock.patch.object(
675+
VertexRagDataServiceClient,
676+
"get_rag_engine_config",
677+
) as get_rag_engine_spanner_basic_config_mock:
678+
get_rag_engine_spanner_basic_config_mock.return_value = (
679+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_BASIC
680+
)
681+
yield get_rag_engine_spanner_basic_config_mock
682+
683+
684+
@pytest.fixture()
685+
def get_rag_engine_spanner_scaled_config_mock():
686+
with mock.patch.object(
687+
VertexRagDataServiceClient,
688+
"get_rag_engine_config",
689+
) as get_rag_engine_spanner_scaled_config_mock:
690+
get_rag_engine_spanner_scaled_config_mock.return_value = (
691+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_SCALED
692+
)
693+
yield get_rag_engine_spanner_scaled_config_mock
694+
695+
696+
@pytest.fixture()
697+
def get_rag_engine_spanner_unprovisioned_config_mock():
698+
with mock.patch.object(
699+
VertexRagDataServiceClient,
700+
"get_rag_engine_config",
701+
) as get_rag_engine_spanner_unprovisioned_config_mock:
702+
get_rag_engine_spanner_unprovisioned_config_mock.return_value = (
703+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED
704+
)
705+
yield get_rag_engine_spanner_unprovisioned_config_mock
706+
707+
708+
@pytest.fixture()
709+
def get_rag_engine_serverless_config_mock():
710+
with mock.patch.object(
711+
VertexRagDataServiceClient,
712+
"get_rag_engine_config",
713+
) as get_rag_engine_serverless_config_mock:
714+
get_rag_engine_serverless_config_mock.return_value = (
715+
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SERVERLESS
716+
)
717+
yield get_rag_engine_serverless_config_mock
718+
719+
587720
@pytest.fixture()
588721
def get_rag_engine_config_mock_exception():
589722
with mock.patch.object(
@@ -1765,6 +1898,73 @@ def test_update_rag_engine_config_unprovisioned_success(
17651898
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED,
17661899
)
17671900

1901+
def test_update_rag_engine_config_spanner_basic_success(
1902+
self, update_rag_engine_config_spanner_basic_mock
1903+
):
1904+
rag_config = rag.update_rag_engine_config(
1905+
rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC,
1906+
)
1907+
assert update_rag_engine_config_spanner_basic_mock.call_count == 1
1908+
rag_engine_config_eq(
1909+
rag_config,
1910+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC,
1911+
)
1912+
1913+
def test_update_rag_engine_config_spanner_scaled_success(
1914+
self, update_rag_engine_config_spanner_scaled_mock
1915+
):
1916+
rag_config = rag.update_rag_engine_config(
1917+
rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED,
1918+
)
1919+
assert update_rag_engine_config_spanner_scaled_mock.call_count == 1
1920+
rag_engine_config_eq(
1921+
rag_config,
1922+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED,
1923+
)
1924+
1925+
def test_update_rag_engine_config_spanner_unprovisioned_success(
1926+
self, update_rag_engine_config_spanner_unprovisioned_mock
1927+
):
1928+
rag_config = rag.update_rag_engine_config(
1929+
rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED,
1930+
)
1931+
assert update_rag_engine_config_spanner_unprovisioned_mock.call_count == 1
1932+
rag_engine_config_eq(
1933+
rag_config,
1934+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED,
1935+
)
1936+
1937+
def test_update_rag_engine_config_spanner_no_tier_success(
1938+
self, update_rag_engine_config_spanner_no_tier_mock
1939+
):
1940+
rag_config = rag.update_rag_engine_config(
1941+
rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_NO_TIER,
1942+
)
1943+
assert update_rag_engine_config_spanner_no_tier_mock.call_count == 1
1944+
rag_engine_config_eq(
1945+
rag_config,
1946+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED,
1947+
)
1948+
1949+
def test_update_rag_engine_config_serverless_success(
1950+
self, update_rag_engine_config_serverless_mock
1951+
):
1952+
rag_config = rag.update_rag_engine_config(
1953+
rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SERVERLESS,
1954+
)
1955+
assert update_rag_engine_config_serverless_mock.call_count == 1
1956+
rag_engine_config_eq(
1957+
rag_config,
1958+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SERVERLESS,
1959+
)
1960+
1961+
def test_update_rag_engine_config_with_mode_and_tier_failure(self):
1962+
with pytest.raises(ValueError) as e:
1963+
rag.update_rag_engine_config(
1964+
rag_engine_config=test_rag_constants_preview.TEST_BAD_RAG_ENGINE_CONFIG_WITH_MODE_AND_TIER,
1965+
)
1966+
e.match("mode and tier both cannot be set at the same time")
1967+
17681968
@pytest.mark.usefixtures("update_rag_engine_config_mock_exception")
17691969
def test_update_rag_engine_config_failure(self):
17701970
with pytest.raises(RuntimeError) as e:
@@ -1786,6 +1986,46 @@ def test_update_rag_engine_config_bad_input(
17861986
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_BASIC,
17871987
)
17881988

1989+
@pytest.mark.usefixtures("get_rag_engine_serverless_config_mock")
1990+
def test_get_rag_engine_config_serverless_success(self):
1991+
rag_config = rag.get_rag_engine_config(
1992+
name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
1993+
)
1994+
rag_engine_config_eq(
1995+
rag_config,
1996+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SERVERLESS,
1997+
)
1998+
1999+
@pytest.mark.usefixtures("get_rag_engine_spanner_basic_config_mock")
2000+
def test_get_rag_engine_config_spanner_basic_success(self):
2001+
rag_config = rag.get_rag_engine_config(
2002+
name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
2003+
)
2004+
rag_engine_config_eq(
2005+
rag_config,
2006+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_BASIC,
2007+
)
2008+
2009+
@pytest.mark.usefixtures("get_rag_engine_spanner_scaled_config_mock")
2010+
def test_get_rag_engine_config_spanner_scaled_success(self):
2011+
rag_config = rag.get_rag_engine_config(
2012+
name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
2013+
)
2014+
rag_engine_config_eq(
2015+
rag_config,
2016+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_SCALED,
2017+
)
2018+
2019+
@pytest.mark.usefixtures("get_rag_engine_spanner_unprovisioned_config_mock")
2020+
def test_get_rag_engine_config_spanner_unprovisioned_success(self):
2021+
rag_config = rag.get_rag_engine_config(
2022+
name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
2023+
)
2024+
rag_engine_config_eq(
2025+
rag_config,
2026+
test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SPANNER_UNPROVISIONED,
2027+
)
2028+
17892029
@pytest.mark.usefixtures("get_rag_engine_basic_config_mock")
17902030
def test_get_rag_engine_config_success(self):
17912031
rag_config = rag.get_rag_engine_config(

vertexai/preview/rag/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@
6666
RankService,
6767
Ranking,
6868
Scaled,
69+
Serverless,
6970
SharePointSource,
7071
SharePointSources,
7172
SlackChannel,
7273
SlackChannelsSource,
74+
Spanner,
7375
TransformationConfig,
7476
Unprovisioned,
7577
VertexAiSearchConfig,
@@ -111,10 +113,12 @@
111113
"RankService",
112114
"Retrieval",
113115
"Scaled",
116+
"Serverless",
114117
"SharePointSource",
115118
"SharePointSources",
116119
"SlackChannel",
117120
"SlackChannelsSource",
121+
"Spanner",
118122
"TransformationConfig",
119123
"Unprovisioned",
120124
"VertexAiSearchConfig",

0 commit comments

Comments
 (0)