diff --git a/docs/release_notes.md b/docs/release_notes.md index e16ac20e9..6bbe03f34 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -18,6 +18,7 @@ - [Orchestration] Added new API `OrchestrationTemplateReference#withScope` to support prompt templates with resource-group scope. - [Orchestration] Chat completion calls now can have multiple module configs to support [fallback modules](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion). +- [Orchestration] `protected_material_code` is now supported as an output content filtering module . ### 🐛 Fixed Issues diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java index 7620fd54c..cf4c6c73b 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java @@ -48,6 +48,9 @@ public class AzureContentFilter implements ContentFilter { /** The filter category for violence content. */ @Nullable AzureFilterThreshold violence; + /** The filter category for protected material content. */ + boolean protectedMaterialCode; + /** * A flag to set prompt shield on input filer. * @@ -90,7 +93,11 @@ public AzureContentSafetyInputFilterConfig createInputFilterConfig() { @Override @Nonnull public AzureContentSafetyOutputFilterConfig createOutputFilterConfig() { - if (hate == null && selfHarm == null && sexual == null && violence == null) { + if (hate == null + && selfHarm == null + && sexual == null + && violence == null + && !protectedMaterialCode) { throw new IllegalArgumentException("At least one filter category must be set"); } @@ -101,6 +108,7 @@ public AzureContentSafetyOutputFilterConfig createOutputFilterConfig() { .hate(hate != null ? hate.getAzureThreshold() : null) .selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null) .sexual(sexual != null ? sexual.getAzureThreshold() : null) - .violence(violence != null ? violence.getAzureThreshold() : null)); + .violence(violence != null ? violence.getAzureThreshold() : null) + .protectedMaterialCode(true)); } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java index fcbeaed33..a325c446e 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java @@ -50,18 +50,21 @@ static class InsideTestClass { void testStackingInputAndOutputFilter() { final var config = new OrchestrationModuleConfig().withLlmConfig(GPT_4O); - final var filter = + final var inputFilter = new AzureContentFilter() .hate(ALLOW_SAFE_LOW_MEDIUM) .selfHarm(ALLOW_SAFE_LOW_MEDIUM) .sexual(ALLOW_SAFE_LOW_MEDIUM) .violence(ALLOW_SAFE_LOW_MEDIUM); + final var outputFilter = inputFilter.protectedMaterialCode(true); - final var configWithInputFirst = config.withInputFiltering(filter).withOutputFiltering(filter); + final var configWithInputFirst = + config.withInputFiltering(inputFilter).withOutputFiltering(outputFilter); assertThat(configWithInputFirst.getFilteringConfig()).isNotNull(); assertThat(configWithInputFirst.getFilteringConfig().getInput()).isNotNull(); - final var configWithOutputFirst = config.withOutputFiltering(filter).withInputFiltering(filter); + final var configWithOutputFirst = + config.withOutputFiltering(outputFilter).withInputFiltering(inputFilter); assertThat(configWithOutputFirst.getFilteringConfig()).isNotNull(); assertThat(configWithOutputFirst.getFilteringConfig().getOutput()).isNotNull(); } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index a3dec5a90..a3abaf540 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -438,18 +438,21 @@ void filteringLoose() throws IOException { .withBodyFile("filteringLooseResponse.json") .withHeader("Content-Type", "application/json"))); - final var azureFilter = + final var azureInputFilter = new AzureContentFilter() .hate(ALLOW_SAFE_LOW_MEDIUM) .selfHarm(ALLOW_SAFE_LOW_MEDIUM) .sexual(ALLOW_SAFE_LOW_MEDIUM) .violence(ALLOW_SAFE_LOW_MEDIUM); + final var azureOutputFilter = azureInputFilter.protectedMaterialCode(true); final var llamaFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().selfHarm(true)); client.chatCompletion( prompt, - config.withInputFiltering(azureFilter, llamaFilter).withOutputFiltering(azureFilter)); + config + .withInputFiltering(azureInputFilter, llamaFilter) + .withOutputFiltering(azureOutputFilter)); // the result is asserted in the verify step below // verify that null fields are absent from the sent request @@ -464,19 +467,20 @@ void filteringLooseStream() throws IOException { post(anyUrl()) .willReturn(aResponse().withBody(res).withHeader("Content-Type", "application/json"))); - final var azureFilter = + final var azureInputFilter = new AzureContentFilter() .hate(ALLOW_SAFE_LOW_MEDIUM) .selfHarm(ALLOW_SAFE_LOW_MEDIUM) .sexual(ALLOW_SAFE_LOW_MEDIUM) .violence(ALLOW_SAFE_LOW_MEDIUM); + final var azureOutputFilter = azureInputFilter.protectedMaterialCode(true); final var llamaFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().selfHarm(true)); OrchestrationModuleConfig myConfig = config - .withInputFiltering(azureFilter, llamaFilter) - .withOutputFiltering(azureFilter) + .withInputFiltering(azureInputFilter, llamaFilter) + .withOutputFiltering(azureOutputFilter) .withOutputFilteringStreamOptions(FilteringStreamOptions.create().overlap(1_000)); Stream result = client.streamChatCompletion(prompt, myConfig); @@ -580,7 +584,8 @@ void outputFilteringStrict() { .hate(ALLOW_SAFE) .selfHarm(ALLOW_SAFE) .sexual(ALLOW_SAFE) - .violence(ALLOW_SAFE); + .violence(ALLOW_SAFE) + .protectedMaterialCode(true); final var llamaFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().violentCrimes(true)); @@ -601,7 +606,8 @@ void outputFilteringStrict() { "Hate", 6, "SelfHarm", 0, "Sexual", 0, - "Violence", 6), + "Violence", 6, + "ProtectedMaterialCode", false), "llama_guard_3_8b", Map.of("violent_crimes", true))); assertThat(e.getErrorResponse()).isNull(); @@ -612,6 +618,7 @@ void outputFilteringStrict() { assertThat(e.getAzureContentSafetyOutput().getSelfHarm()).isEqualTo(NUMBER_0); assertThat(e.getAzureContentSafetyOutput().getSexual()).isEqualTo(NUMBER_0); assertThat(e.getAzureContentSafetyOutput().getViolence()).isEqualTo(NUMBER_6); + assertThat(e.getAzureContentSafetyOutput().isProtectedMaterialCode()).isFalse(); assertThat(e.getLlamaGuard38b()).isNotNull(); assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue(); diff --git a/orchestration/src/test/resources/__files/outputFilteringStrict.json b/orchestration/src/test/resources/__files/outputFilteringStrict.json index e34f2bce6..9439fe893 100644 --- a/orchestration/src/test/resources/__files/outputFilteringStrict.json +++ b/orchestration/src/test/resources/__files/outputFilteringStrict.json @@ -21,7 +21,8 @@ "Hate": 6, "SelfHarm": 0, "Sexual": 0, - "Violence": 6 + "Violence": 6, + "ProtectedMaterialCode": false }, "llama_guard_3_8b": { "violent_crimes": true diff --git a/orchestration/src/test/resources/filteringLooseRequestStream.json b/orchestration/src/test/resources/filteringLooseRequestStream.json index f359c5659..68c9bfe8d 100644 --- a/orchestration/src/test/resources/filteringLooseRequestStream.json +++ b/orchestration/src/test/resources/filteringLooseRequestStream.json @@ -56,7 +56,7 @@ "self_harm": 4, "sexual": 4, "violence": 4, - "protected_material_code" : false + "protected_material_code" : true } } ], diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java index 0ecc9b3cc..ed43c9258 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java @@ -204,6 +204,7 @@ public OrchestrationChatResponse inputFiltering(@Nonnull final AzureFilterThresh */ @Nonnull public OrchestrationChatResponse outputFiltering(@Nonnull final AzureFilterThreshold policy) { + val isProtected = true; // enforce protected material code filtering val systemMessage = Message.system("Give three paraphrases for the following sentence"); // Reliably triggering the content filter of models fine-tuned for ethical compliance @@ -212,7 +213,12 @@ public OrchestrationChatResponse outputFiltering(@Nonnull final AzureFilterThres new OrchestrationPrompt("'We shall spill blood tonight', said the operation in-charge.") .messageHistory(List.of(systemMessage)); val filterConfig = - new AzureContentFilter().hate(policy).selfHarm(policy).sexual(policy).violence(policy); + new AzureContentFilter() + .hate(policy) + .selfHarm(policy) + .sexual(policy) + .violence(policy) + .protectedMaterialCode(isProtected); val configWithFilter = config.withOutputFiltering(filterConfig); return client.chatCompletion(prompt, configWithFilter); diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 5051e091c..8108728d3 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -286,6 +286,7 @@ void testOutputFilteringStrict() { assertThat(actualAzureContentSafety.getSelfHarm()).isEqualTo(NUMBER_0); assertThat(actualAzureContentSafety.getSexual()).isEqualTo(NUMBER_0); assertThat(actualAzureContentSafety.getHate()).isEqualTo(NUMBER_0); + assertThat(actualAzureContentSafety.isProtectedMaterialCode()).isFalse(); }); } @@ -299,7 +300,8 @@ void testOutputFilteringLenient() { assertThat(response.getContent()).isNotEmpty(); var filterResult = response.getOriginalResponse().getIntermediateResults().getOutputFiltering(); - assertThat(filterResult.getMessage()).containsPattern("Choice 0: Filtering was skipped."); + assertThat(filterResult.getMessage()) + .containsPattern("Choice 0: Filtering passed successfully."); } @Test