Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

- [Orchestration] Added new API `OrchestrationTemplateReference#withScope` to support prompt templates with resource-group scope.
- [Orchestration] Chat completion calls now can have multiple module configs to support [fallback modules](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion).
- [Orchestration] `protected_material_code` is now supported as an output content filtering module .

### 🐛 Fixed Issues

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ public class AzureContentFilter implements ContentFilter {
/** The filter category for violence content. */
@Nullable AzureFilterThreshold violence;

/** The filter category for protected material content. */
boolean protectedMaterialCode;

/**
* A flag to set prompt shield on input filer.
*
Expand Down Expand Up @@ -90,7 +93,11 @@ public AzureContentSafetyInputFilterConfig createInputFilterConfig() {
@Override
@Nonnull
public AzureContentSafetyOutputFilterConfig createOutputFilterConfig() {
if (hate == null && selfHarm == null && sexual == null && violence == null) {
if (hate == null
&& selfHarm == null
&& sexual == null
&& violence == null
&& !protectedMaterialCode) {
throw new IllegalArgumentException("At least one filter category must be set");
}

Expand All @@ -101,6 +108,7 @@ public AzureContentSafetyOutputFilterConfig createOutputFilterConfig() {
.hate(hate != null ? hate.getAzureThreshold() : null)
.selfHarm(selfHarm != null ? selfHarm.getAzureThreshold() : null)
.sexual(sexual != null ? sexual.getAzureThreshold() : null)
.violence(violence != null ? violence.getAzureThreshold() : null));
.violence(violence != null ? violence.getAzureThreshold() : null)
.protectedMaterialCode(true));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,21 @@ static class InsideTestClass {
void testStackingInputAndOutputFilter() {
final var config = new OrchestrationModuleConfig().withLlmConfig(GPT_4O);

final var filter =
final var inputFilter =
new AzureContentFilter()
.hate(ALLOW_SAFE_LOW_MEDIUM)
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
.sexual(ALLOW_SAFE_LOW_MEDIUM)
.violence(ALLOW_SAFE_LOW_MEDIUM);
final var outputFilter = inputFilter.protectedMaterialCode(true);

final var configWithInputFirst = config.withInputFiltering(filter).withOutputFiltering(filter);
final var configWithInputFirst =
config.withInputFiltering(inputFilter).withOutputFiltering(outputFilter);
assertThat(configWithInputFirst.getFilteringConfig()).isNotNull();
assertThat(configWithInputFirst.getFilteringConfig().getInput()).isNotNull();

final var configWithOutputFirst = config.withOutputFiltering(filter).withInputFiltering(filter);
final var configWithOutputFirst =
config.withOutputFiltering(outputFilter).withInputFiltering(inputFilter);
assertThat(configWithOutputFirst.getFilteringConfig()).isNotNull();
assertThat(configWithOutputFirst.getFilteringConfig().getOutput()).isNotNull();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,18 +438,21 @@ void filteringLoose() throws IOException {
.withBodyFile("filteringLooseResponse.json")
.withHeader("Content-Type", "application/json")));

final var azureFilter =
final var azureInputFilter =
new AzureContentFilter()
.hate(ALLOW_SAFE_LOW_MEDIUM)
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
.sexual(ALLOW_SAFE_LOW_MEDIUM)
.violence(ALLOW_SAFE_LOW_MEDIUM);
final var azureOutputFilter = azureInputFilter.protectedMaterialCode(true);

final var llamaFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().selfHarm(true));

client.chatCompletion(
prompt,
config.withInputFiltering(azureFilter, llamaFilter).withOutputFiltering(azureFilter));
config
.withInputFiltering(azureInputFilter, llamaFilter)
.withOutputFiltering(azureOutputFilter));
// the result is asserted in the verify step below

// verify that null fields are absent from the sent request
Expand All @@ -464,19 +467,20 @@ void filteringLooseStream() throws IOException {
post(anyUrl())
.willReturn(aResponse().withBody(res).withHeader("Content-Type", "application/json")));

final var azureFilter =
final var azureInputFilter =
new AzureContentFilter()
.hate(ALLOW_SAFE_LOW_MEDIUM)
.selfHarm(ALLOW_SAFE_LOW_MEDIUM)
.sexual(ALLOW_SAFE_LOW_MEDIUM)
.violence(ALLOW_SAFE_LOW_MEDIUM);
final var azureOutputFilter = azureInputFilter.protectedMaterialCode(true);

final var llamaFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().selfHarm(true));

OrchestrationModuleConfig myConfig =
config
.withInputFiltering(azureFilter, llamaFilter)
.withOutputFiltering(azureFilter)
.withInputFiltering(azureInputFilter, llamaFilter)
.withOutputFiltering(azureOutputFilter)
.withOutputFilteringStreamOptions(FilteringStreamOptions.create().overlap(1_000));

Stream<String> result = client.streamChatCompletion(prompt, myConfig);
Expand Down Expand Up @@ -580,7 +584,8 @@ void outputFilteringStrict() {
.hate(ALLOW_SAFE)
.selfHarm(ALLOW_SAFE)
.sexual(ALLOW_SAFE)
.violence(ALLOW_SAFE);
.violence(ALLOW_SAFE)
.protectedMaterialCode(true);

final var llamaFilter =
new LlamaGuardFilter().config(LlamaGuard38b.create().violentCrimes(true));
Expand All @@ -601,7 +606,8 @@ void outputFilteringStrict() {
"Hate", 6,
"SelfHarm", 0,
"Sexual", 0,
"Violence", 6),
"Violence", 6,
"ProtectedMaterialCode", false),
"llama_guard_3_8b",
Map.of("violent_crimes", true)));
assertThat(e.getErrorResponse()).isNull();
Expand All @@ -612,6 +618,7 @@ void outputFilteringStrict() {
assertThat(e.getAzureContentSafetyOutput().getSelfHarm()).isEqualTo(NUMBER_0);
assertThat(e.getAzureContentSafetyOutput().getSexual()).isEqualTo(NUMBER_0);
assertThat(e.getAzureContentSafetyOutput().getViolence()).isEqualTo(NUMBER_6);
assertThat(e.getAzureContentSafetyOutput().isProtectedMaterialCode()).isFalse();

assertThat(e.getLlamaGuard38b()).isNotNull();
assertThat(e.getLlamaGuard38b().isViolentCrimes()).isTrue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
"Hate": 6,
"SelfHarm": 0,
"Sexual": 0,
"Violence": 6
"Violence": 6,
"ProtectedMaterialCode": false
},
"llama_guard_3_8b": {
"violent_crimes": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"self_harm": 4,
"sexual": 4,
"violence": 4,
"protected_material_code" : false
"protected_material_code" : true
}
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ public OrchestrationChatResponse inputFiltering(@Nonnull final AzureFilterThresh
*/
@Nonnull
public OrchestrationChatResponse outputFiltering(@Nonnull final AzureFilterThreshold policy) {
val isProtected = true; // enforce protected material code filtering

val systemMessage = Message.system("Give three paraphrases for the following sentence");
// Reliably triggering the content filter of models fine-tuned for ethical compliance
Expand All @@ -212,7 +213,12 @@ public OrchestrationChatResponse outputFiltering(@Nonnull final AzureFilterThres
new OrchestrationPrompt("'We shall spill blood tonight', said the operation in-charge.")
.messageHistory(List.of(systemMessage));
val filterConfig =
new AzureContentFilter().hate(policy).selfHarm(policy).sexual(policy).violence(policy);
new AzureContentFilter()
.hate(policy)
.selfHarm(policy)
.sexual(policy)
.violence(policy)
.protectedMaterialCode(isProtected);

val configWithFilter = config.withOutputFiltering(filterConfig);
return client.chatCompletion(prompt, configWithFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ void testOutputFilteringStrict() {
assertThat(actualAzureContentSafety.getSelfHarm()).isEqualTo(NUMBER_0);
assertThat(actualAzureContentSafety.getSexual()).isEqualTo(NUMBER_0);
assertThat(actualAzureContentSafety.getHate()).isEqualTo(NUMBER_0);
assertThat(actualAzureContentSafety.isProtectedMaterialCode()).isFalse();
});
}

Expand All @@ -299,7 +300,8 @@ void testOutputFilteringLenient() {
assertThat(response.getContent()).isNotEmpty();

var filterResult = response.getOriginalResponse().getIntermediateResults().getOutputFiltering();
assertThat(filterResult.getMessage()).containsPattern("Choice 0: Filtering was skipped.");
assertThat(filterResult.getMessage())
.containsPattern("Choice 0: Filtering passed successfully.");
}

@Test
Expand Down