From 96fc9992356bd3c3c6aafb0379b8846aacdf5274 Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Thu, 26 Feb 2026 16:13:32 +0100 Subject: [PATCH 1/8] Responses api init --- src/http_rest_api_handler.cpp | 2 +- src/llm/apis/openai_completions.cpp | 789 ++++++++++++++++++ src/llm/apis/openai_completions.hpp | 5 + .../continuous_batching/servable.cpp | 15 +- src/llm/servable.cpp | 56 +- .../continuous_batching/servable.cpp | 6 +- .../visual_language_model/legacy/servable.cpp | 6 +- src/test/http_openai_handler_test.cpp | 526 ++++++++++++ .../complete_flow_test.cpp | 191 +++++ 9 files changed, 1587 insertions(+), 9 deletions(-) diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index afe163e6dc..ab6a2c8868 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -531,7 +531,7 @@ static Status createV3HttpPayload( return Status(StatusCode::JSON_INVALID, "model field is not a string"); } - bool isTextGenerationEndpoint = uri.find("completions") != std::string_view::npos; + bool isTextGenerationEndpoint = (uri.find("completions") != std::string_view::npos) || (uri.find("responses") != std::string_view::npos); if (isTextGenerationEndpoint) { auto streamIt = parsedJson->FindMember("stream"); if (streamIt != parsedJson->MemberEnd()) { diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 0402017564..0d383a48f4 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -17,6 +17,7 @@ #include "openai_completions.hpp" #include +#include #include #include #include @@ -95,6 +96,328 @@ ov::genai::JsonContainer rapidJsonValueToJsonContainer(const rapidjson::Value& v throw std::invalid_argument("Unsupported JSON value type"); } +std::string serializeResponsesUnaryResponse( + const std::vector& parsedOutputs, + const CompletionUsageStatistics& usage, + const OpenAIChatCompletionsRequest& request, + const ToolsSchemas_t& toolNameSchemaMap, + std::chrono::time_point created) { + const auto createdAt = std::chrono::duration_cast(created.time_since_epoch()).count(); + const std::string responseId = "resp-" + std::to_string(createdAt); + + auto serializeResponsesToolChoice = [&request](Writer& writer) { + writer.String("tool_choice"); + if (request.toolChoice.empty()) { + writer.String("auto"); + } else if (request.toolChoice == "auto" || request.toolChoice == "none" || request.toolChoice == "required") { + writer.String(request.toolChoice.c_str()); + } else { + writer.StartObject(); + writer.String("type"); + writer.String("function"); + writer.String("name"); + writer.String(request.toolChoice.c_str()); + writer.EndObject(); + } + }; + + auto serializeResponsesTools = [&toolNameSchemaMap](Writer& writer) { + writer.String("tools"); + writer.StartArray(); + for (const auto& [toolName, toolSchemaWrapper] : toolNameSchemaMap) { + writer.StartObject(); + writer.String("type"); + writer.String("function"); + writer.String("name"); + writer.String(toolName.c_str()); + writer.String("parameters"); + writer.RawValue(toolSchemaWrapper.stringRepr.c_str(), toolSchemaWrapper.stringRepr.size(), rapidjson::kObjectType); + writer.EndObject(); + } + writer.EndArray(); + }; + + StringBuffer buffer; + Writer writer(buffer); + + writer.StartObject(); + writer.String("id"); + writer.String(responseId.c_str()); + writer.String("object"); + writer.String("response"); + writer.String("created_at"); + writer.Int64(createdAt); + writer.String("completed_at"); + writer.Int64(createdAt); + writer.String("model"); + writer.String(request.model.c_str()); + writer.String("status"); + writer.String("completed"); + + writer.String("parallel_tool_calls"); + writer.Bool(false); + serializeResponsesToolChoice(writer); + serializeResponsesTools(writer); + + if (request.maxTokens.has_value()) { + writer.String("max_output_tokens"); + writer.Uint64(static_cast(request.maxTokens.value())); + } + + writer.String("output"); + writer.StartArray(); + int outputIndex = 0; + for (const auto& parsedOutput : parsedOutputs) { + const std::string outputId = "msg-" + std::to_string(outputIndex++); + + writer.StartObject(); + writer.String("id"); + writer.String(outputId.c_str()); + writer.String("type"); + writer.String("message"); + writer.String("role"); + writer.String("assistant"); + writer.String("status"); + writer.String("completed"); + writer.String("content"); + writer.StartArray(); + writer.StartObject(); + writer.String("type"); + writer.String("output_text"); + writer.String("text"); + writer.String(parsedOutput.content.c_str()); + writer.String("annotations"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + writer.EndArray(); + writer.EndObject(); + } + writer.EndArray(); + + writer.String("usage"); + writer.StartObject(); + writer.String("input_tokens"); + writer.Uint64(static_cast(usage.promptTokens)); + writer.String("input_tokens_details"); + writer.StartObject(); + writer.String("cached_tokens"); + writer.Uint64(0); + writer.EndObject(); + writer.String("output_tokens"); + writer.Uint64(static_cast(usage.completionTokens)); + writer.String("output_tokens_details"); + writer.StartObject(); + writer.String("reasoning_tokens"); + writer.Uint64(0); + writer.EndObject(); + writer.String("total_tokens"); + writer.Uint64(static_cast(usage.calculateTotalTokens())); + writer.EndObject(); + + writer.EndObject(); + + return buffer.GetString(); +} + +absl::Status normalizeResponsesFunctionToolsInPlace(rapidjson::Document& doc) { + auto toolsIt = doc.FindMember("tools"); + if (toolsIt == doc.MemberEnd() || toolsIt->value.IsNull()) { + return absl::OkStatus(); + } + if (!toolsIt->value.IsArray()) { + return absl::InvalidArgumentError("Tools are not an array"); + } + + auto& allocator = doc.GetAllocator(); + for (auto& toolValue : toolsIt->value.GetArray()) { + if (!toolValue.IsObject()) { + return absl::InvalidArgumentError("Tool is not a JSON object"); + } + auto toolObj = toolValue.GetObject(); + auto typeIt = toolObj.FindMember("type"); + if (typeIt == toolObj.MemberEnd() || !typeIt->value.IsString()) { + return absl::InvalidArgumentError("Tool type is missing or invalid"); + } + if (std::string(typeIt->value.GetString()) != "function") { + return absl::InvalidArgumentError("Only function tools are supported"); + } + + auto functionIt = toolObj.FindMember("function"); + if (functionIt != toolObj.MemberEnd()) { + if (!functionIt->value.IsObject()) { + return absl::InvalidArgumentError("Function is not a valid JSON object"); + } + continue; + } + + auto nameIt = toolObj.FindMember("name"); + if (nameIt == toolObj.MemberEnd() || !nameIt->value.IsString()) { + return absl::InvalidArgumentError("Function object does not contain a valid name field"); + } + + rapidjson::Value functionObj(rapidjson::kObjectType); + functionObj.AddMember("name", rapidjson::Value(nameIt->value.GetString(), allocator), allocator); + + auto descriptionIt = toolObj.FindMember("description"); + if (descriptionIt != toolObj.MemberEnd() && descriptionIt->value.IsString()) { + functionObj.AddMember("description", rapidjson::Value(descriptionIt->value.GetString(), allocator), allocator); + } + + auto parametersIt = toolObj.FindMember("parameters"); + if (parametersIt != toolObj.MemberEnd()) { + if (!parametersIt->value.IsObject()) { + return absl::InvalidArgumentError("Function parameters are not a valid JSON object"); + } + rapidjson::Value parametersCopy(rapidjson::kObjectType); + parametersCopy.CopyFrom(parametersIt->value, allocator); + functionObj.AddMember("parameters", parametersCopy, allocator); + } + + toolValue.AddMember("function", functionObj, allocator); + } + + auto toolChoiceIt = doc.FindMember("tool_choice"); + if (toolChoiceIt != doc.MemberEnd() && !toolChoiceIt->value.IsNull() && toolChoiceIt->value.IsObject()) { + auto toolChoiceObj = toolChoiceIt->value.GetObject(); + auto functionIt = toolChoiceObj.FindMember("function"); + if (functionIt == toolChoiceObj.MemberEnd()) { + auto typeIt = toolChoiceObj.FindMember("type"); + auto nameIt = toolChoiceObj.FindMember("name"); + if (typeIt != toolChoiceObj.MemberEnd() && typeIt->value.IsString() && std::string(typeIt->value.GetString()) == "function") { + if (nameIt == toolChoiceObj.MemberEnd() || !nameIt->value.IsString()) { + return absl::InvalidArgumentError("tool_choice.name is not a valid string"); + } + + rapidjson::Value functionObj(rapidjson::kObjectType); + functionObj.AddMember("name", rapidjson::Value(nameIt->value.GetString(), allocator), allocator); + toolChoiceIt->value.AddMember("function", functionObj, allocator); + } + } + } + + return absl::OkStatus(); +} + +absl::Status normalizeResponsesInputToMessagesInPlace(rapidjson::Document& doc) { + auto inputIt = doc.FindMember("input"); + if (inputIt == doc.MemberEnd()) { + return absl::InvalidArgumentError("input missing in request"); + } + auto& allocator = doc.GetAllocator(); + if (inputIt->value.IsString()) { + rapidjson::Value messages(rapidjson::kArrayType); + rapidjson::Value messageObj(rapidjson::kObjectType); + messageObj.AddMember("role", "user", allocator); + messageObj.AddMember("content", rapidjson::Value(inputIt->value.GetString(), allocator), allocator); + messages.PushBack(messageObj, allocator); + + auto existingMessages = doc.FindMember("messages"); + if (existingMessages != doc.MemberEnd()) { + existingMessages->value = messages; + } else { + doc.AddMember("messages", messages, allocator); + } + return absl::OkStatus(); + } + if (!inputIt->value.IsArray()) { + return absl::InvalidArgumentError("input is not a string or array"); + } + + rapidjson::Value messages(rapidjson::kArrayType); + for (auto& item : inputIt->value.GetArray()) { + if (!item.IsObject()) { + return absl::InvalidArgumentError("input array items must be objects"); + } + + auto itemObj = item.GetObject(); + auto roleIt = itemObj.FindMember("role"); + if (roleIt == itemObj.MemberEnd() || !roleIt->value.IsString()) { + return absl::InvalidArgumentError("input item role is missing or invalid"); + } + + rapidjson::Value messageObj(rapidjson::kObjectType); + messageObj.AddMember("role", rapidjson::Value(roleIt->value.GetString(), allocator), allocator); + + auto contentIt = itemObj.FindMember("content"); + if (contentIt == itemObj.MemberEnd()) { + return absl::InvalidArgumentError("input item content is missing"); + } + + if (contentIt->value.IsString()) { + messageObj.AddMember("content", rapidjson::Value(contentIt->value.GetString(), allocator), allocator); + messages.PushBack(messageObj, allocator); + continue; + } + + if (!contentIt->value.IsArray()) { + return absl::InvalidArgumentError("input item content must be a string or array"); + } + + rapidjson::Value normalizedContent(rapidjson::kArrayType); + for (auto& contentItem : contentIt->value.GetArray()) { + if (!contentItem.IsObject()) { + return absl::InvalidArgumentError("input content items must be objects"); + } + auto contentObj = contentItem.GetObject(); + auto typeIt = contentObj.FindMember("type"); + if (typeIt == contentObj.MemberEnd() || !typeIt->value.IsString()) { + return absl::InvalidArgumentError("input content item type is missing or invalid"); + } + + std::string type = typeIt->value.GetString(); + if (type == "input_text") { + auto textIt = contentObj.FindMember("text"); + if (textIt == contentObj.MemberEnd() || !textIt->value.IsString()) { + return absl::InvalidArgumentError("input_text requires a valid text field"); + } + rapidjson::Value textObj(rapidjson::kObjectType); + textObj.AddMember("type", "text", allocator); + textObj.AddMember("text", rapidjson::Value(textIt->value.GetString(), allocator), allocator); + normalizedContent.PushBack(textObj, allocator); + } else if (type == "input_image") { + std::string imageUrl; + auto imageUrlIt = contentObj.FindMember("image_url"); + if (imageUrlIt == contentObj.MemberEnd()) { + return absl::InvalidArgumentError("input_image requires image_url field"); + } + if (imageUrlIt->value.IsString()) { + imageUrl = imageUrlIt->value.GetString(); + } else if (imageUrlIt->value.IsObject()) { + auto imageUrlObj = imageUrlIt->value.GetObject(); + auto urlIt = imageUrlObj.FindMember("url"); + if (urlIt == imageUrlObj.MemberEnd() || !urlIt->value.IsString()) { + return absl::InvalidArgumentError("input_image.image_url.url is missing or invalid"); + } + imageUrl = urlIt->value.GetString(); + } else { + return absl::InvalidArgumentError("input_image.image_url must be a string or object"); + } + + rapidjson::Value imageUrlObj(rapidjson::kObjectType); + imageUrlObj.AddMember("url", rapidjson::Value(imageUrl.c_str(), allocator), allocator); + + rapidjson::Value imageObj(rapidjson::kObjectType); + imageObj.AddMember("type", "image_url", allocator); + imageObj.AddMember("image_url", imageUrlObj, allocator); + normalizedContent.PushBack(imageObj, allocator); + } else { + return absl::InvalidArgumentError("Unsupported content type"); + } + } + messageObj.AddMember("content", normalizedContent, allocator); + messages.PushBack(messageObj, allocator); + } + + auto existingMessages = doc.FindMember("messages"); + if (existingMessages != doc.MemberEnd()) { + existingMessages->value = messages; + } else { + doc.AddMember("messages", messages, allocator); + } + return absl::OkStatus(); +} + } // namespace absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() { @@ -656,6 +979,120 @@ absl::Status OpenAIChatCompletionsHandler::parseChatCompletionsPart(std::optiona return absl::OkStatus(); } +absl::Status OpenAIChatCompletionsHandler::parseResponsesPart(std::optional maxTokensLimit, std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains) { + // input: string; required + auto it = doc.FindMember("input"); + if (it == doc.MemberEnd()) { + return absl::InvalidArgumentError("input missing in request"); + } + + auto normalizeInputStatus = normalizeResponsesInputToMessagesInPlace(doc); + if (!normalizeInputStatus.ok()) { + return normalizeInputStatus; + } + + it = doc.FindMember("input"); + if (it == doc.MemberEnd()) { + return absl::InvalidArgumentError("input missing in request"); + } + + if (it->value.IsString()) { + request.prompt = it->value.GetString(); + if (!request.prompt.has_value() || !request.prompt.value().size()) { + return absl::InvalidArgumentError("input cannot be empty"); + } + } + + auto messagesStatus = parseMessages(allowedLocalMediaPath, allowedMediaDomains); + if (!messagesStatus.ok()) { + return messagesStatus; + } + + // logprobs: bool; optional - defaults to false + it = doc.FindMember("logprobs"); + if (it != doc.MemberEnd() && !it->value.IsNull()) { + if (!it->value.IsBool()) + return absl::InvalidArgumentError("logprobs accepts values true or false"); + request.logprobschat = it->value.GetBool(); + } + if (request.logprobschat && request.stream) { + return absl::InvalidArgumentError("logprobs are not supported in streaming mode."); + } + + auto toolsStatus = normalizeResponsesFunctionToolsInPlace(doc); + if (!toolsStatus.ok()) { + return toolsStatus; + } + toolsStatus = parseTools(); + if (!toolsStatus.ok()) { + return toolsStatus; + } + + std::optional maxCompletionTokens; + std::optional maxOutputTokens; + + // max_completion_tokens: uint; optional + it = doc.FindMember("max_completion_tokens"); + if (it != doc.MemberEnd() && !it->value.IsNull()) { + if (!it->value.IsUint()) { + if (it->value.IsUint64()) + return absl::InvalidArgumentError("max_completion_tokens value can't be greater than 4294967295"); + return absl::InvalidArgumentError("max_completion_tokens is not an unsigned integer"); + } + if (maxTokensLimit.has_value() && it->value.GetUint() > maxTokensLimit.value()) + return absl::InvalidArgumentError(absl::StrCat("max_completion_tokens exceeds limit provided in graph config: ", maxTokensLimit.value())); + maxCompletionTokens = it->value.GetUint(); + } + + // max_output_tokens: uint; optional + // OpenAI Responses API uses this field for output token limit. + it = doc.FindMember("max_output_tokens"); + if (it != doc.MemberEnd() && !it->value.IsNull()) { + if (!it->value.IsUint()) { + if (it->value.IsUint64()) + return absl::InvalidArgumentError("max_output_tokens value can't be greater than 4294967295"); + return absl::InvalidArgumentError("max_output_tokens is not an unsigned integer"); + } + if (maxTokensLimit.has_value() && it->value.GetUint() > maxTokensLimit.value()) + return absl::InvalidArgumentError(absl::StrCat("max_output_tokens exceeds limit provided in graph config: ", maxTokensLimit.value())); + maxOutputTokens = it->value.GetUint(); + } + + if (maxCompletionTokens.has_value() && maxOutputTokens.has_value() && maxCompletionTokens.value() != maxOutputTokens.value()) { + return absl::InvalidArgumentError("max_output_tokens and max_completion_tokens must match when both are provided"); + } + if (maxOutputTokens.has_value()) { + request.maxTokens = maxOutputTokens.value(); + } else if (maxCompletionTokens.has_value()) { + request.maxTokens = maxCompletionTokens.value(); + } + + // specific part of max_tokens validation + if (request.maxTokens == 0) { + return absl::InvalidArgumentError("max_tokens value should be greater than 0"); + } + + // parse response_format + it = doc.FindMember("response_format"); + if (it != doc.MemberEnd()) { + if (it->value.IsNull()) + return absl::OkStatus(); + if (!it->value.IsObject()) + return absl::InvalidArgumentError("response_format is not an object"); + const rapidjson::Value& responseFormat = it->value; + request.responseFormat = convertOpenAIResponseFormatToStructuralTagStringFormat(responseFormat); + } + + { + StringBuffer buffer; + Writer writer(buffer); + doc.Accept(writer); + request.processedJson = buffer.GetString(); + } + + return absl::OkStatus(); +} + absl::Status OpenAIChatCompletionsHandler::parseCommonPart(std::optional maxTokensLimit, uint32_t bestOfLimit, std::optional maxModelLength) { OVMS_PROFILE_FUNCTION(); // stream: bool; optional @@ -937,6 +1374,8 @@ absl::Status OpenAIChatCompletionsHandler::parseRequest(std::optional return status; if (endpoint == Endpoint::COMPLETIONS) status = parseCompletionsPart(); + else if (endpoint == Endpoint::RESPONSES) + status = parseResponsesPart(maxTokensLimit, allowedLocalMediaPath, allowedMediaDomains); else status = parseChatCompletionsPart(maxTokensLimit, allowedLocalMediaPath, allowedMediaDomains); @@ -987,6 +1426,16 @@ ParsedOutput OpenAIChatCompletionsHandler::parseOutputIfNeeded(const std::vector std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const std::vector& generationOutputs) { OVMS_PROFILE_FUNCTION(); + if (endpoint == Endpoint::RESPONSES) { + std::vector parsedOutputs; + usage.completionTokens = 0; + for (const ov::genai::GenerationOutput& generationOutput : generationOutputs) { + updateUsage(usage, generationOutput.generated_ids, request.echo); + parsedOutputs.push_back(parseOutputIfNeeded(generationOutput.generated_ids)); + } + return serializeResponsesUnaryResponse(parsedOutputs, usage, request, request.toolNameSchemaMap, created); + } + OpenAiJsonResponse jsonResponse; jsonResponse.StartObject(); @@ -1112,6 +1561,15 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(ov::genai::Enco OVMS_PROFILE_FUNCTION(); usage.promptTokens = results.perf_metrics.get_num_input_tokens(); usage.completionTokens = results.perf_metrics.get_num_generated_tokens(); + if (endpoint == Endpoint::RESPONSES) { + std::vector parsedOutputs; + for (const auto& tokens : results.tokens) { + updateUsage(usage, tokens, request.echo); + parsedOutputs.push_back(parseOutputIfNeeded(tokens)); + } + return serializeResponsesUnaryResponse(parsedOutputs, usage, request, request.toolNameSchemaMap, created); + } + OpenAiJsonResponse jsonResponse; jsonResponse.StartObject(); @@ -1172,6 +1630,27 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(ov::genai::VLMD OVMS_PROFILE_FUNCTION(); usage.promptTokens = results.perf_metrics.get_num_input_tokens(); usage.completionTokens = results.perf_metrics.get_num_generated_tokens(); + if (endpoint == Endpoint::RESPONSES) { + std::vector parsedOutputs; + usage.completionTokens = 0; + for (const std::string& text : results.texts) { + auto result = tokenizer.encode(text); + auto& input_ids = result.input_ids; + if (input_ids.get_shape().size() != 2) + throw std::runtime_error("input_ids should have 2 dimensions"); + if (input_ids.get_shape()[0] != 1) + throw std::runtime_error("input_ids should have 1 batch size"); + if (input_ids.get_element_type() != ov::element::i64) + throw std::runtime_error("input_ids should have i64 element type"); + + int64_t* input_ids_data = reinterpret_cast(input_ids.data()); + std::vector generatedTokens(input_ids_data, input_ids_data + input_ids.get_shape()[1]); + updateUsage(usage, generatedTokens, request.echo); + parsedOutputs.push_back(parseOutputIfNeeded(generatedTokens)); + } + return serializeResponsesUnaryResponse(parsedOutputs, usage, request, request.toolNameSchemaMap, created); + } + OpenAiJsonResponse jsonResponse; jsonResponse.StartObject(); @@ -1248,6 +1727,313 @@ std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(ov::genai::VLMD std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason) { OVMS_PROFILE_FUNCTION(); + if (endpoint == Endpoint::RESPONSES) { + const auto createdAt = std::chrono::duration_cast(created.time_since_epoch()).count(); + const std::string responseId = "resp-" + std::to_string(createdAt); + const std::string outputItemId = "msg-0"; + + auto serializeResponsesToolChoice = [this](Writer& writer) { + writer.String("tool_choice"); + if (request.toolChoice.empty()) { + writer.String("auto"); + } else if (request.toolChoice == "auto" || request.toolChoice == "none" || request.toolChoice == "required") { + writer.String(request.toolChoice.c_str()); + } else { + writer.StartObject(); + writer.String("type"); + writer.String("function"); + writer.String("name"); + writer.String(request.toolChoice.c_str()); + writer.EndObject(); + } + }; + + auto serializeResponsesTools = [this](Writer& writer) { + writer.String("tools"); + writer.StartArray(); + for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) { + writer.StartObject(); + writer.String("type"); + writer.String("function"); + writer.String("name"); + writer.String(toolName.c_str()); + writer.String("parameters"); + writer.RawValue(toolSchemaWrapper.stringRepr.c_str(), toolSchemaWrapper.stringRepr.size(), rapidjson::kObjectType); + writer.EndObject(); + } + writer.EndArray(); + }; + + auto serializeResponseObject = [this, &responseId, createdAt, &serializeResponsesToolChoice, &serializeResponsesTools](Writer& writer, const char* status, const std::string& fullOutputText, bool includeUsage) { + writer.StartObject(); + writer.String("id"); + writer.String(responseId.c_str()); + writer.String("object"); + writer.String("response"); + writer.String("created_at"); + writer.Int64(createdAt); + if (std::string(status) == "completed") { + writer.String("completed_at"); + writer.Int64(createdAt); + } + writer.String("model"); + writer.String(request.model.c_str()); + writer.String("status"); + writer.String(status); + + writer.String("parallel_tool_calls"); + writer.Bool(false); + serializeResponsesToolChoice(writer); + serializeResponsesTools(writer); + + if (request.maxTokens.has_value()) { + writer.String("max_output_tokens"); + writer.Uint64(static_cast(request.maxTokens.value())); + } + + writer.String("output"); + writer.StartArray(); + if (!fullOutputText.empty()) { + writer.StartObject(); + writer.String("id"); + writer.String("msg-0"); + writer.String("type"); + writer.String("message"); + writer.String("role"); + writer.String("assistant"); + writer.String("status"); + writer.String(std::string(status) == "completed" ? "completed" : "in_progress"); + writer.String("content"); + writer.StartArray(); + writer.StartObject(); + writer.String("type"); + writer.String("output_text"); + writer.String("text"); + writer.String(fullOutputText.c_str()); + writer.String("annotations"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + writer.EndArray(); + writer.EndObject(); + } + writer.EndArray(); + + if (includeUsage) { + writer.String("usage"); + writer.StartObject(); + writer.String("input_tokens"); + writer.Uint64(static_cast(usage.promptTokens)); + writer.String("input_tokens_details"); + writer.StartObject(); + writer.String("cached_tokens"); + writer.Uint64(0); + writer.EndObject(); + writer.String("output_tokens"); + writer.Uint64(static_cast(usage.completionTokens)); + writer.String("output_tokens_details"); + writer.StartObject(); + writer.String("reasoning_tokens"); + writer.Uint64(0); + writer.EndObject(); + writer.String("total_tokens"); + writer.Uint64(static_cast(usage.calculateTotalTokens())); + writer.EndObject(); + } + + writer.EndObject(); + }; + + auto serializeOutputItem = [&outputItemId](Writer& writer, const std::string& text, const char* status, bool withContent) { + writer.StartObject(); + writer.String("id"); + writer.String(outputItemId.c_str()); + writer.String("type"); + writer.String("message"); + writer.String("role"); + writer.String("assistant"); + writer.String("status"); + writer.String(status); + writer.String("content"); + writer.StartArray(); + if (withContent) { + writer.StartObject(); + writer.String("type"); + writer.String("output_text"); + writer.String("text"); + writer.String(text.c_str()); + writer.String("annotations"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + } + writer.EndArray(); + writer.EndObject(); + }; + + auto serializePart = [](Writer& writer, const std::string& text) { + writer.StartObject(); + writer.String("type"); + writer.String("output_text"); + writer.String("text"); + writer.String(text.c_str()); + writer.String("annotations"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + }; + + auto serializeResponsesEvent = [](const std::function&)>& eventSerializer) { + StringBuffer eventBuffer; + Writer eventWriter(eventBuffer); + eventSerializer(eventWriter); + return std::string(eventBuffer.GetString()); + }; + + std::vector events; + if (!responsesStreamingInitialized) { + events.emplace_back(serializeResponsesEvent([this, &serializeResponseObject](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.created"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("response"); + serializeResponseObject(writer, "in_progress", "", false); + writer.EndObject(); + })); + + events.emplace_back(serializeResponsesEvent([this, &outputItemId, &serializeOutputItem](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.output_item.added"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("output_index"); + writer.Uint64(0); + writer.String("item"); + serializeOutputItem(writer, "", "in_progress", false); + writer.EndObject(); + })); + + events.emplace_back(serializeResponsesEvent([this, &outputItemId, &serializePart](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.content_part.added"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("output_index"); + writer.Uint64(0); + writer.String("content_index"); + writer.Uint64(0); + writer.String("item_id"); + writer.String(outputItemId.c_str()); + writer.String("part"); + serializePart(writer, ""); + writer.EndObject(); + })); + + responsesStreamingInitialized = true; + } + + if (!chunkResponse.empty()) { + responsesStreamingOutputText += chunkResponse; + events.emplace_back(serializeResponsesEvent([this, &chunkResponse, &outputItemId](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.output_text.delta"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("output_index"); + writer.Uint64(0); + writer.String("content_index"); + writer.Uint64(0); + writer.String("item_id"); + writer.String(outputItemId.c_str()); + writer.String("delta"); + writer.String(chunkResponse.c_str()); + writer.String("logprobs"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + })); + } + + if (finishReason != ov::genai::GenerationFinishReason::NONE) { + events.emplace_back(serializeResponsesEvent([this, &outputItemId](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.output_text.done"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("output_index"); + writer.Uint64(0); + writer.String("content_index"); + writer.Uint64(0); + writer.String("item_id"); + writer.String(outputItemId.c_str()); + writer.String("text"); + writer.String(responsesStreamingOutputText.c_str()); + writer.String("logprobs"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + })); + + events.emplace_back(serializeResponsesEvent([this, &outputItemId, &serializePart](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.content_part.done"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("output_index"); + writer.Uint64(0); + writer.String("content_index"); + writer.Uint64(0); + writer.String("item_id"); + writer.String(outputItemId.c_str()); + writer.String("part"); + serializePart(writer, responsesStreamingOutputText); + writer.EndObject(); + })); + + events.emplace_back(serializeResponsesEvent([this, &serializeOutputItem](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.output_item.done"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("output_index"); + writer.Uint64(0); + writer.String("item"); + serializeOutputItem(writer, responsesStreamingOutputText, "completed", true); + writer.EndObject(); + })); + + events.emplace_back(serializeResponsesEvent([this, &serializeResponseObject](Writer& writer) { + writer.StartObject(); + writer.String("type"); + writer.String("response.completed"); + writer.String("sequence_number"); + writer.Uint64(responsesStreamingSequenceNumber++); + writer.String("response"); + serializeResponseObject(writer, "completed", responsesStreamingOutputText, true); + writer.EndObject(); + })); + } + + if (events.empty()) { + return ""; + } + + std::stringstream ss; + ss << events.front(); + for (size_t i = 1; i < events.size(); ++i) { + ss << "\n\ndata: " << events[i]; + } + return ss.str(); + } + Document doc; doc.SetObject(); Document::AllocatorType& allocator = doc.GetAllocator(); @@ -1334,6 +2120,9 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingChunk(const std::str std::string OpenAIChatCompletionsHandler::serializeStreamingUsageChunk() { OVMS_PROFILE_FUNCTION(); + if (endpoint == Endpoint::RESPONSES) { + return ""; + } StringBuffer buffer; Writer writer(buffer); diff --git a/src/llm/apis/openai_completions.hpp b/src/llm/apis/openai_completions.hpp index 0b513fd528..832e6cd316 100644 --- a/src/llm/apis/openai_completions.hpp +++ b/src/llm/apis/openai_completions.hpp @@ -47,6 +47,7 @@ namespace ovms { enum class Endpoint { CHAT_COMPLETIONS, COMPLETIONS, + RESPONSES, TOKENIZE, }; @@ -69,12 +70,16 @@ class OpenAIChatCompletionsHandler { std::chrono::time_point created; ov::genai::Tokenizer tokenizer; size_t processedTokens = 0; // tracks overall number of tokens processed by the pipeline + size_t responsesStreamingSequenceNumber = 0; + bool responsesStreamingInitialized = false; + std::string responsesStreamingOutputText; // Output parser is used to parse chat completions response to extract specific fields like tool calls and reasoning. std::unique_ptr outputParser = nullptr; absl::Status parseCompletionsPart(); absl::Status parseChatCompletionsPart(std::optional maxTokensLimit, std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains); + absl::Status parseResponsesPart(std::optional maxTokensLimit, std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains); absl::Status parseCommonPart(std::optional maxTokensLimit, uint32_t bestOfLimit, std::optional maxModelLength); ParsedOutput parseOutputIfNeeded(const std::vector& generatedIds); diff --git a/src/llm/language_model/continuous_batching/servable.cpp b/src/llm/language_model/continuous_batching/servable.cpp index 470e170a09..1c14944385 100644 --- a/src/llm/language_model/continuous_batching/servable.cpp +++ b/src/llm/language_model/continuous_batching/servable.cpp @@ -103,6 +103,15 @@ static ov::genai::GenerationOutput prepareEmptyStopReasonOutput() { return out; } +static ov::genai::GenerationOutput prepareEmptyNoneReasonOutput() { + static ov::genai::GenerationOutput out = { + std::vector(), // generated_ids + std::vector(), // generated_log_probs + 0.0f, // score + ov::genai::GenerationFinishReason::NONE}; + return out; +} + absl::Status ContinuousBatchingServable::readCompleteExecutionResults(std::shared_ptr& executionContext) { auto cbExecutionContext = std::static_pointer_cast(executionContext); if (cbExecutionContext->payload.client->isDisconnected()) { @@ -136,7 +145,11 @@ absl::Status ContinuousBatchingServable::readPartialExecutionResults(std::shared ov::genai::GenerationOutputs generationOutputs = cbExecutionContext->generationHandle->read(); RET_CHECK(generationOutputs.size() <= 1); // TODO: Support multiple generations if (generationOutputs.size() == 0) { - cbExecutionContext->generationOutputs = {prepareEmptyStopReasonOutput()}; + if (cbExecutionContext->generationHandle->get_status() == ov::genai::GenerationStatus::RUNNING) { + cbExecutionContext->generationOutputs = {prepareEmptyNoneReasonOutput()}; + } else { + cbExecutionContext->generationOutputs = {prepareEmptyStopReasonOutput()}; + } } else { cbExecutionContext->generationOutputs = {generationOutputs.begin()->second}; } diff --git a/src/llm/servable.cpp b/src/llm/servable.cpp index 75480efe37..f360c38184 100644 --- a/src/llm/servable.cpp +++ b/src/llm/servable.cpp @@ -68,10 +68,12 @@ absl::Status GenAiServable::loadRequest(std::shared_ptrendpoint = Endpoint::CHAT_COMPLETIONS; } else if (payload.uri == "/v3/completions" || payload.uri == "/v3/v1/completions") { executionContext->endpoint = Endpoint::COMPLETIONS; + } else if (payload.uri == "/v3/responses" || payload.uri == "/v3/v1/responses") { + executionContext->endpoint = Endpoint::RESPONSES; } else if (TokenizeParser::isTokenizeEndpoint(payload.uri)) { executionContext->endpoint = Endpoint::TOKENIZE; } else { - return absl::InvalidArgumentError("Wrong endpoint. Allowed endpoints: /v3/chat/completions, /v3/completions"); + return absl::InvalidArgumentError("Wrong endpoint. Allowed endpoints: /v3/chat/completions, /v3/completions, /v3/responses, /v3/tokenize"); } executionContext->payload = payload; return absl::OkStatus(); @@ -204,6 +206,50 @@ absl::Status GenAiServable::prepareInputs(std::shared_ptrapiHandler->getChatHistory().size() > 0) { +#if (PYTHON_DISABLE == 0) + bool success; + if (executionContext->apiHandler->getProcessedJson().size() > 0) { + success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, executionContext->apiHandler->getProcessedJson(), inputText); + } else { + success = PyJinjaTemplateProcessor::applyChatTemplate(getProperties()->templateProcessor, getProperties()->modelsPath, executionContext->payload.body, inputText); + } + if (!success) { + return absl::Status(absl::StatusCode::kInvalidArgument, inputText); + } +#else + ov::genai::ChatHistory& chatHistory = executionContext->apiHandler->getChatHistory(); + constexpr bool add_generation_prompt = true; + auto toolsStatus = executionContext->apiHandler->parseToolsToJsonContainer(); + if (!toolsStatus.ok()) { + return toolsStatus.status(); + } + const auto& tools = toolsStatus.value(); + auto chatTemplateKwargsStatus = executionContext->apiHandler->parseChatTemplateKwargsToJsonContainer(); + if (!chatTemplateKwargsStatus.ok()) { + return chatTemplateKwargsStatus.status(); + } + const auto& chatTemplateKwargs = chatTemplateKwargsStatus.value(); + try { + inputText = getProperties()->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {}, tools, chatTemplateKwargs); + } catch (const std::exception& e) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to apply chat template: {}", e.what()); + return absl::Status(absl::StatusCode::kInvalidArgument, "Failed to apply chat template. The model either does not have chat template or has an invalid one."); + } +#endif + if (inputText.size() == 0) { + return absl::Status(absl::StatusCode::kInvalidArgument, "Final prompt after applying chat template is empty"); + } + } else { + auto prompt = executionContext->apiHandler->getPrompt(); + if (!prompt.has_value()) { + return absl::Status(absl::StatusCode::kInvalidArgument, "input is missing"); + } + inputText = prompt.value(); + } + break; + } case Endpoint::COMPLETIONS: { inputText = executionContext->apiHandler->getPrompt().value(); break; @@ -277,8 +323,12 @@ absl::Status GenAiServable::preparePartialResponse(std::shared_ptrresponse = wrapTextInServerSideEventMessage(serializedChunk); } - if (executionContext->apiHandler->getStreamOptions().includeUsage) - executionContext->response += wrapTextInServerSideEventMessage(executionContext->apiHandler->serializeStreamingUsageChunk()); + if (executionContext->apiHandler->getStreamOptions().includeUsage) { + std::string usageChunk = executionContext->apiHandler->serializeStreamingUsageChunk(); + if (!usageChunk.empty()) { + executionContext->response += wrapTextInServerSideEventMessage(usageChunk); + } + } executionContext->response += wrapTextInServerSideEventMessage("[DONE]"); diff --git a/src/llm/visual_language_model/continuous_batching/servable.cpp b/src/llm/visual_language_model/continuous_batching/servable.cpp index be33838d9f..94aef05387 100644 --- a/src/llm/visual_language_model/continuous_batching/servable.cpp +++ b/src/llm/visual_language_model/continuous_batching/servable.cpp @@ -45,10 +45,12 @@ absl::Status VisualLanguageModelServable::loadRequest(std::shared_ptrendpoint = Endpoint::CHAT_COMPLETIONS; + } else if (payload.uri == "/v3/responses" || payload.uri == "/v3/v1/responses") { + executionContext->endpoint = Endpoint::RESPONSES; } else if (TokenizeParser::isTokenizeEndpoint(payload.uri)) { executionContext->endpoint = Endpoint::TOKENIZE; } else { - return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions endpoint or /v3/tokenize"); + return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions, /v3/responses endpoint or /v3/tokenize"); } executionContext->payload = payload; return absl::OkStatus(); @@ -67,7 +69,7 @@ absl::Status VisualLanguageModelServable::prepareInputs(std::shared_ptrapiHandler == nullptr) { return absl::Status(absl::StatusCode::kInvalidArgument, "API handler is not initialized"); } - if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS) { + if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS || executionContext->endpoint == Endpoint::RESPONSES) { ov::genai::ChatHistory& chatHistory = vlmExecutionContext->apiHandler->getChatHistory(); for (size_t i = 0; i < chatHistory.size(); i++) { diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index 2834072410..307723415a 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -53,10 +53,12 @@ absl::Status VisualLanguageModelLegacyServable::loadRequest(std::shared_ptrendpoint = Endpoint::CHAT_COMPLETIONS; + } else if (payload.uri == "/v3/responses" || payload.uri == "/v3/v1/responses") { + executionContext->endpoint = Endpoint::RESPONSES; } else if (TokenizeParser::isTokenizeEndpoint(payload.uri)) { executionContext->endpoint = Endpoint::TOKENIZE; } else { - return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions endpoint or /v3/tokenize"); + return absl::InvalidArgumentError("Wrong endpoint. VLM Servable allowed only on /v3/chat/completions, /v3/responses endpoint or /v3/tokenize"); } executionContext->payload = payload; return absl::OkStatus(); @@ -237,7 +239,7 @@ absl::Status VisualLanguageModelLegacyServable::prepareInputs(std::shared_ptrapiHandler == nullptr) { return absl::Status(absl::StatusCode::kInvalidArgument, "API handler is not initialized"); } - if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS) { + if (executionContext->endpoint == Endpoint::CHAT_COMPLETIONS || executionContext->endpoint == Endpoint::RESPONSES) { ov::genai::ChatHistory& chatHistory = vlmExecutionContext->apiHandler->getChatHistory(); for (size_t i = 0; i < chatHistory.size(); i++) { diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 94648d0e68..471a12070a 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -269,6 +269,27 @@ TEST_F(HttpOpenAIHandlerTest, Stream) { ASSERT_EQ(response, ""); } + TEST_F(HttpOpenAIHandlerTest, ResponsesStream) { + std::string requestBody = R"( + { + "model": "gpt", + "stream": true, + "input": "What is OpenVINO?" + } + )"; + + EXPECT_CALL(*writer, PartialReplyBegin(::testing::_)).WillOnce(testing::Invoke([](std::function fn) { fn(); })); + EXPECT_CALL(*writer, PartialReplyEnd()).Times(1); + EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(9); + EXPECT_CALL(*writer, IsDisconnected()).Times(9); + + ASSERT_EQ( + handler->dispatchToProcessor("/v3/responses", requestBody, &response, comp, responseComponents, writer, multiPartParser), + ovms::StatusCode::PARTIAL_END); + + ASSERT_EQ(response, ""); + } + TEST_F(HttpOpenAIHandlerTest, BodyNotAJson) { std::string requestBody = "not a json"; @@ -557,6 +578,85 @@ TEST_F(HttpOpenAIHandlerParsingTest, serializeUnaryResponseVLMSupportsToolCallsF ASSERT_NE(serialized.find("\"finish_reason\":\"tool_calls\""), std::string::npos) << serialized; } +TEST_F(HttpOpenAIHandlerParsingTest, serializeUnaryResponseForResponsesContainsOutputText) { + std::string json = R"({ + "model": "llama", + "input": "What is OpenVINO?", + "max_output_tokens": 5 + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + auto apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + + ov::genai::EncodedResults results; + ov::Tensor outputIds = tokenizer->encode("OVMS", ov::genai::add_special_tokens(false)).input_ids; + ASSERT_EQ(outputIds.get_shape().size(), 2); + ASSERT_EQ(outputIds.get_shape()[0], 1); + ASSERT_EQ(outputIds.get_element_type(), ov::element::i64); + int64_t* outputIdsData = reinterpret_cast(outputIds.data()); + results.tokens = {std::vector(outputIdsData, outputIdsData + outputIds.get_shape()[1])}; + + std::string serialized = apiHandler->serializeUnaryResponse(results); + ASSERT_NE(serialized.find("\"object\":\"response\""), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"output\":"), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"type\":\"output_text\""), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"text\":"), std::string::npos) << serialized; +} + +TEST_F(HttpOpenAIHandlerParsingTest, serializeStreamingChunkForResponsesContainsRequiredEvents) { + std::string json = R"({ + "model": "llama", + "input": "What is OpenVINO?", + "stream": true + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + auto apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + + std::string firstChunk = apiHandler->serializeStreamingChunk("Hello", ov::genai::GenerationFinishReason::NONE); + ASSERT_NE(firstChunk.find("\"type\":\"response.created\""), std::string::npos) << firstChunk; + ASSERT_NE(firstChunk.find("\"type\":\"response.output_item.added\""), std::string::npos) << firstChunk; + ASSERT_NE(firstChunk.find("\"type\":\"response.content_part.added\""), std::string::npos) << firstChunk; + ASSERT_NE(firstChunk.find("\"type\":\"response.output_text.delta\""), std::string::npos) << firstChunk; + ASSERT_NE(firstChunk.find("\"delta\":\"Hello\""), std::string::npos) << firstChunk; + + std::string finalChunk = apiHandler->serializeStreamingChunk(" world", ov::genai::GenerationFinishReason::STOP); + ASSERT_NE(finalChunk.find("\"type\":\"response.output_text.done\""), std::string::npos) << finalChunk; + ASSERT_NE(finalChunk.find("\"type\":\"response.content_part.done\""), std::string::npos) << finalChunk; + ASSERT_NE(finalChunk.find("\"type\":\"response.output_item.done\""), std::string::npos) << finalChunk; + ASSERT_NE(finalChunk.find("\"type\":\"response.completed\""), std::string::npos) << finalChunk; + ASSERT_NE(finalChunk.find("\"text\":\"Hello world\""), std::string::npos) << finalChunk; +} + +TEST_F(HttpOpenAIHandlerParsingTest, serializeStreamingUsageChunkForResponsesIsEmpty) { + std::string json = R"({ + "model": "llama", + "input": "What is OpenVINO?", + "stream": true, + "stream_options": {"include_usage": true} + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + auto apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + + ASSERT_EQ(apiHandler->serializeStreamingUsageChunk(), ""); +} + TEST_F(HttpOpenAIHandlerParsingTest, ParsingMessagesSucceedsBase64) { std::string json = R"({ "model": "llama", @@ -1318,6 +1418,432 @@ TEST_F(HttpOpenAIHandlerParsingTest, ParsingRequestWithNullParametersCompletions } } +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesMaxOutputTokensSetsMaxTokens) { + std::string json = R"({ + "model": "llama", + "input": "valid prompt", + "max_output_tokens": 7 + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + EXPECT_TRUE(apiHandler->getMaxTokens().has_value()); + EXPECT_EQ(apiHandler->getMaxTokens().value(), 7); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesStringInputCreatesUserChatMessage) { + std::string json = R"({ + "model": "llama", + "input": "What is OpenVINO?" + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + + auto& chatHistory = apiHandler->getChatHistory(); + ASSERT_EQ(chatHistory.size(), 1); + ASSERT_TRUE(chatHistory[0].contains("role")); + ASSERT_TRUE(chatHistory[0].contains("content")); + EXPECT_EQ(chatHistory[0]["role"], "user"); + EXPECT_EQ(chatHistory[0]["content"], "What is OpenVINO?"); + EXPECT_NE(apiHandler->getProcessedJson().find("\"messages\""), std::string::npos); + EXPECT_NE(apiHandler->getProcessedJson().find("\"role\":\"user\""), std::string::npos); + EXPECT_NE(apiHandler->getProcessedJson().find("\"input\":\"What is OpenVINO?\""), std::string::npos); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesConflictingOutputAndCompletionTokensFails) { + std::string json = R"({ + "model": "llama", + "input": "valid prompt", + "max_output_tokens": 5, + "max_completion_tokens": 7 + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("max_output_tokens and max_completion_tokens must match when both are provided")); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesFlatFunctionToolsSucceeds) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": "auto", + "tools": [ + { + "type": "function", + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location", "unit"] + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + EXPECT_TRUE(apiHandler->areToolsAvailable()); + EXPECT_EQ(apiHandler->getToolChoice(), "auto"); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesToolChoiceFunctionObjectSucceeds) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": { + "type": "function", + "name": "get_current_weather" + }, + "tools": [ + { + "type": "function", + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"] + } + }, + { + "type": "function", + "name": "unused_tool", + "parameters": { + "type": "object", + "properties": { + "arg": {"type": "string"} + } + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + EXPECT_TRUE(apiHandler->areToolsAvailable()); + EXPECT_EQ(apiHandler->getToolChoice(), "get_current_weather"); +} + +TEST_F(HttpOpenAIHandlerParsingTest, SerializeResponsesUnaryResponseContainsFunctionTools) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": "auto", + "tools": [ + { + "type": "function", + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"] + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + + ov::genai::EncodedResults results; + ov::Tensor outputIds = tokenizer->encode("Sunny", ov::genai::add_special_tokens(false)).input_ids; + ASSERT_EQ(outputIds.get_shape().size(), 2); + ASSERT_EQ(outputIds.get_shape()[0], 1); + ASSERT_EQ(outputIds.get_element_type(), ov::element::i64); + int64_t* outputIdsData = reinterpret_cast(outputIds.data()); + results.tokens = {std::vector(outputIdsData, outputIdsData + outputIds.get_shape()[1])}; + + std::string serialized = apiHandler->serializeUnaryResponse(results); + ASSERT_NE(serialized.find("\"object\":\"response\""), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"tools\":[{"), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"type\":\"function\""), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"name\":\"get_current_weather\""), std::string::npos) << serialized; +} + +TEST_F(HttpOpenAIHandlerParsingTest, SerializeResponsesUnaryResponseContainsFunctionToolChoiceObject) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": { + "type": "function", + "name": "get_current_weather" + }, + "tools": [ + { + "type": "function", + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"] + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + ASSERT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + + ov::genai::EncodedResults results; + ov::Tensor outputIds = tokenizer->encode("Sunny", ov::genai::add_special_tokens(false)).input_ids; + ASSERT_EQ(outputIds.get_shape().size(), 2); + ASSERT_EQ(outputIds.get_shape()[0], 1); + ASSERT_EQ(outputIds.get_element_type(), ov::element::i64); + int64_t* outputIdsData = reinterpret_cast(outputIds.data()); + results.tokens = {std::vector(outputIdsData, outputIdsData + outputIds.get_shape()[1])}; + + std::string serialized = apiHandler->serializeUnaryResponse(results); + ASSERT_NE(serialized.find("\"tool_choice\":{"), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"type\":\"function\""), std::string::npos) << serialized; + ASSERT_NE(serialized.find("\"name\":\"get_current_weather\""), std::string::npos) << serialized; +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesToolChoiceFunctionObjectMissingNameFails) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": { + "type": "function" + }, + "tools": [ + { + "type": "function", + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("tool_choice.name is not a valid string")); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesToolChoiceFunctionObjectNameNotStringFails) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": { + "type": "function", + "name": 7 + }, + "tools": [ + { + "type": "function", + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("tool_choice.name is not a valid string")); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesInputImageUrlStringSucceeds) { + std::string json = R"({ + "model": "llama", + "input": [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "what is in this image?"}, + {"type": "input_image", "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAEElEQVR4nGLK27oAEAAA//8DYAHGgEvy5AAAAABJRU5ErkJggg=="} + ] + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + EXPECT_EQ(apiHandler->getImageHistory().size(), 1); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesInputImageUrlObjectSucceeds) { + std::string json = R"({ + "model": "llama", + "input": [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "what is in this image?"}, + {"type": "input_image", "image_url": {"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAEElEQVR4nGLK27oAEAAA//8DYAHGgEvy5AAAAABJRU5ErkJggg=="}} + ] + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + EXPECT_EQ(apiHandler->getImageHistory().size(), 1); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesInputImageWithoutImageUrlFails) { + std::string json = R"({ + "model": "llama", + "input": [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "what is in this image?"}, + {"type": "input_image"} + ] + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("input_image requires image_url field")); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesInputImageUrlInvalidTypeFails) { + std::string json = R"({ + "model": "llama", + "input": [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "what is in this image?"}, + {"type": "input_image", "image_url": 123} + ] + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("input_image.image_url must be a string or object")); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesUnsupportedToolTypeFails) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": "auto", + "tools": [ + { + "type": "web_search_preview" + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("Only function tools are supported")); +} + +TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesToolChoiceNoneRemovesTools) { + std::string json = R"({ + "model": "llama", + "input": "What is the weather like in Boston today?", + "tool_choice": "none", + "tools": [ + { + "type": "function", + "name": "get_current_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + } + ] + })"; + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); + EXPECT_FALSE(apiHandler->areToolsAvailable()); + EXPECT_EQ(apiHandler->getToolChoice(), "none"); +} + // Provide get_weather2 but take none TEST_F(HttpOpenAIHandlerParsingTest, ParseRequestWithTools_Provided1_ChoiceNone) { std::string providedTools = R"( diff --git a/src/test/llm/visual_language_model/complete_flow_test.cpp b/src/test/llm/visual_language_model/complete_flow_test.cpp index 4dc22d6fa3..5f2b380556 100644 --- a/src/test/llm/visual_language_model/complete_flow_test.cpp +++ b/src/test/llm/visual_language_model/complete_flow_test.cpp @@ -49,6 +49,7 @@ class VLMServableExecutionTest : public ::testing::Test { std::unordered_map headers{{"content-type", "application/json"}}; ovms::HttpRequestComponents comp; const std::string endpointChatCompletions = "/v3/chat/completions"; + const std::string endpointResponses = "/v3/responses"; std::shared_ptr writer; std::shared_ptr multiPartParser; std::string response; @@ -129,6 +130,50 @@ static std::string createRequestBody(const std::string& modelName, const std::ve return oss.str(); } +static std::string createResponsesRequestBody(const std::string& modelName, const std::vector>& fields, bool includeText = true, int numberOfImages = 1, const std::string contentOfTheFirstMessage = "What is in this image?") { + std::ostringstream oss; + oss << R"( + { + "model": ")" + << modelName << R"(", + "input": [ + { + "role": "user", + "content": [)"; + if (includeText) { + oss << R"( + { + "type": "input_text", + "text": ")"; + oss << contentOfTheFirstMessage; + oss << R"("})"; + if (numberOfImages > 0) { + oss << ","; + } + } + for (int i = 0; i < numberOfImages; i++) { + oss << R"( + { + "type": "input_image", + "image_url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAEElEQVR4nGIy+/oREAAA//8DiQIftNKCRwAAAABJRU5ErkJggg==" + })"; + if (i < numberOfImages - 1) { + oss << ","; + } + } + oss << R"( + ] + } + ] + )"; + for (const auto& field : fields) { + oss << R"(, ")" << field.first << R"(": )" << field.second << R"()" + << "\n"; + } + oss << "\n}"; + return oss.str(); +} + class VLMServableExecutionTestParameterized : public VLMServableExecutionTest, public ::testing::WithParamInterface {}; // Unary flow @@ -304,6 +349,152 @@ TEST_P(VLMServableExecutionTestParameterized, unaryBasicWithTools) { EXPECT_STREQ(parsedResponse["model"].GetString(), modelName.c_str()); } +TEST_P(VLMServableExecutionTestParameterized, unaryResponsesWithImageInput) { + auto modelName = GetParam(); + std::vector> fields = { + {"max_output_tokens", "5"}, + {"temperature", "0.0"}}; + std::string requestBody = createResponsesRequestBody(modelName, fields); + + ovms::HttpRequestComponents responsesComp; + ASSERT_EQ(handler->parseRequestComponents(responsesComp, "POST", endpointResponses, headers), ovms::StatusCode::OK); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointResponses, requestBody, &response, responsesComp, responseComponents, writer, multiPartParser), + ovms::StatusCode::OK); + + parsedResponse.Parse(response.c_str()); + ASSERT_TRUE(parsedResponse.IsObject()); + ASSERT_TRUE(parsedResponse.HasMember("object")); + EXPECT_STREQ(parsedResponse["object"].GetString(), "response"); + ASSERT_TRUE(parsedResponse.HasMember("model")); + EXPECT_STREQ(parsedResponse["model"].GetString(), modelName.c_str()); + ASSERT_TRUE(parsedResponse.HasMember("output")); + ASSERT_TRUE(parsedResponse["output"].IsArray()); + ASSERT_GT(parsedResponse["output"].GetArray().Size(), 0); + ASSERT_TRUE(parsedResponse["output"][0].IsObject()); + ASSERT_TRUE(parsedResponse["output"][0].HasMember("type")); + EXPECT_STREQ(parsedResponse["output"][0]["type"].GetString(), "message"); + ASSERT_TRUE(parsedResponse["output"][0].HasMember("content")); + ASSERT_TRUE(parsedResponse["output"][0]["content"].IsArray()); + ASSERT_GT(parsedResponse["output"][0]["content"].GetArray().Size(), 0); + ASSERT_TRUE(parsedResponse["output"][0]["content"][0].HasMember("type")); + EXPECT_STREQ(parsedResponse["output"][0]["content"][0]["type"].GetString(), "output_text"); + + ASSERT_TRUE(parsedResponse.HasMember("usage")); + ASSERT_TRUE(parsedResponse["usage"].IsObject()); + ASSERT_TRUE(parsedResponse["usage"].HasMember("input_tokens")); + ASSERT_TRUE(parsedResponse["usage"].HasMember("output_tokens")); + ASSERT_TRUE(parsedResponse["usage"].HasMember("total_tokens")); +} + +TEST_P(VLMServableExecutionTestParameterized, unaryResponsesOnlyImageInput) { + auto modelName = GetParam(); + std::vector> fields = { + {"max_output_tokens", "5"}, + {"temperature", "0.0"}}; + std::string requestBody = createResponsesRequestBody(modelName, fields, false, 1); + + ovms::HttpRequestComponents responsesComp; + ASSERT_EQ(handler->parseRequestComponents(responsesComp, "POST", endpointResponses, headers), ovms::StatusCode::OK); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointResponses, requestBody, &response, responsesComp, responseComponents, writer, multiPartParser), + ovms::StatusCode::OK); + + parsedResponse.Parse(response.c_str()); + ASSERT_TRUE(parsedResponse.IsObject()); + ASSERT_TRUE(parsedResponse.HasMember("object")); + EXPECT_STREQ(parsedResponse["object"].GetString(), "response"); + ASSERT_TRUE(parsedResponse.HasMember("output")); + ASSERT_TRUE(parsedResponse["output"].IsArray()); + ASSERT_GT(parsedResponse["output"].GetArray().Size(), 0); +} + +TEST_P(VLMServableExecutionTestParameterized, unaryResponsesWithTools) { + auto modelName = GetParam(); + std::vector> fields = { + {"max_output_tokens", "5"}, + {"temperature", "0.0"}, + {"tool_choice", R"("auto")"}, + {"tools", R"([ + { + "type": "function", + "name": "get_weather", + "description": "Get weather by city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + }, + "required": ["city"] + } + } + ])"}}; + std::string requestBody = createResponsesRequestBody(modelName, fields); + + ovms::HttpRequestComponents responsesComp; + ASSERT_EQ(handler->parseRequestComponents(responsesComp, "POST", endpointResponses, headers), ovms::StatusCode::OK); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointResponses, requestBody, &response, responsesComp, responseComponents, writer, multiPartParser), + ovms::StatusCode::OK); + + parsedResponse.Parse(response.c_str()); + ASSERT_TRUE(parsedResponse.IsObject()); + ASSERT_TRUE(parsedResponse.HasMember("object")); + EXPECT_STREQ(parsedResponse["object"].GetString(), "response"); + ASSERT_TRUE(parsedResponse.HasMember("tools")); + ASSERT_TRUE(parsedResponse["tools"].IsArray()); + ASSERT_GT(parsedResponse["tools"].GetArray().Size(), 0); + ASSERT_TRUE(parsedResponse.HasMember("tool_choice")); + ASSERT_TRUE(parsedResponse["tool_choice"].IsString()); + EXPECT_STREQ(parsedResponse["tool_choice"].GetString(), "auto"); +} + +TEST_P(VLMServableExecutionTestParameterized, unaryResponsesWithFunctionToolChoiceObject) { + auto modelName = GetParam(); + std::vector> fields = { + {"max_output_tokens", "5"}, + {"temperature", "0.0"}, + {"tool_choice", R"({"type":"function","name":"get_weather"})"}, + {"tools", R"([ + { + "type": "function", + "name": "get_weather", + "description": "Get weather by city", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + }, + "required": ["city"] + } + } + ])"}}; + std::string requestBody = createResponsesRequestBody(modelName, fields); + + ovms::HttpRequestComponents responsesComp; + ASSERT_EQ(handler->parseRequestComponents(responsesComp, "POST", endpointResponses, headers), ovms::StatusCode::OK); + + ASSERT_EQ( + handler->dispatchToProcessor(endpointResponses, requestBody, &response, responsesComp, responseComponents, writer, multiPartParser), + ovms::StatusCode::OK); + + parsedResponse.Parse(response.c_str()); + ASSERT_TRUE(parsedResponse.IsObject()); + ASSERT_TRUE(parsedResponse.HasMember("tool_choice")); + ASSERT_TRUE(parsedResponse["tool_choice"].IsObject()); + ASSERT_TRUE(parsedResponse["tool_choice"].HasMember("type")); + EXPECT_STREQ(parsedResponse["tool_choice"]["type"].GetString(), "function"); + ASSERT_TRUE(parsedResponse["tool_choice"].HasMember("name")); + EXPECT_STREQ(parsedResponse["tool_choice"]["name"].GetString(), "get_weather"); +} + // Stream flow TEST_P(VLMServableExecutionTestParameterized, streamBasic) { From 5a12817dbc47fc08a7273f7be26479556500f36a Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Thu, 5 Mar 2026 16:03:05 +0100 Subject: [PATCH 2/8] fix --- src/llm/apis/openai_completions.cpp | 325 ++++++++++-------- src/llm/apis/openai_completions.hpp | 1 + .../continuous_batching/servable.cpp | 15 +- 3 files changed, 190 insertions(+), 151 deletions(-) diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 0d383a48f4..66b40b7287 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -299,125 +299,6 @@ absl::Status normalizeResponsesFunctionToolsInPlace(rapidjson::Document& doc) { return absl::OkStatus(); } -absl::Status normalizeResponsesInputToMessagesInPlace(rapidjson::Document& doc) { - auto inputIt = doc.FindMember("input"); - if (inputIt == doc.MemberEnd()) { - return absl::InvalidArgumentError("input missing in request"); - } - auto& allocator = doc.GetAllocator(); - if (inputIt->value.IsString()) { - rapidjson::Value messages(rapidjson::kArrayType); - rapidjson::Value messageObj(rapidjson::kObjectType); - messageObj.AddMember("role", "user", allocator); - messageObj.AddMember("content", rapidjson::Value(inputIt->value.GetString(), allocator), allocator); - messages.PushBack(messageObj, allocator); - - auto existingMessages = doc.FindMember("messages"); - if (existingMessages != doc.MemberEnd()) { - existingMessages->value = messages; - } else { - doc.AddMember("messages", messages, allocator); - } - return absl::OkStatus(); - } - if (!inputIt->value.IsArray()) { - return absl::InvalidArgumentError("input is not a string or array"); - } - - rapidjson::Value messages(rapidjson::kArrayType); - for (auto& item : inputIt->value.GetArray()) { - if (!item.IsObject()) { - return absl::InvalidArgumentError("input array items must be objects"); - } - - auto itemObj = item.GetObject(); - auto roleIt = itemObj.FindMember("role"); - if (roleIt == itemObj.MemberEnd() || !roleIt->value.IsString()) { - return absl::InvalidArgumentError("input item role is missing or invalid"); - } - - rapidjson::Value messageObj(rapidjson::kObjectType); - messageObj.AddMember("role", rapidjson::Value(roleIt->value.GetString(), allocator), allocator); - - auto contentIt = itemObj.FindMember("content"); - if (contentIt == itemObj.MemberEnd()) { - return absl::InvalidArgumentError("input item content is missing"); - } - - if (contentIt->value.IsString()) { - messageObj.AddMember("content", rapidjson::Value(contentIt->value.GetString(), allocator), allocator); - messages.PushBack(messageObj, allocator); - continue; - } - - if (!contentIt->value.IsArray()) { - return absl::InvalidArgumentError("input item content must be a string or array"); - } - - rapidjson::Value normalizedContent(rapidjson::kArrayType); - for (auto& contentItem : contentIt->value.GetArray()) { - if (!contentItem.IsObject()) { - return absl::InvalidArgumentError("input content items must be objects"); - } - auto contentObj = contentItem.GetObject(); - auto typeIt = contentObj.FindMember("type"); - if (typeIt == contentObj.MemberEnd() || !typeIt->value.IsString()) { - return absl::InvalidArgumentError("input content item type is missing or invalid"); - } - - std::string type = typeIt->value.GetString(); - if (type == "input_text") { - auto textIt = contentObj.FindMember("text"); - if (textIt == contentObj.MemberEnd() || !textIt->value.IsString()) { - return absl::InvalidArgumentError("input_text requires a valid text field"); - } - rapidjson::Value textObj(rapidjson::kObjectType); - textObj.AddMember("type", "text", allocator); - textObj.AddMember("text", rapidjson::Value(textIt->value.GetString(), allocator), allocator); - normalizedContent.PushBack(textObj, allocator); - } else if (type == "input_image") { - std::string imageUrl; - auto imageUrlIt = contentObj.FindMember("image_url"); - if (imageUrlIt == contentObj.MemberEnd()) { - return absl::InvalidArgumentError("input_image requires image_url field"); - } - if (imageUrlIt->value.IsString()) { - imageUrl = imageUrlIt->value.GetString(); - } else if (imageUrlIt->value.IsObject()) { - auto imageUrlObj = imageUrlIt->value.GetObject(); - auto urlIt = imageUrlObj.FindMember("url"); - if (urlIt == imageUrlObj.MemberEnd() || !urlIt->value.IsString()) { - return absl::InvalidArgumentError("input_image.image_url.url is missing or invalid"); - } - imageUrl = urlIt->value.GetString(); - } else { - return absl::InvalidArgumentError("input_image.image_url must be a string or object"); - } - - rapidjson::Value imageUrlObj(rapidjson::kObjectType); - imageUrlObj.AddMember("url", rapidjson::Value(imageUrl.c_str(), allocator), allocator); - - rapidjson::Value imageObj(rapidjson::kObjectType); - imageObj.AddMember("type", "image_url", allocator); - imageObj.AddMember("image_url", imageUrlObj, allocator); - normalizedContent.PushBack(imageObj, allocator); - } else { - return absl::InvalidArgumentError("Unsupported content type"); - } - } - messageObj.AddMember("content", normalizedContent, allocator); - messages.PushBack(messageObj, allocator); - } - - auto existingMessages = doc.FindMember("messages"); - if (existingMessages != doc.MemberEnd()) { - existingMessages->value = messages; - } else { - doc.AddMember("messages", messages, allocator); - } - return absl::OkStatus(); -} - } // namespace absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() { @@ -570,6 +451,193 @@ absl::Status OpenAIChatCompletionsHandler::ensureArgumentsInToolCalls(Value& mes return absl::OkStatus(); } +absl::Status OpenAIChatCompletionsHandler::parseResponsesInputDirectly(std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains) { + auto inputIt = doc.FindMember("input"); + if (inputIt == doc.MemberEnd()) { + return absl::InvalidArgumentError("input missing in request"); + } + + auto& allocator = doc.GetAllocator(); + rapidjson::Value messages(rapidjson::kArrayType); + + if (inputIt->value.IsString()) { + request.prompt = inputIt->value.GetString(); + if (!request.prompt.has_value() || request.prompt.value().empty()) { + return absl::InvalidArgumentError("input cannot be empty"); + } + + request.chatHistory.push_back({}); + request.chatHistory.last()["role"] = "user"; + request.chatHistory.last()["content"] = request.prompt.value(); + + rapidjson::Value messageObj(rapidjson::kObjectType); + messageObj.AddMember("role", "user", allocator); + messageObj.AddMember("content", rapidjson::Value(request.prompt->c_str(), allocator), allocator); + messages.PushBack(messageObj, allocator); + } else if (inputIt->value.IsArray()) { + if (inputIt->value.GetArray().Size() == 0) { + return absl::InvalidArgumentError("Messages array cannot be empty"); + } + + for (size_t i = 0; i < inputIt->value.GetArray().Size(); ++i) { + auto& item = inputIt->value.GetArray()[i]; + if (!item.IsObject()) { + return absl::InvalidArgumentError("input array items must be objects"); + } + + auto itemObj = item.GetObject(); + auto roleIt = itemObj.FindMember("role"); + if (roleIt == itemObj.MemberEnd() || !roleIt->value.IsString()) { + return absl::InvalidArgumentError("input item role is missing or invalid"); + } + + request.chatHistory.push_back({}); + request.chatHistory.last()["role"] = roleIt->value.GetString(); + + rapidjson::Value messageObj(rapidjson::kObjectType); + messageObj.AddMember("role", rapidjson::Value(roleIt->value.GetString(), allocator), allocator); + + auto contentIt = itemObj.FindMember("content"); + if (contentIt == itemObj.MemberEnd()) { + return absl::InvalidArgumentError("input item content is missing"); + } + + if (contentIt->value.IsString()) { + messageObj.AddMember("content", rapidjson::Value(contentIt->value.GetString(), allocator), allocator); + request.chatHistory.last()["content"] = contentIt->value.GetString(); + messages.PushBack(messageObj, allocator); + continue; + } + + if (!contentIt->value.IsArray()) { + return absl::InvalidArgumentError("input item content must be a string or array"); + } + if (contentIt->value.GetArray().Size() == 0) { + return absl::InvalidArgumentError("Invalid message structure - content array is empty"); + } + + std::string contentText; + for (auto& contentItem : contentIt->value.GetArray()) { + if (!contentItem.IsObject()) { + return absl::InvalidArgumentError("input content items must be objects"); + } + auto contentObj = contentItem.GetObject(); + auto typeIt = contentObj.FindMember("type"); + if (typeIt == contentObj.MemberEnd() || !typeIt->value.IsString()) { + return absl::InvalidArgumentError("input content item type is missing or invalid"); + } + + const std::string type = typeIt->value.GetString(); + if (type == "input_text") { + auto textIt = contentObj.FindMember("text"); + if (textIt == contentObj.MemberEnd() || !textIt->value.IsString()) { + return absl::InvalidArgumentError("input_text requires a valid text field"); + } + contentText = textIt->value.GetString(); + } else if (type == "input_image") { + std::string imageUrl; + auto imageUrlIt = contentObj.FindMember("image_url"); + if (imageUrlIt == contentObj.MemberEnd()) { + return absl::InvalidArgumentError("input_image requires image_url field"); + } + if (imageUrlIt->value.IsString()) { + imageUrl = imageUrlIt->value.GetString(); + } else if (imageUrlIt->value.IsObject()) { + auto imageUrlObj = imageUrlIt->value.GetObject(); + auto urlIt = imageUrlObj.FindMember("url"); + if (urlIt == imageUrlObj.MemberEnd() || !urlIt->value.IsString()) { + return absl::InvalidArgumentError("input_image.image_url.url is missing or invalid"); + } + imageUrl = urlIt->value.GetString(); + } else { + return absl::InvalidArgumentError("input_image.image_url must be a string or object"); + } + + std::string pattern = "base64,"; + std::size_t pos = imageUrl.find(pattern); + std::string decoded; + ov::Tensor tensor; + if (pos != std::string::npos) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Loading image from base64 string"); + size_t offset = pos + pattern.length(); + if (!absl::Base64Unescape(std::string_view(imageUrl.data() + offset, imageUrl.size() - offset), &decoded)) { + return absl::InvalidArgumentError("Invalid base64 string in request"); + } + try { + tensor = loadImageStbiFromMemory(decoded); + } catch (std::runtime_error& e) { + std::stringstream ss; + ss << "Image parsing failed: " << e.what(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, ss.str()); + return absl::InvalidArgumentError(ss.str()); + } + } else if (std::regex_match(imageUrl.c_str(), std::regex("^(http|https|ftp|sftp|)://(.*)"))) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Loading image using curl"); + int64_t sizeLimit = 20000000; // restrict single image size to 20MB + if (!allowedMediaDomains.has_value() || !isDomainAllowed(allowedMediaDomains.value(), imageUrl.c_str())) { + return absl::InvalidArgumentError("Given url does not match any allowed domain from allowed_media_domains"); + } + auto status = downloadImage(imageUrl.c_str(), decoded, sizeLimit); + if (status != absl::OkStatus()) { + return status; + } + try { + tensor = loadImageStbiFromMemory(decoded); + } catch (std::runtime_error& e) { + std::stringstream ss; + ss << "Image parsing failed: " << e.what(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, ss.str()); + return absl::InvalidArgumentError("Image parsing failed"); + } + } else { + if (!allowedLocalMediaPath.has_value()) { + return absl::InvalidArgumentError("Loading images from local filesystem is disabled."); + } + if (FileSystem::isPathEscaped(imageUrl)) { + std::stringstream ss; + ss << "Path " << imageUrl.c_str() << " escape with .. is forbidden."; + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, ss.str()); + return absl::InvalidArgumentError(ss.str()); + } + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Loading image from local filesystem"); + const auto firstMissmatch = std::mismatch(imageUrl.begin(), imageUrl.end(), allowedLocalMediaPath.value().begin(), allowedLocalMediaPath.value().end()); + if (firstMissmatch.second != allowedLocalMediaPath.value().end()) { + return absl::InvalidArgumentError("Given filepath is not subpath of allowed_local_media_path"); + } + try { + tensor = loadImageStbiFromFile(imageUrl.c_str()); + } catch (std::runtime_error& e) { + std::stringstream ss; + ss << "Image file " << imageUrl.c_str() << " parsing failed: " << e.what(); + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, ss.str()); + return absl::InvalidArgumentError(ss.str()); + } + } + request.imageHistory.push_back({i, tensor}); + } else { + return absl::InvalidArgumentError("Unsupported content type"); + } + } + + messageObj.AddMember("content", rapidjson::Value(contentText.c_str(), allocator), allocator); + request.chatHistory.last()["content"] = contentText; + messages.PushBack(messageObj, allocator); + } + } else { + return absl::InvalidArgumentError("input is not a string or array"); + } + + auto existingMessages = doc.FindMember("messages"); + if (existingMessages != doc.MemberEnd()) { + existingMessages->value = messages; + } else { + doc.AddMember("messages", messages, allocator); + } + + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Parsed responses input directly to chat history"); + return absl::OkStatus(); +} + absl::Status OpenAIChatCompletionsHandler::parseMessages(std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains) { auto it = doc.FindMember("messages"); if (it == doc.MemberEnd()) @@ -986,24 +1054,7 @@ absl::Status OpenAIChatCompletionsHandler::parseResponsesPart(std::optionalvalue.IsString()) { - request.prompt = it->value.GetString(); - if (!request.prompt.has_value() || !request.prompt.value().size()) { - return absl::InvalidArgumentError("input cannot be empty"); - } - } - - auto messagesStatus = parseMessages(allowedLocalMediaPath, allowedMediaDomains); + auto messagesStatus = parseResponsesInputDirectly(allowedLocalMediaPath, allowedMediaDomains); if (!messagesStatus.ok()) { return messagesStatus; } diff --git a/src/llm/apis/openai_completions.hpp b/src/llm/apis/openai_completions.hpp index 832e6cd316..254e02050c 100644 --- a/src/llm/apis/openai_completions.hpp +++ b/src/llm/apis/openai_completions.hpp @@ -80,6 +80,7 @@ class OpenAIChatCompletionsHandler { absl::Status parseCompletionsPart(); absl::Status parseChatCompletionsPart(std::optional maxTokensLimit, std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains); absl::Status parseResponsesPart(std::optional maxTokensLimit, std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains); + absl::Status parseResponsesInputDirectly(std::optional allowedLocalMediaPath, std::optional> allowedMediaDomains); absl::Status parseCommonPart(std::optional maxTokensLimit, uint32_t bestOfLimit, std::optional maxModelLength); ParsedOutput parseOutputIfNeeded(const std::vector& generatedIds); diff --git a/src/llm/language_model/continuous_batching/servable.cpp b/src/llm/language_model/continuous_batching/servable.cpp index 1c14944385..470e170a09 100644 --- a/src/llm/language_model/continuous_batching/servable.cpp +++ b/src/llm/language_model/continuous_batching/servable.cpp @@ -103,15 +103,6 @@ static ov::genai::GenerationOutput prepareEmptyStopReasonOutput() { return out; } -static ov::genai::GenerationOutput prepareEmptyNoneReasonOutput() { - static ov::genai::GenerationOutput out = { - std::vector(), // generated_ids - std::vector(), // generated_log_probs - 0.0f, // score - ov::genai::GenerationFinishReason::NONE}; - return out; -} - absl::Status ContinuousBatchingServable::readCompleteExecutionResults(std::shared_ptr& executionContext) { auto cbExecutionContext = std::static_pointer_cast(executionContext); if (cbExecutionContext->payload.client->isDisconnected()) { @@ -145,11 +136,7 @@ absl::Status ContinuousBatchingServable::readPartialExecutionResults(std::shared ov::genai::GenerationOutputs generationOutputs = cbExecutionContext->generationHandle->read(); RET_CHECK(generationOutputs.size() <= 1); // TODO: Support multiple generations if (generationOutputs.size() == 0) { - if (cbExecutionContext->generationHandle->get_status() == ov::genai::GenerationStatus::RUNNING) { - cbExecutionContext->generationOutputs = {prepareEmptyNoneReasonOutput()}; - } else { - cbExecutionContext->generationOutputs = {prepareEmptyStopReasonOutput()}; - } + cbExecutionContext->generationOutputs = {prepareEmptyStopReasonOutput()}; } else { cbExecutionContext->generationOutputs = {generationOutputs.begin()->second}; } From 6168f25f172a5adda2c6d7f537d3bb055e8db318 Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Thu, 5 Mar 2026 16:04:40 +0100 Subject: [PATCH 3/8] style --- src/test/http_openai_handler_test.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 471a12070a..e0b3a40bba 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -269,7 +269,7 @@ TEST_F(HttpOpenAIHandlerTest, Stream) { ASSERT_EQ(response, ""); } - TEST_F(HttpOpenAIHandlerTest, ResponsesStream) { +TEST_F(HttpOpenAIHandlerTest, ResponsesStream) { std::string requestBody = R"( { "model": "gpt", @@ -284,11 +284,11 @@ TEST_F(HttpOpenAIHandlerTest, Stream) { EXPECT_CALL(*writer, IsDisconnected()).Times(9); ASSERT_EQ( - handler->dispatchToProcessor("/v3/responses", requestBody, &response, comp, responseComponents, writer, multiPartParser), - ovms::StatusCode::PARTIAL_END); + handler->dispatchToProcessor("/v3/responses", requestBody, &response, comp, responseComponents, writer, multiPartParser), + ovms::StatusCode::PARTIAL_END); ASSERT_EQ(response, ""); - } +} TEST_F(HttpOpenAIHandlerTest, BodyNotAJson) { std::string requestBody = "not a json"; From 205030637b5f9e2c03e26b0af9a314ad78a6bfe6 Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Thu, 5 Mar 2026 16:25:36 +0100 Subject: [PATCH 4/8] fix --- src/test/http_openai_handler_test.cpp | 269 ++++++++++++++++++++------ 1 file changed, 215 insertions(+), 54 deletions(-) diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index e0b3a40bba..5a79272f9d 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -14,6 +14,7 @@ // limitations under the License. //***************************************************************************** #include +#include #include #include #include @@ -314,60 +315,6 @@ TEST_F(HttpOpenAIHandlerTest, JsonBodyValidButNotAnObject) { ASSERT_EQ(status.string(), "The file is not valid json - JSON body must be an object"); } -TEST_F(HttpOpenAIHandlerTest, ModelFieldMissing) { - std::string requestBody = R"( - { - "stream": true, - "messages": [] - } - )"; - - EXPECT_CALL(*writer, PartialReplyEnd()).Times(0); - EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(0); - EXPECT_CALL(*writer, IsDisconnected()).Times(0); - - auto status = handler->dispatchToProcessor("/v3/completions", requestBody, &response, comp, responseComponents, writer, multiPartParser); - ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); - ASSERT_EQ(status.string(), "The file is not valid json - model field is missing in JSON body"); -} - -TEST_F(HttpOpenAIHandlerTest, ModelFieldNotAString) { - std::string requestBody = R"( - { - "model": 2, - "stream": true, - "messages": [] - } - )"; - - EXPECT_CALL(*writer, PartialReplyEnd()).Times(0); - EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(0); - EXPECT_CALL(*writer, IsDisconnected()).Times(0); - - auto status = handler->dispatchToProcessor("/v3/completions", requestBody, &response, comp, responseComponents, writer, multiPartParser); - ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); - ASSERT_EQ(status.string(), "The file is not valid json - model field is not a string"); -} - -TEST_F(HttpOpenAIHandlerTest, StreamFieldNotABoolean) { - std::string requestBody = R"( - { - "model": "gpt", - "stream": 2, - "messages": [] - } - )"; - - EXPECT_CALL(*writer, PartialReplyBegin(::testing::_)).Times(0); - EXPECT_CALL(*writer, PartialReplyEnd()).Times(0); - EXPECT_CALL(*writer, PartialReply(::testing::_)).Times(0); - EXPECT_CALL(*writer, IsDisconnected()).Times(0); - - auto status = handler->dispatchToProcessor("/v3/completions", requestBody, &response, comp, responseComponents, writer, multiPartParser); - ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID); - ASSERT_EQ(status.string(), "The file is not valid json - stream field is not a boolean"); -} - TEST_F(HttpOpenAIHandlerTest, GraphWithANameDoesNotExist) { std::string requestBody = R"( { @@ -423,6 +370,220 @@ class HttpOpenAIHandlerParsingTest : public ::testing::Test { } }; +class HttpOpenAIHandlerCommonParsingValidationTest : public HttpOpenAIHandlerParsingTest, + public ::testing::WithParamInterface { +protected: + ovms::Endpoint endpoint() const { + return GetParam(); + } + + std::string createRequestWithRawStreamValue(const std::string& streamRawValue) const { + if (endpoint() == ovms::Endpoint::COMPLETIONS) { + return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"prompt\":\"valid prompt\"}"; + } + if (endpoint() == ovms::Endpoint::RESPONSES) { + return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"input\":\"valid prompt\"}"; + } + return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; + } + + std::string createRequestWithoutModel() const { + if (endpoint() == ovms::Endpoint::COMPLETIONS) { + return "{\"prompt\":\"valid prompt\"}"; + } + if (endpoint() == ovms::Endpoint::RESPONSES) { + return "{\"input\":\"valid prompt\"}"; + } + return "{\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; + } + + std::string createRequestWithNonStringModel() const { + if (endpoint() == ovms::Endpoint::COMPLETIONS) { + return "{\"model\":2,\"prompt\":\"valid prompt\"}"; + } + if (endpoint() == ovms::Endpoint::RESPONSES) { + return "{\"model\":2,\"input\":\"valid prompt\"}"; + } + return "{\"model\":2,\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; + } +}; + +TEST_P(HttpOpenAIHandlerCommonParsingValidationTest, StreamFieldNotABooleanFails) { + std::string json = createRequestWithRawStreamValue("2"); + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("Stream is not bool")); +} + +TEST_P(HttpOpenAIHandlerCommonParsingValidationTest, ModelFieldMissingFails) { + std::string json = createRequestWithoutModel(); + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("model missing in request")); +} + +TEST_P(HttpOpenAIHandlerCommonParsingValidationTest, ModelFieldNotStringFails) { + std::string json = createRequestWithNonStringModel(); + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); + + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("model is not a string")); +} + +INSTANTIATE_TEST_SUITE_P( + CommonParsingValidation, + HttpOpenAIHandlerCommonParsingValidationTest, + ::testing::Values(ovms::Endpoint::CHAT_COMPLETIONS, ovms::Endpoint::COMPLETIONS, ovms::Endpoint::RESPONSES), + [](const testing::TestParamInfo& info) { + switch (info.param) { + case ovms::Endpoint::CHAT_COMPLETIONS: + return "ChatCompletions"; + case ovms::Endpoint::COMPLETIONS: + return "Completions"; + case ovms::Endpoint::RESPONSES: + return "Responses"; + default: + return "Unknown"; + } + }); + + class HttpOpenAIHandlerChatAndResponsesParsingTest : public HttpOpenAIHandlerParsingTest, + public ::testing::WithParamInterface { + protected: + ovms::Endpoint endpoint() const { + return GetParam(); + } + + std::string createTextRequest(const std::string& text, const std::string& extraJsonFields = "") const { + if (endpoint() == ovms::Endpoint::RESPONSES) { + return std::string("{\"model\":\"llama\",\"input\":\"") + text + "\"" + extraJsonFields + "}"; + } + return std::string("{\"model\":\"llama\",\"messages\":[{\"role\":\"user\",\"content\":\"") + text + "\"}]" + extraJsonFields + "}"; + } + + std::string createMultimodalRequestWithImageUrl(const std::string& dataUrl) const { + if (endpoint() == ovms::Endpoint::RESPONSES) { + return std::string("{\"model\":\"llama\",\"input\":[{\"role\":\"user\",\"content\":[{\"type\":\"input_text\",\"text\":\"what is in this image?\"},{\"type\":\"input_image\",\"image_url\":\"") + dataUrl + "\"}]}] }"; + } + return std::string("{\"model\":\"llama\",\"messages\":[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"what is in this image?\"},{\"type\":\"image_url\",\"image_url\":{\"url\":\"") + dataUrl + "\"}}]}]}"; + } + + std::string createToolRequest(const std::string& toolChoiceJson) const { + std::string base = createTextRequest("What is the weather like in Boston today?", ",\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\"}},\"required\":[\"location\"]}}}]"); + if (toolChoiceJson.empty()) { + return base; + } + base.pop_back(); // remove trailing '}' + base += ",\"tool_choice\":" + toolChoiceJson + "}"; + return base; + } + + std::shared_ptr parseCurrentRequest(const std::string& json) { + doc.Parse(json.c_str()); + EXPECT_FALSE(doc.HasParseError()) << json; + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()) << json; + return apiHandler; + } + }; + + TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingTextInputCreatesUserChatMessage) { + std::string json = createTextRequest("What is OpenVINO?"); + auto apiHandler = parseCurrentRequest(json); + + auto& chatHistory = apiHandler->getChatHistory(); + ASSERT_EQ(chatHistory.size(), 1); + ASSERT_TRUE(chatHistory[0].contains("role")); + ASSERT_TRUE(chatHistory[0].contains("content")); + EXPECT_EQ(chatHistory[0]["role"], "user"); + EXPECT_EQ(chatHistory[0]["content"], "What is OpenVINO?"); + if (endpoint() == ovms::Endpoint::RESPONSES) { + EXPECT_NE(apiHandler->getProcessedJson().find("\"messages\""), std::string::npos); + } else { + EXPECT_TRUE(apiHandler->getProcessedJson().empty()); + } + } + + TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingTokenLimitSetsMaxTokens) { + std::string tokenField = endpoint() == ovms::Endpoint::RESPONSES ? "max_output_tokens" : "max_completion_tokens"; + std::string json = createTextRequest("valid prompt", ",\"" + tokenField + "\":7"); + auto apiHandler = parseCurrentRequest(json); + + EXPECT_TRUE(apiHandler->getMaxTokens().has_value()); + EXPECT_EQ(apiHandler->getMaxTokens().value(), 7); + } + + TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingFunctionToolsWithAutoChoiceSucceeds) { + std::string json = createToolRequest("\"auto\""); + auto apiHandler = parseCurrentRequest(json); + + EXPECT_TRUE(apiHandler->areToolsAvailable()); + EXPECT_EQ(apiHandler->getToolChoice(), "auto"); + } + + TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingToolChoiceFunctionObjectSucceeds) { + std::string json = createToolRequest("{\"type\":\"function\",\"function\":{\"name\":\"get_current_weather\"}}"); + auto apiHandler = parseCurrentRequest(json); + + EXPECT_TRUE(apiHandler->areToolsAvailable()); + EXPECT_EQ(apiHandler->getToolChoice(), "get_current_weather"); + } + + TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingToolChoiceNoneRemovesTools) { + std::string json = createToolRequest("\"none\""); + auto apiHandler = parseCurrentRequest(json); + + EXPECT_FALSE(apiHandler->areToolsAvailable()); + EXPECT_EQ(apiHandler->getToolChoice(), "none"); + } + + TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingMultimodalInputImageSucceeds) { + const std::string base64Image = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAEElEQVR4nGLK27oAEAAA//8DYAHGgEvy5AAAAABJRU5ErkJggg=="; + std::string json = createMultimodalRequestWithImageUrl(base64Image); + auto apiHandler = parseCurrentRequest(json); + + EXPECT_EQ(apiHandler->getImageHistory().size(), 1); + } + + INSTANTIATE_TEST_SUITE_P( + ChatAndResponses, + HttpOpenAIHandlerChatAndResponsesParsingTest, + ::testing::Values(ovms::Endpoint::CHAT_COMPLETIONS, ovms::Endpoint::RESPONSES), + [](const testing::TestParamInfo& info) { + switch (info.param) { + case ovms::Endpoint::CHAT_COMPLETIONS: + return "ChatCompletions"; + case ovms::Endpoint::RESPONSES: + return "Responses"; + default: + return "Unknown"; + } + }); + static std::vector createHermes3ToolCallTokens(ov::genai::Tokenizer& tokenizer) { std::string toolCall = R"({"name": "example_tool", "arguments": {"arg1": "value1", "arg2": 42}})"; auto generatedTensor = tokenizer.encode(toolCall, ov::genai::add_special_tokens(true)).input_ids; From 299e2bea9312ea8b6a9d17c59a84ac31adad22a0 Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Thu, 5 Mar 2026 16:27:12 +0100 Subject: [PATCH 5/8] style --- src/test/http_openai_handler_test.cpp | 242 +++++++++++++------------- 1 file changed, 121 insertions(+), 121 deletions(-) diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 5a79272f9d..232e7b4ae4 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -371,147 +371,147 @@ class HttpOpenAIHandlerParsingTest : public ::testing::Test { }; class HttpOpenAIHandlerCommonParsingValidationTest : public HttpOpenAIHandlerParsingTest, - public ::testing::WithParamInterface { + public ::testing::WithParamInterface { protected: - ovms::Endpoint endpoint() const { - return GetParam(); - } - - std::string createRequestWithRawStreamValue(const std::string& streamRawValue) const { - if (endpoint() == ovms::Endpoint::COMPLETIONS) { - return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"prompt\":\"valid prompt\"}"; - } - if (endpoint() == ovms::Endpoint::RESPONSES) { - return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"input\":\"valid prompt\"}"; + ovms::Endpoint endpoint() const { + return GetParam(); } - return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; - } - std::string createRequestWithoutModel() const { - if (endpoint() == ovms::Endpoint::COMPLETIONS) { - return "{\"prompt\":\"valid prompt\"}"; - } - if (endpoint() == ovms::Endpoint::RESPONSES) { - return "{\"input\":\"valid prompt\"}"; + std::string createRequestWithRawStreamValue(const std::string& streamRawValue) const { + if (endpoint() == ovms::Endpoint::COMPLETIONS) { + return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"prompt\":\"valid prompt\"}"; + } + if (endpoint() == ovms::Endpoint::RESPONSES) { + return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"input\":\"valid prompt\"}"; + } + return std::string("{\"model\":\"llama\",\"stream\":") + streamRawValue + ",\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; } - return "{\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; - } - std::string createRequestWithNonStringModel() const { - if (endpoint() == ovms::Endpoint::COMPLETIONS) { - return "{\"model\":2,\"prompt\":\"valid prompt\"}"; + std::string createRequestWithoutModel() const { + if (endpoint() == ovms::Endpoint::COMPLETIONS) { + return "{\"prompt\":\"valid prompt\"}"; + } + if (endpoint() == ovms::Endpoint::RESPONSES) { + return "{\"input\":\"valid prompt\"}"; + } + return "{\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; } - if (endpoint() == ovms::Endpoint::RESPONSES) { - return "{\"model\":2,\"input\":\"valid prompt\"}"; + + std::string createRequestWithNonStringModel() const { + if (endpoint() == ovms::Endpoint::COMPLETIONS) { + return "{\"model\":2,\"prompt\":\"valid prompt\"}"; + } + if (endpoint() == ovms::Endpoint::RESPONSES) { + return "{\"model\":2,\"input\":\"valid prompt\"}"; + } + return "{\"model\":2,\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; } - return "{\"model\":2,\"messages\":[{\"role\":\"user\",\"content\":\"valid prompt\"}]}"; - } }; TEST_P(HttpOpenAIHandlerCommonParsingValidationTest, StreamFieldNotABooleanFails) { - std::string json = createRequestWithRawStreamValue("2"); - doc.Parse(json.c_str()); - ASSERT_FALSE(doc.HasParseError()); + std::string json = createRequestWithRawStreamValue("2"); + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); - std::optional maxTokensLimit; - uint32_t bestOfLimit = 0; - std::optional maxModelLength; - std::shared_ptr apiHandler = - std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); - EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("Stream is not bool")); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("Stream is not bool")); } TEST_P(HttpOpenAIHandlerCommonParsingValidationTest, ModelFieldMissingFails) { - std::string json = createRequestWithoutModel(); - doc.Parse(json.c_str()); - ASSERT_FALSE(doc.HasParseError()); + std::string json = createRequestWithoutModel(); + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); - std::optional maxTokensLimit; - uint32_t bestOfLimit = 0; - std::optional maxModelLength; - std::shared_ptr apiHandler = - std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); - EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("model missing in request")); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("model missing in request")); } TEST_P(HttpOpenAIHandlerCommonParsingValidationTest, ModelFieldNotStringFails) { - std::string json = createRequestWithNonStringModel(); - doc.Parse(json.c_str()); - ASSERT_FALSE(doc.HasParseError()); + std::string json = createRequestWithNonStringModel(); + doc.Parse(json.c_str()); + ASSERT_FALSE(doc.HasParseError()); - std::optional maxTokensLimit; - uint32_t bestOfLimit = 0; - std::optional maxModelLength; - std::shared_ptr apiHandler = - std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); - EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("model is not a string")); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("model is not a string")); } INSTANTIATE_TEST_SUITE_P( - CommonParsingValidation, - HttpOpenAIHandlerCommonParsingValidationTest, - ::testing::Values(ovms::Endpoint::CHAT_COMPLETIONS, ovms::Endpoint::COMPLETIONS, ovms::Endpoint::RESPONSES), - [](const testing::TestParamInfo& info) { - switch (info.param) { - case ovms::Endpoint::CHAT_COMPLETIONS: - return "ChatCompletions"; - case ovms::Endpoint::COMPLETIONS: - return "Completions"; - case ovms::Endpoint::RESPONSES: - return "Responses"; - default: - return "Unknown"; - } - }); + CommonParsingValidation, + HttpOpenAIHandlerCommonParsingValidationTest, + ::testing::Values(ovms::Endpoint::CHAT_COMPLETIONS, ovms::Endpoint::COMPLETIONS, ovms::Endpoint::RESPONSES), + [](const testing::TestParamInfo& info) { + switch (info.param) { + case ovms::Endpoint::CHAT_COMPLETIONS: + return "ChatCompletions"; + case ovms::Endpoint::COMPLETIONS: + return "Completions"; + case ovms::Endpoint::RESPONSES: + return "Responses"; + default: + return "Unknown"; + } + }); - class HttpOpenAIHandlerChatAndResponsesParsingTest : public HttpOpenAIHandlerParsingTest, - public ::testing::WithParamInterface { - protected: +class HttpOpenAIHandlerChatAndResponsesParsingTest : public HttpOpenAIHandlerParsingTest, + public ::testing::WithParamInterface { +protected: ovms::Endpoint endpoint() const { - return GetParam(); + return GetParam(); } std::string createTextRequest(const std::string& text, const std::string& extraJsonFields = "") const { - if (endpoint() == ovms::Endpoint::RESPONSES) { - return std::string("{\"model\":\"llama\",\"input\":\"") + text + "\"" + extraJsonFields + "}"; - } - return std::string("{\"model\":\"llama\",\"messages\":[{\"role\":\"user\",\"content\":\"") + text + "\"}]" + extraJsonFields + "}"; + if (endpoint() == ovms::Endpoint::RESPONSES) { + return std::string("{\"model\":\"llama\",\"input\":\"") + text + "\"" + extraJsonFields + "}"; + } + return std::string("{\"model\":\"llama\",\"messages\":[{\"role\":\"user\",\"content\":\"") + text + "\"}]" + extraJsonFields + "}"; } std::string createMultimodalRequestWithImageUrl(const std::string& dataUrl) const { - if (endpoint() == ovms::Endpoint::RESPONSES) { - return std::string("{\"model\":\"llama\",\"input\":[{\"role\":\"user\",\"content\":[{\"type\":\"input_text\",\"text\":\"what is in this image?\"},{\"type\":\"input_image\",\"image_url\":\"") + dataUrl + "\"}]}] }"; - } - return std::string("{\"model\":\"llama\",\"messages\":[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"what is in this image?\"},{\"type\":\"image_url\",\"image_url\":{\"url\":\"") + dataUrl + "\"}}]}]}"; + if (endpoint() == ovms::Endpoint::RESPONSES) { + return std::string("{\"model\":\"llama\",\"input\":[{\"role\":\"user\",\"content\":[{\"type\":\"input_text\",\"text\":\"what is in this image?\"},{\"type\":\"input_image\",\"image_url\":\"") + dataUrl + "\"}]}] }"; + } + return std::string("{\"model\":\"llama\",\"messages\":[{\"role\":\"user\",\"content\":[{\"type\":\"text\",\"text\":\"what is in this image?\"},{\"type\":\"image_url\",\"image_url\":{\"url\":\"") + dataUrl + "\"}}]}]}"; } std::string createToolRequest(const std::string& toolChoiceJson) const { - std::string base = createTextRequest("What is the weather like in Boston today?", ",\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\"}},\"required\":[\"location\"]}}}]"); - if (toolChoiceJson.empty()) { + std::string base = createTextRequest("What is the weather like in Boston today?", ",\"tools\":[{\"type\":\"function\",\"function\":{\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\"}},\"required\":[\"location\"]}}}]"); + if (toolChoiceJson.empty()) { + return base; + } + base.pop_back(); // remove trailing '}' + base += ",\"tool_choice\":" + toolChoiceJson + "}"; return base; - } - base.pop_back(); // remove trailing '}' - base += ",\"tool_choice\":" + toolChoiceJson + "}"; - return base; } std::shared_ptr parseCurrentRequest(const std::string& json) { - doc.Parse(json.c_str()); - EXPECT_FALSE(doc.HasParseError()) << json; - std::optional maxTokensLimit; - uint32_t bestOfLimit = 0; - std::optional maxModelLength; - std::shared_ptr apiHandler = - std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); - EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()) << json; - return apiHandler; + doc.Parse(json.c_str()); + EXPECT_FALSE(doc.HasParseError()) << json; + std::optional maxTokensLimit; + uint32_t bestOfLimit = 0; + std::optional maxModelLength; + std::shared_ptr apiHandler = + std::make_shared(doc, endpoint(), std::chrono::system_clock::now(), *tokenizer); + EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()) << json; + return apiHandler; } - }; +}; - TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingTextInputCreatesUserChatMessage) { +TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingTextInputCreatesUserChatMessage) { std::string json = createTextRequest("What is OpenVINO?"); auto apiHandler = parseCurrentRequest(json); @@ -522,66 +522,66 @@ INSTANTIATE_TEST_SUITE_P( EXPECT_EQ(chatHistory[0]["role"], "user"); EXPECT_EQ(chatHistory[0]["content"], "What is OpenVINO?"); if (endpoint() == ovms::Endpoint::RESPONSES) { - EXPECT_NE(apiHandler->getProcessedJson().find("\"messages\""), std::string::npos); + EXPECT_NE(apiHandler->getProcessedJson().find("\"messages\""), std::string::npos); } else { - EXPECT_TRUE(apiHandler->getProcessedJson().empty()); + EXPECT_TRUE(apiHandler->getProcessedJson().empty()); } - } +} - TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingTokenLimitSetsMaxTokens) { +TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingTokenLimitSetsMaxTokens) { std::string tokenField = endpoint() == ovms::Endpoint::RESPONSES ? "max_output_tokens" : "max_completion_tokens"; std::string json = createTextRequest("valid prompt", ",\"" + tokenField + "\":7"); auto apiHandler = parseCurrentRequest(json); EXPECT_TRUE(apiHandler->getMaxTokens().has_value()); EXPECT_EQ(apiHandler->getMaxTokens().value(), 7); - } +} - TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingFunctionToolsWithAutoChoiceSucceeds) { +TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingFunctionToolsWithAutoChoiceSucceeds) { std::string json = createToolRequest("\"auto\""); auto apiHandler = parseCurrentRequest(json); EXPECT_TRUE(apiHandler->areToolsAvailable()); EXPECT_EQ(apiHandler->getToolChoice(), "auto"); - } +} - TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingToolChoiceFunctionObjectSucceeds) { +TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingToolChoiceFunctionObjectSucceeds) { std::string json = createToolRequest("{\"type\":\"function\",\"function\":{\"name\":\"get_current_weather\"}}"); auto apiHandler = parseCurrentRequest(json); EXPECT_TRUE(apiHandler->areToolsAvailable()); EXPECT_EQ(apiHandler->getToolChoice(), "get_current_weather"); - } +} - TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingToolChoiceNoneRemovesTools) { +TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingToolChoiceNoneRemovesTools) { std::string json = createToolRequest("\"none\""); auto apiHandler = parseCurrentRequest(json); EXPECT_FALSE(apiHandler->areToolsAvailable()); EXPECT_EQ(apiHandler->getToolChoice(), "none"); - } +} - TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingMultimodalInputImageSucceeds) { +TEST_P(HttpOpenAIHandlerChatAndResponsesParsingTest, ParsingMultimodalInputImageSucceeds) { const std::string base64Image = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAEElEQVR4nGLK27oAEAAA//8DYAHGgEvy5AAAAABJRU5ErkJggg=="; std::string json = createMultimodalRequestWithImageUrl(base64Image); auto apiHandler = parseCurrentRequest(json); EXPECT_EQ(apiHandler->getImageHistory().size(), 1); - } +} - INSTANTIATE_TEST_SUITE_P( +INSTANTIATE_TEST_SUITE_P( ChatAndResponses, HttpOpenAIHandlerChatAndResponsesParsingTest, ::testing::Values(ovms::Endpoint::CHAT_COMPLETIONS, ovms::Endpoint::RESPONSES), [](const testing::TestParamInfo& info) { - switch (info.param) { - case ovms::Endpoint::CHAT_COMPLETIONS: - return "ChatCompletions"; - case ovms::Endpoint::RESPONSES: - return "Responses"; - default: - return "Unknown"; - } + switch (info.param) { + case ovms::Endpoint::CHAT_COMPLETIONS: + return "ChatCompletions"; + case ovms::Endpoint::RESPONSES: + return "Responses"; + default: + return "Unknown"; + } }); static std::vector createHermes3ToolCallTokens(ov::genai::Tokenizer& tokenizer) { From f03f34dc3d6c252d9acd67fe49fa41e0685cf635 Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Fri, 6 Mar 2026 10:42:11 +0100 Subject: [PATCH 6/8] remove redundant tests --- src/test/http_openai_handler_test.cpp | 65 --------------------------- 1 file changed, 65 deletions(-) diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index 232e7b4ae4..88b788ac8e 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -1579,47 +1579,6 @@ TEST_F(HttpOpenAIHandlerParsingTest, ParsingRequestWithNullParametersCompletions } } -TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesMaxOutputTokensSetsMaxTokens) { - std::string json = R"({ - "model": "llama", - "input": "valid prompt", - "max_output_tokens": 7 - })"; - doc.Parse(json.c_str()); - ASSERT_FALSE(doc.HasParseError()); - std::optional maxTokensLimit; - uint32_t bestOfLimit = 0; - std::optional maxModelLength; - std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); - EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); - EXPECT_TRUE(apiHandler->getMaxTokens().has_value()); - EXPECT_EQ(apiHandler->getMaxTokens().value(), 7); -} - -TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesStringInputCreatesUserChatMessage) { - std::string json = R"({ - "model": "llama", - "input": "What is OpenVINO?" - })"; - doc.Parse(json.c_str()); - ASSERT_FALSE(doc.HasParseError()); - std::optional maxTokensLimit; - uint32_t bestOfLimit = 0; - std::optional maxModelLength; - std::shared_ptr apiHandler = std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); - EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); - - auto& chatHistory = apiHandler->getChatHistory(); - ASSERT_EQ(chatHistory.size(), 1); - ASSERT_TRUE(chatHistory[0].contains("role")); - ASSERT_TRUE(chatHistory[0].contains("content")); - EXPECT_EQ(chatHistory[0]["role"], "user"); - EXPECT_EQ(chatHistory[0]["content"], "What is OpenVINO?"); - EXPECT_NE(apiHandler->getProcessedJson().find("\"messages\""), std::string::npos); - EXPECT_NE(apiHandler->getProcessedJson().find("\"role\":\"user\""), std::string::npos); - EXPECT_NE(apiHandler->getProcessedJson().find("\"input\":\"What is OpenVINO?\""), std::string::npos); -} - TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesConflictingOutputAndCompletionTokensFails) { std::string json = R"({ "model": "llama", @@ -1862,30 +1821,6 @@ TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesToolChoiceFunctionObjectNam EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::InvalidArgumentError("tool_choice.name is not a valid string")); } -TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesInputImageUrlStringSucceeds) { - std::string json = R"({ - "model": "llama", - "input": [ - { - "role": "user", - "content": [ - {"type": "input_text", "text": "what is in this image?"}, - {"type": "input_image", "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAAEElEQVR4nGLK27oAEAAA//8DYAHGgEvy5AAAAABJRU5ErkJggg=="} - ] - } - ] - })"; - doc.Parse(json.c_str()); - ASSERT_FALSE(doc.HasParseError()); - std::optional maxTokensLimit; - uint32_t bestOfLimit = 0; - std::optional maxModelLength; - std::shared_ptr apiHandler = - std::make_shared(doc, ovms::Endpoint::RESPONSES, std::chrono::system_clock::now(), *tokenizer); - EXPECT_EQ(apiHandler->parseRequest(maxTokensLimit, bestOfLimit, maxModelLength), absl::OkStatus()); - EXPECT_EQ(apiHandler->getImageHistory().size(), 1); -} - TEST_F(HttpOpenAIHandlerParsingTest, ParsingResponsesInputImageUrlObjectSucceeds) { std::string json = R"({ "model": "llama", From e18b165e0fbb554a627ed87b8f37cf13a15021f2 Mon Sep 17 00:00:00 2001 From: Michal Kulakowski Date: Fri, 6 Mar 2026 10:45:18 +0100 Subject: [PATCH 7/8] update tools parsing in responses --- src/llm/apis/openai_completions.cpp | 190 +++++++++++----------------- 1 file changed, 75 insertions(+), 115 deletions(-) diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 66b40b7287..358f846511 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -220,85 +220,6 @@ std::string serializeResponsesUnaryResponse( return buffer.GetString(); } -absl::Status normalizeResponsesFunctionToolsInPlace(rapidjson::Document& doc) { - auto toolsIt = doc.FindMember("tools"); - if (toolsIt == doc.MemberEnd() || toolsIt->value.IsNull()) { - return absl::OkStatus(); - } - if (!toolsIt->value.IsArray()) { - return absl::InvalidArgumentError("Tools are not an array"); - } - - auto& allocator = doc.GetAllocator(); - for (auto& toolValue : toolsIt->value.GetArray()) { - if (!toolValue.IsObject()) { - return absl::InvalidArgumentError("Tool is not a JSON object"); - } - auto toolObj = toolValue.GetObject(); - auto typeIt = toolObj.FindMember("type"); - if (typeIt == toolObj.MemberEnd() || !typeIt->value.IsString()) { - return absl::InvalidArgumentError("Tool type is missing or invalid"); - } - if (std::string(typeIt->value.GetString()) != "function") { - return absl::InvalidArgumentError("Only function tools are supported"); - } - - auto functionIt = toolObj.FindMember("function"); - if (functionIt != toolObj.MemberEnd()) { - if (!functionIt->value.IsObject()) { - return absl::InvalidArgumentError("Function is not a valid JSON object"); - } - continue; - } - - auto nameIt = toolObj.FindMember("name"); - if (nameIt == toolObj.MemberEnd() || !nameIt->value.IsString()) { - return absl::InvalidArgumentError("Function object does not contain a valid name field"); - } - - rapidjson::Value functionObj(rapidjson::kObjectType); - functionObj.AddMember("name", rapidjson::Value(nameIt->value.GetString(), allocator), allocator); - - auto descriptionIt = toolObj.FindMember("description"); - if (descriptionIt != toolObj.MemberEnd() && descriptionIt->value.IsString()) { - functionObj.AddMember("description", rapidjson::Value(descriptionIt->value.GetString(), allocator), allocator); - } - - auto parametersIt = toolObj.FindMember("parameters"); - if (parametersIt != toolObj.MemberEnd()) { - if (!parametersIt->value.IsObject()) { - return absl::InvalidArgumentError("Function parameters are not a valid JSON object"); - } - rapidjson::Value parametersCopy(rapidjson::kObjectType); - parametersCopy.CopyFrom(parametersIt->value, allocator); - functionObj.AddMember("parameters", parametersCopy, allocator); - } - - toolValue.AddMember("function", functionObj, allocator); - } - - auto toolChoiceIt = doc.FindMember("tool_choice"); - if (toolChoiceIt != doc.MemberEnd() && !toolChoiceIt->value.IsNull() && toolChoiceIt->value.IsObject()) { - auto toolChoiceObj = toolChoiceIt->value.GetObject(); - auto functionIt = toolChoiceObj.FindMember("function"); - if (functionIt == toolChoiceObj.MemberEnd()) { - auto typeIt = toolChoiceObj.FindMember("type"); - auto nameIt = toolChoiceObj.FindMember("name"); - if (typeIt != toolChoiceObj.MemberEnd() && typeIt->value.IsString() && std::string(typeIt->value.GetString()) == "function") { - if (nameIt == toolChoiceObj.MemberEnd() || !nameIt->value.IsString()) { - return absl::InvalidArgumentError("tool_choice.name is not a valid string"); - } - - rapidjson::Value functionObj(rapidjson::kObjectType); - functionObj.AddMember("name", rapidjson::Value(nameIt->value.GetString(), allocator), allocator); - toolChoiceIt->value.AddMember("function", functionObj, allocator); - } - } - } - - return absl::OkStatus(); -} - } // namespace absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() { @@ -806,8 +727,9 @@ absl::Status OpenAIChatCompletionsHandler::parseTools() { if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") return absl::InvalidArgumentError("tool_choice should be either 'none' or 'auto' or 'required'"); } else if (tool_choice_it->value.IsObject()) { - auto tool_choice_functionIt = tool_choice_it->value.GetObject().FindMember("function"); - if (tool_choice_functionIt != tool_choice_it->value.GetObject().MemberEnd() && tool_choice_functionIt->value.IsObject()) { + auto toolChoiceObj = tool_choice_it->value.GetObject(); + auto tool_choice_functionIt = toolChoiceObj.FindMember("function"); + if (tool_choice_functionIt != toolChoiceObj.MemberEnd() && tool_choice_functionIt->value.IsObject()) { auto nameIt = tool_choice_functionIt->value.GetObject().FindMember("name"); if (nameIt != tool_choice_functionIt->value.GetObject().MemberEnd() && nameIt->value.IsString()) { tool_choice = nameIt->value.GetString(); @@ -815,7 +737,16 @@ absl::Status OpenAIChatCompletionsHandler::parseTools() { return absl::InvalidArgumentError("tool_choice.function.name is not a valid string"); } } else { - return absl::InvalidArgumentError("tool_choice.function is not a valid JSON object"); + auto typeIt = toolChoiceObj.FindMember("type"); + auto nameIt = toolChoiceObj.FindMember("name"); + if (typeIt != toolChoiceObj.MemberEnd() && typeIt->value.IsString() && std::string(typeIt->value.GetString()) == "function") { + if (nameIt == toolChoiceObj.MemberEnd() || !nameIt->value.IsString()) { + return absl::InvalidArgumentError("tool_choice.name is not a valid string"); + } + tool_choice = nameIt->value.GetString(); + } else { + return absl::InvalidArgumentError("tool_choice.function is not a valid JSON object"); + } } } else { return absl::InvalidArgumentError("tool_choice is not a valid JSON object or string"); @@ -835,38 +766,71 @@ absl::Status OpenAIChatCompletionsHandler::parseTools() { auto& obj = it->value.GetArray()[i]; if (!obj.IsObject()) return absl::InvalidArgumentError("Tool is not a JSON object"); + const rapidjson::Value* functionObj = nullptr; + const rapidjson::Value* parametersValue = nullptr; + const char* functionNameCStr = nullptr; + auto functionIt = obj.FindMember("function"); - if (functionIt != obj.MemberEnd() && functionIt->value.IsObject()) { - auto nameIt = functionIt->value.GetObject().FindMember("name"); - if (nameIt != functionIt->value.GetObject().MemberEnd() && nameIt->value.IsString()) { - std::string functionName = nameIt->value.GetString(); - // If tool_choice is set to "auto", we keep all tools - // If tool_choice is set to a specific function name, we keep only that tool - if (tool_choice != "auto" && tool_choice != "required" && tool_choice != functionName) { - it->value.Erase(&obj); - jsonChanged = true; - } else { - i++; - // If we keep the tool, add tool name and schema to the request - auto parametersIt = functionIt->value.GetObject().FindMember("parameters"); - if (parametersIt != functionIt->value.GetObject().MemberEnd() && parametersIt->value.IsObject()) { - // now we want to insert to a mapping of - // tool name -> tool schema representations struct - // Dump parameters object to string since this is the schema format expected by GenAI - // Keep the rapidjson::Value object as well to avoid re-parsing in outputParsers - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - parametersIt->value.Accept(writer); - std::string parametersStr = buffer.GetString(); - ToolSchemaWrapper schemaReprs{¶metersIt->value, std::move(parametersStr)}; - request.toolNameSchemaMap[nameIt->value.GetString()] = std::move(schemaReprs); - } - } - } else { + if (functionIt != obj.MemberEnd()) { + if (!functionIt->value.IsObject()) { + return absl::InvalidArgumentError("Function is not a valid JSON object"); + } + functionObj = &functionIt->value; + auto nameIt = functionObj->GetObject().FindMember("name"); + if (nameIt == functionObj->GetObject().MemberEnd() || !nameIt->value.IsString()) { return absl::InvalidArgumentError("Function object does not contain a valid name field"); } + functionNameCStr = nameIt->value.GetString(); + auto parametersIt = functionObj->GetObject().FindMember("parameters"); + if (parametersIt != functionObj->GetObject().MemberEnd()) { + parametersValue = ¶metersIt->value; + } } else { - return absl::InvalidArgumentError("Function is not a valid JSON object"); + auto typeIt = obj.FindMember("type"); + if (typeIt == obj.MemberEnd() || !typeIt->value.IsString()) { + return absl::InvalidArgumentError("Tool type is missing or invalid"); + } + if (std::string(typeIt->value.GetString()) != "function") { + return absl::InvalidArgumentError("Only function tools are supported"); + } + + auto nameIt = obj.FindMember("name"); + if (nameIt == obj.MemberEnd() || !nameIt->value.IsString()) { + return absl::InvalidArgumentError("Function object does not contain a valid name field"); + } + functionNameCStr = nameIt->value.GetString(); + + auto parametersIt = obj.FindMember("parameters"); + if (parametersIt != obj.MemberEnd()) { + parametersValue = ¶metersIt->value; + } + } + + std::string functionName = functionNameCStr; + // If tool_choice is set to "auto", we keep all tools + // If tool_choice is set to a specific function name, we keep only that tool + if (tool_choice != "auto" && tool_choice != "required" && tool_choice != functionName) { + it->value.Erase(&obj); + jsonChanged = true; + continue; + } + + i++; + // If we keep the tool, add tool name and schema to the request + if (parametersValue != nullptr) { + if (!parametersValue->IsObject()) { + return absl::InvalidArgumentError("Function parameters are not a valid JSON object"); + } + // now we want to insert to a mapping of + // tool name -> tool schema representations struct + // Dump parameters object to string since this is the schema format expected by GenAI + // Keep the rapidjson::Value object as well to avoid re-parsing in outputParsers + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + parametersValue->Accept(writer); + std::string parametersStr = buffer.GetString(); + ToolSchemaWrapper schemaReprs{parametersValue, std::move(parametersStr)}; + request.toolNameSchemaMap[functionNameCStr] = std::move(schemaReprs); } } } else { @@ -1070,11 +1034,7 @@ absl::Status OpenAIChatCompletionsHandler::parseResponsesPart(std::optional Date: Fri, 6 Mar 2026 13:56:26 +0100 Subject: [PATCH 8/8] fix --- src/llm/apis/openai_completions.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 358f846511..1ebdbfcad7 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -766,8 +766,8 @@ absl::Status OpenAIChatCompletionsHandler::parseTools() { auto& obj = it->value.GetArray()[i]; if (!obj.IsObject()) return absl::InvalidArgumentError("Tool is not a JSON object"); - const rapidjson::Value* functionObj = nullptr; - const rapidjson::Value* parametersValue = nullptr; + rapidjson::Value* functionObj = nullptr; + rapidjson::Value* parametersValue = nullptr; const char* functionNameCStr = nullptr; auto functionIt = obj.FindMember("function");