-
Notifications
You must be signed in to change notification settings - Fork 231
devstral tool parser for tool calling #3851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e3fb518
bf74839
28cd83b
33a1062
104c980
a150c4d
ccc71d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
|
|
||
| #include <memory> | ||
| #include <string> | ||
| #include <utility> | ||
| #include <openvino/genai/generation_config.hpp> | ||
|
|
||
| #include "generation_config_builder.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| void DevstralGenerationConfigBuilder::parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) { | ||
| // Call the base class method to fill in common configuration | ||
| BaseGenerationConfigBuilder::parseConfigFromRequest(request); | ||
|
|
||
| // For now the only specific part is related to tools, so if there are no tools provided in the request | ||
| // we can exit early | ||
| if (request.toolNameSchemaMap.empty()) { | ||
| return; | ||
| } | ||
|
|
||
| if (enableToolGuidedGeneration || request.toolChoice == "required") { | ||
| // Set tool guided generation config specific to Devstral model | ||
| auto triggeredTags = std::make_shared<ov::genai::StructuredOutputConfig::TriggeredTags>(); | ||
| triggeredTags->triggers.push_back("[TOOL_CALLS]"); | ||
|
|
||
| for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) { | ||
| const auto& toolSchema = toolSchemaWrapper.stringRepr; | ||
| ov::genai::StructuredOutputConfig::Tag tagItem; | ||
| tagItem.begin = "[TOOL_CALLS]" + toolName + "[ARGS]"; | ||
| // tagItem.end = "</s>"; | ||
| tagItem.content = ov::genai::StructuredOutputConfig::JSONSchema(toolSchema); | ||
| triggeredTags->tags.push_back(tagItem); | ||
| } | ||
| if (request.toolChoice == "required") { | ||
| triggeredTags->at_least_one = true; | ||
| } | ||
| ov::genai::StructuredOutputConfig::StructuralTag structuralTag = triggeredTags; | ||
| setStructuralTagsConfig(structuralTag); | ||
| } | ||
| } | ||
|
|
||
| } // namespace ovms |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
| #pragma once | ||
| #include "../base_generation_config_builder.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| /* | ||
| * Phi4GenerationConfigBuilder extends BaseGenerationConfigBuilder to provide specific configuration for Phi-4 model. | ||
| * It overrides the parseConfigFromRequest method to set tool guided generation config. | ||
| */ | ||
| class DevstralGenerationConfigBuilder : public BaseGenerationConfigBuilder { | ||
| public: | ||
| DevstralGenerationConfigBuilder() = delete; | ||
| explicit DevstralGenerationConfigBuilder(const ov::genai::GenerationConfig& baseConfig, bool enableToolGuidedGeneration, DecodingMethod decodingMethod) : | ||
| BaseGenerationConfigBuilder(baseConfig, enableToolGuidedGeneration, decodingMethod) {} | ||
|
|
||
| void parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) override; | ||
| }; | ||
| } // namespace ovms |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
|
|
||
| #include <openvino/genai/tokenizer.hpp> | ||
| #include <string> | ||
| #include <vector> | ||
| #include <regex> | ||
|
|
||
| #include "src/port/rapidjson_document.hpp" | ||
|
|
||
| #include "../../../logging.hpp" | ||
| #include "tool_parser.hpp" | ||
| #include "../utils.hpp" | ||
| #include "src/stringutils.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| void DevstralToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) { | ||
| std::vector<std::string> tools; | ||
| // expected format: [TOOL_CALLS]tool_name[ARGS]{"arg1": "value1", ...} | ||
| if (parsedOutput.content.empty() || generatedTokens.size() <= 0) { | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No content to parse for tool calls"); | ||
| return; | ||
| } | ||
| size_t firstToolTokenIndex; | ||
| auto it = std::find(generatedTokens.begin(), generatedTokens.end(), this->botTokenId); | ||
| if (it != generatedTokens.end()) { | ||
| firstToolTokenIndex = std::distance(generatedTokens.begin(), it); | ||
| } else { | ||
| return; | ||
| } | ||
|
|
||
| size_t firstArgsTokenIndex; | ||
| auto itArgs = std::find(generatedTokens.begin() + firstToolTokenIndex, generatedTokens.end(), this->argsTokenId); | ||
| if (itArgs != generatedTokens.end()) { | ||
| firstArgsTokenIndex = std::distance(generatedTokens.begin(), itArgs); | ||
| } else { | ||
| return; | ||
| } | ||
| if (firstToolTokenIndex > firstArgsTokenIndex) { | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "First tool token index is greater than first args token index."); | ||
| return; | ||
| } | ||
| std::vector<int64_t> toolNameTokens(generatedTokens.begin() + (firstToolTokenIndex + 1), generatedTokens.begin() + (firstArgsTokenIndex)); | ||
| std::vector<int64_t> argumentsTokens(generatedTokens.begin() + (firstArgsTokenIndex + 1), generatedTokens.end()); | ||
|
|
||
| ToolCall toolCall; | ||
| std::string toolName = tokenizer.decode(toolNameTokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); | ||
| std::string arguments = tokenizer.decode(argumentsTokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); | ||
| toolCall.name = toolName; | ||
| toolCall.arguments = arguments; | ||
| toolCall.id = generateRandomId(); // Generate a random ID for the tool call | ||
| parsedOutput.toolCalls.push_back(toolCall); | ||
|
|
||
| // get subset of generatedTokens starting from begin() to firstArgsTokenIndex | ||
| std::vector<int64_t> contentTokens; | ||
| if (firstToolTokenIndex > 0) { | ||
| contentTokens = std::vector<int64_t>(generatedTokens.begin(), generatedTokens.begin() + firstToolTokenIndex); | ||
| parsedOutput.content = tokenizer.decode(contentTokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); // Return only the content till tool call | ||
| } else { | ||
| parsedOutput.content = ""; | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| std::optional<rapidjson::Document> DevstralToolParser::sendFullDelta(ToolCall& toolCall) { | ||
| rapidjson::Document argsDelta; | ||
| argsDelta.Parse(toolCall.arguments.c_str()); | ||
| rapidjson::Document argumentsWrapper; | ||
| argumentsWrapper.SetObject(); | ||
| rapidjson::Document::AllocatorType& allocator = argumentsWrapper.GetAllocator(); | ||
| // now we need to add string toolCall.arguments to argumentsWrapper under "arguments" key | ||
| rapidjson::Value toolCallsString(rapidjson::kStringType); | ||
| toolCallsString.SetString(toolCall.arguments.c_str(), allocator); | ||
| argumentsWrapper.AddMember("arguments", toolCallsString, allocator); | ||
| auto currentDelta = wrapDelta(argumentsWrapper, this->toolCallIndex); | ||
| return currentDelta; | ||
| } | ||
|
|
||
| std::optional<rapidjson::Document> DevstralToolParser::parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) { | ||
| /* | ||
| Devstral [TOOL_CALL]tool_name[ARGS]arguments[</s>] | ||
| It does not support parallel tool calls, so tool calls are always in sequence. | ||
|
|
||
| We have three processing states: | ||
| AWAITING_START_TAG, | ||
| AWAITING_ARGS_TAG, | ||
| PROCESSING_ARGS | ||
|
|
||
| We store the history of chunks in streamContent string. After state changes are detected, we clear the streamContent to only keep unprocessed part. | ||
| */ | ||
|
|
||
| this->streamContent += chunk; | ||
| SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Chunk content: '{}'", chunk); | ||
| if (this->internalState == AWAITING_START_TAG) { | ||
| size_t pos = chunk.find(this->streamingParsingToolCallsStartTag); | ||
| if (pos != std::string::npos) { | ||
| this->internalState = AWAITING_ARGS_TAG; | ||
| this->toolCallIndex++; | ||
| if (pos == 0) { | ||
| this->streamContent.clear(); | ||
| } else { | ||
| this->streamContent = this->streamContent.substr(pos + this->streamingParsingToolCallsStartTag.length()); // "[TOOLS_CALLS]" length is 13 | ||
| } | ||
| } else { | ||
| return std::nullopt; | ||
| } | ||
| } | ||
| if (this->internalState == AWAITING_ARGS_TAG) { | ||
| // check if [ARGS] tag is present in the chunk and update state accordingly | ||
| size_t pos = this->streamContent.find(this->streamingParsingArgsStartTag); | ||
| if (pos != std::string::npos) { | ||
| this->internalState = PROCESSING_ARGS; | ||
| this->toolName = this->streamContent.substr(0, pos); | ||
| this->streamContent = this->streamContent.substr(pos + this->streamingParsingArgsStartTag.length()); // "[ARGS]" length is 6 | ||
| return wrapFirstDelta(this->toolName, this->toolCallIndex); | ||
| } else { | ||
| return std::nullopt; | ||
| } | ||
| } | ||
| if (finishReason != ov::genai::GenerationFinishReason::NONE) { | ||
| size_t endPos = this->streamContent.find(this->streamingEndTag); | ||
| std::string arguments; | ||
| if (endPos != std::string::npos) { | ||
| arguments = this->streamContent.substr(0, endPos); | ||
| } else { | ||
| arguments = this->streamContent; | ||
| } | ||
| if (!arguments.empty()) { | ||
| ToolCall toolCall; | ||
| toolCall.arguments = arguments; | ||
| toolCall.name = this->toolName; | ||
| return sendFullDelta(toolCall); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldntg we stream partial function argument chunks? if i understand correctly you send full delta at the end of generation
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already accepted such approach for qwen3 coder, so I suppose we can have it in other parsers as well unless there are specific requirements for "real" streaming. |
||
| } else { | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No valid arguments found in streamContent."); | ||
| return std::nullopt; | ||
| } | ||
| } | ||
| return std::nullopt; | ||
| } | ||
| } // namespace ovms | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
| #pragma once | ||
|
|
||
| #include <openvino/genai/tokenizer.hpp> | ||
| #include <optional> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "src/port/rapidjson_document.hpp" | ||
|
|
||
| #include "src/llm/io_processing/base_output_parser.hpp" | ||
| #include "src/llm/io_processing/partial_json_builder.hpp" | ||
| #include "src/llm/apis/tool_schema_wrapper.hpp" | ||
|
|
||
| namespace ovms { | ||
| class DevstralToolParser : public BaseOutputParser { | ||
| const int64_t argsTokenId; // [ARGS] | ||
| const int64_t botTokenId; // [TOOL_CALLS] | ||
|
|
||
| // in streaming mode we can rely on tags in string format as tokens are not available | ||
| const std::string streamingParsingArgsStartTag = "[ARGS]"; | ||
| const std::string streamingParsingToolCallsStartTag = "[TOOL_CALLS]"; | ||
|
Comment on lines
+31
to
+36
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those tags/tokens are not specific to streaming, so I think we can drop |
||
| const std::string streamingEndTag = "</s>"; | ||
|
|
||
| enum InternalState { | ||
| AWAITING_START_TAG, | ||
| AWAITING_ARGS_TAG, | ||
| PROCESSING_ARGS | ||
| }; | ||
|
|
||
| InternalState internalState = AWAITING_START_TAG; | ||
| const ToolsSchemas_t& toolSchemas; | ||
| // Index to track the current tool call being processed (-1 means no tool call has been started yet) | ||
| int toolCallIndex = -1; | ||
| std::string streamContent = ""; // content accumulated from stream chunks | ||
| std::string toolName = ""; | ||
| std::optional<rapidjson::Document> sendFullDelta(ToolCall& toolCall); | ||
|
|
||
| public: | ||
| DevstralToolParser() = delete; | ||
| DevstralToolParser(ov::genai::Tokenizer& tokenizer, const ToolsSchemas_t& toolSchemas) : | ||
| BaseOutputParser(tokenizer), | ||
| argsTokenId([&tokenizer, this]() { | ||
| // can not use streamingParsingArgsStartTag because object is not initialized yet | ||
| auto encoded = tokenizer.encode("[ARGS]", {{"add_special_tokens", false}}).input_ids; | ||
| if (encoded.get_shape()[0] != 1) { | ||
| throw std::runtime_error("[ARGS] must be a single token in the tokenizer vocabulary."); | ||
| } | ||
| return encoded.data<int64_t>()[0]; | ||
| }()), | ||
| botTokenId([&tokenizer, this]() { | ||
| // can not use streamingParsingToolCallsStartTag because object is not initialized yet | ||
| auto encoded = tokenizer.encode("[TOOL_CALLS]", {{"add_special_tokens", false}}).input_ids; | ||
| if (encoded.get_shape()[0] != 1) { | ||
| throw std::runtime_error("[TOOL_CALLS] must be a single token in the tokenizer vocabulary."); | ||
| } | ||
| return encoded.data<int64_t>()[0]; | ||
| }()), | ||
| toolSchemas(toolSchemas) {} | ||
|
|
||
| void parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) override; | ||
| std::optional<rapidjson::Document> parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) override; | ||
| const std::vector<std::string>& getParsingStartTags() const override { | ||
| static const std::vector<std::string> toolCallStartTags{streamingParsingToolCallsStartTag}; | ||
| return toolCallStartTags; | ||
| } | ||
| const std::vector<std::string>& getSpecialParsingStartTags() const override { | ||
| static const std::vector<std::string> specialParsingStartTags{}; | ||
| return specialParsingStartTags; | ||
| } | ||
| // Tools calls are expected to be the last part of the content, so we do not specify an end tag. | ||
| const std::string& getParsingEndTag() const override { | ||
| static const std::string toolCallEndTag = "</s>"; | ||
| return toolCallEndTag; | ||
| } | ||
|
|
||
| bool requiresStreamingWithSpecialTokens() const override { | ||
| return true; | ||
| } | ||
| }; | ||
| } // namespace ovms | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically we check streamContent but it will be the case only if [ARGS] is added in the chunk. Otherwise it would be different state