diff --git a/Releases/0.10.10.md b/Releases/0.10.10.md new file mode 100644 index 00000000..66e6b23d --- /dev/null +++ b/Releases/0.10.10.md @@ -0,0 +1,5 @@ +# 0.10.10 release + +- Adds `WithMaxIterations()` to `ToolsConfigurationBuilder` and `McpContext` to override the default tool-call iteration limit. +- Fix: MCP loop now sends a final synthesis request instead of returning an error string when the iteration cap is reached. +- Fix: Gemini/Vertex backends now throw `NotSupportedException` when `WithMaxIterations` is used instead of silently ignoring the value. diff --git a/src/MaIN.Core/.nuspec b/src/MaIN.Core/.nuspec index 8ed72828..ba30e6a5 100644 --- a/src/MaIN.Core/.nuspec +++ b/src/MaIN.Core/.nuspec @@ -2,7 +2,7 @@ MaIN.NET - 0.10.9 + 0.10.10 Wisedev Wisedev favicon.png @@ -34,4 +34,4 @@ - \ No newline at end of file + diff --git a/src/MaIN.Core/Hub/Contexts/Interfaces/McpContext/IMcpContext.cs b/src/MaIN.Core/Hub/Contexts/Interfaces/McpContext/IMcpContext.cs index e0c0aa5b..166f7ea9 100644 --- a/src/MaIN.Core/Hub/Contexts/Interfaces/McpContext/IMcpContext.cs +++ b/src/MaIN.Core/Hub/Contexts/Interfaces/McpContext/IMcpContext.cs @@ -1,4 +1,4 @@ -using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration; using MaIN.Domain.Entities; using MaIN.Services.Services.Models; @@ -22,10 +22,20 @@ public interface IMcpContext /// The context instance implementing for method chaining. IMcpContext WithBackend(BackendType backendType); + /// + /// Sets the maximum number of tool-call iterations allowed in a single MCP prompt. + /// Overrides the default limit of 10. Must be at least 1. + /// + /// + /// Not supported for and backends - + /// a will be thrown at runtime when is called. + /// + IMcpContext WithMaxIterations(int maxIterations); + /// /// Asynchronously processes a prompt through the configured MCP service, sending the prompt to the MCP server and returning the processed result. /// /// The text prompt to be processed by the MCP service /// A object containing the processed response from the MCP server. Task PromptAsync(string prompt); -} \ No newline at end of file +} diff --git a/src/MaIN.Core/Hub/Contexts/McpContext.cs b/src/MaIN.Core/Hub/Contexts/McpContext.cs index a29b45d4..fccb40a8 100644 --- a/src/MaIN.Core/Hub/Contexts/McpContext.cs +++ b/src/MaIN.Core/Hub/Contexts/McpContext.cs @@ -2,6 +2,7 @@ using MaIN.Domain.Configuration; using MaIN.Domain.Entities; using MaIN.Domain.Exceptions.MPC; +using MaIN.Domain.Exceptions.Tools; using MaIN.Services.Constants; using MaIN.Services.Services.Abstract; using MaIN.Services.Services.Models; @@ -13,6 +14,7 @@ public sealed class McpContext : IMcpContext private readonly IMcpService _mcpService; private Mcp? _mcpConfig; private BackendType? _explicitBackend; + private int? _maxIterations; internal McpContext(IMcpService mcpService) { @@ -24,7 +26,10 @@ public IMcpContext WithConfig(Mcp mcpConfig) { _mcpConfig = mcpConfig; if (_explicitBackend.HasValue) + { _mcpConfig.Backend = _explicitBackend; + } + return this; } @@ -35,18 +40,25 @@ public IMcpContext WithBackend(BackendType backendType) return this; } + public IMcpContext WithMaxIterations(int maxIterations) + { + InvalidToolIterationsException.ThrowIfInvalid(maxIterations); + _maxIterations = maxIterations; + return this; + } + public async Task PromptAsync(string prompt) { if (_mcpConfig == null) { throw new MPCConfigNotFoundException(); } - - return await _mcpService.Prompt(_mcpConfig!, [new Message() + + return await _mcpService.Prompt(_mcpConfig, [new Message() { Content = prompt, Role = ServiceConstants.Roles.User, Type = MessageType.CloudLLM - }]); + }], _maxIterations); } -} \ No newline at end of file +} diff --git a/src/MaIN.Core/Hub/Utils/ToolConfigurationBuilder.cs b/src/MaIN.Core/Hub/Utils/ToolConfigurationBuilder.cs index 3c353bb6..a6f7ce06 100644 --- a/src/MaIN.Core/Hub/Utils/ToolConfigurationBuilder.cs +++ b/src/MaIN.Core/Hub/Utils/ToolConfigurationBuilder.cs @@ -1,150 +1,96 @@ using System.Text.Json; using MaIN.Domain.Entities.Tools; +using MaIN.Domain.Exceptions.Tools; namespace MaIN.Core.Hub.Utils; -//TODO try to share logic of adding tool to the list across methods https://github.com/wisedev-code/MaIN.NET/pull/98#discussion_r2454997846 + public sealed class ToolsConfigurationBuilder { + private static readonly JsonSerializerOptions s_deserializeOptions = new() { PropertyNameCaseInsensitive = true }; private readonly ToolsConfiguration _config = new() { Tools = [] }; - - public ToolsConfigurationBuilder AddDefaultTool( - string type) + + public ToolsConfigurationBuilder AddDefaultTool(string type) { - _config.Tools.Add(new ToolDefinition - { - Type = type - }); + _config.Tools.Add(new ToolDefinition { Type = type }); return this; } - + public ToolsConfigurationBuilder AddTool( - string name, - string description, + string name, + string description, object parameters, Func> execute) { - _config.Tools.Add(new ToolDefinition - { - Function = new FunctionDefinition - { - Name = name, - Description = description, - Parameters = parameters - }, - Execute = execute - }); - return this; + return AddToolCore(name, description, parameters, execute); } public ToolsConfigurationBuilder AddTool( - string name, - string description, + string name, + string description, object parameters, Func execute) { - _config.Tools!.Add(new ToolDefinition - { - Function = new FunctionDefinition - { - Name = name, - Description = description, - Parameters = parameters - }, - Execute = args => Task.FromResult(execute(args)) - }); - return this; + return AddToolCore(name, description, parameters, args => Task.FromResult(execute(args))); } public ToolsConfigurationBuilder AddTool( - string name, - string description, + string name, + string description, object parameters, Func> execute) where TArgs : class { - _config.Tools.Add(new ToolDefinition - { - Function = new FunctionDefinition + return AddToolCore(name, description, parameters, async argsJson => { - Name = name, - Description = description, - Parameters = parameters - }, - Execute = async (argsJson) => - { - var args = JsonSerializer.Deserialize(argsJson, - new JsonSerializerOptions { PropertyNameCaseInsensitive = true })!; - var result = await execute(args); - return JsonSerializer.Serialize(result); - } - }); - return this; + var args = JsonSerializer.Deserialize(argsJson, s_deserializeOptions)!; + return JsonSerializer.Serialize(await execute(args)); + }); } public ToolsConfigurationBuilder AddTool( - string name, - string description, + string name, + string description, object parameters, Func execute) where TArgs : class { - _config.Tools!.Add(new ToolDefinition - { - Function = new FunctionDefinition + return AddToolCore(name, description, parameters, argsJson => { - Name = name, - Description = description, - Parameters = parameters - }, - Execute = (argsJson) => - { - var args = JsonSerializer.Deserialize(argsJson, - new JsonSerializerOptions { PropertyNameCaseInsensitive = true })!; - var result = execute(args); - return Task.FromResult(JsonSerializer.Serialize(result)); - } - }); - return this; + var args = JsonSerializer.Deserialize(argsJson, s_deserializeOptions)!; + return Task.FromResult(JsonSerializer.Serialize(execute(args))); + }); } public ToolsConfigurationBuilder AddTool( - string name, + string name, string description, Func> execute) { - _config.Tools.Add(new ToolDefinition - { - Function = new FunctionDefinition - { - Name = name, - Description = description, - Parameters = new { type = "object", properties = new { } } - }, - Execute = async (args) => - { - var result = await execute(); - return JsonSerializer.Serialize(result); - } - }); - return this; + return AddToolCore( + name, + description, + new { type = "object", properties = new { } }, + async _ => JsonSerializer.Serialize(await execute())); } public ToolsConfigurationBuilder AddTool( - string name, + string name, string description, Func execute) + => AddToolCore( + name, + description, + new { type = "object", properties = new { } }, + _ => Task.FromResult(JsonSerializer.Serialize(execute()))); + + private ToolsConfigurationBuilder AddToolCore( + string name, + string description, + object parameters, + Func> execute) { _config.Tools.Add(new ToolDefinition { - Function = new FunctionDefinition - { - Name = name, - Description = description, - Parameters = new { type = "object", properties = new { } } - }, - Execute = (args) => - { - var result = execute(); - return Task.FromResult(JsonSerializer.Serialize(result)); - } + Function = new FunctionDefinition { Name = name, Description = description, Parameters = parameters }, + Execute = execute }); return this; } @@ -155,5 +101,12 @@ public ToolsConfigurationBuilder WithToolChoice(string choice) return this; } + public ToolsConfigurationBuilder WithMaxIterations(int maxIterations) + { + InvalidToolIterationsException.ThrowIfInvalid(maxIterations); + _config.MaxIterations = maxIterations; + return this; + } + public ToolsConfiguration Build() => _config; -} \ No newline at end of file +} diff --git a/src/MaIN.Domain/Entities/Tools/ToolsConfiguration.cs b/src/MaIN.Domain/Entities/Tools/ToolsConfiguration.cs index 6bf99588..22227f72 100644 --- a/src/MaIN.Domain/Entities/Tools/ToolsConfiguration.cs +++ b/src/MaIN.Domain/Entities/Tools/ToolsConfiguration.cs @@ -4,9 +4,10 @@ public class ToolsConfiguration { public required List Tools { get; set; } public string? ToolChoice { get; set; } - + public int? MaxIterations { get; set; } + public Func>? GetExecutor(string functionName) { return Tools.FirstOrDefault(t => t.Function!.Name == functionName)?.Execute; } -} \ No newline at end of file +} diff --git a/src/MaIN.Domain/Exceptions/Tools/InvalidToolIterationsException.cs b/src/MaIN.Domain/Exceptions/Tools/InvalidToolIterationsException.cs new file mode 100644 index 00000000..ac8696ab --- /dev/null +++ b/src/MaIN.Domain/Exceptions/Tools/InvalidToolIterationsException.cs @@ -0,0 +1,18 @@ +using System.Net; + +namespace MaIN.Domain.Exceptions.Tools; + +public class InvalidToolIterationsException(int value) + : MaINCustomException($"MaxIterations must be at least 1, but received {value}.") +{ + public override string PublicErrorMessage => Message; + public override HttpStatusCode HttpStatusCode => HttpStatusCode.BadRequest; + + public static void ThrowIfInvalid(int value) + { + if (value < 1) + { + throw new InvalidToolIterationsException(value); + } + } +} diff --git a/src/MaIN.Services/Services/Abstract/IMcpService.cs b/src/MaIN.Services/Services/Abstract/IMcpService.cs index 67a0f761..79817b20 100644 --- a/src/MaIN.Services/Services/Abstract/IMcpService.cs +++ b/src/MaIN.Services/Services/Abstract/IMcpService.cs @@ -1,9 +1,9 @@ -using MaIN.Domain.Entities; +using MaIN.Domain.Entities; using MaIN.Services.Services.Models; namespace MaIN.Services.Services.Abstract; public interface IMcpService { - Task Prompt(Mcp config, List messageHistory); -} \ No newline at end of file + Task Prompt(Mcp config, List messageHistory, int? maxIterations = null); +} diff --git a/src/MaIN.Services/Services/LLMService/AnthropicService.cs b/src/MaIN.Services/Services/LLMService/AnthropicService.cs index 8c7b961a..acb79611 100644 --- a/src/MaIN.Services/Services/LLMService/AnthropicService.cs +++ b/src/MaIN.Services/Services/LLMService/AnthropicService.cs @@ -1,20 +1,20 @@ -using MaIN.Domain.Entities; -using MaIN.Domain.Models; -using MaIN.Services.Constants; -using MaIN.Services.Services.Abstract; -using MaIN.Services.Services.Models; -using MaIN.Services.Utils; -using Microsoft.Extensions.Logging; using System.Collections.Concurrent; -using System.Text.Json; using System.Text; +using System.Text.Json; using LLama.Common; using MaIN.Domain.Configuration; +using MaIN.Domain.Configuration.BackendInferenceParams; +using MaIN.Domain.Entities; using MaIN.Domain.Entities.Tools; using MaIN.Domain.Exceptions; +using MaIN.Domain.Models; using MaIN.Domain.Models.Concrete; +using MaIN.Services.Constants; +using MaIN.Services.Services.Abstract; using MaIN.Services.Services.LLMService.Utils; -using MaIN.Domain.Configuration.BackendInferenceParams; +using MaIN.Services.Services.Models; +using MaIN.Services.Utils; +using Microsoft.Extensions.Logging; namespace MaIN.Services.Services.LLMService; @@ -43,7 +43,9 @@ private HttpClient CreateAnthropicHttpClient(bool requireSkillsBeta = false) client.DefaultRequestHeaders.Add("anthropic-version", "2023-06-01"); if (requireSkillsBeta) + { client.DefaultRequestHeaders.Add("anthropic-beta", ServiceConstants.AnthropicBetaFeatures.SkillsBetaHeader); + } return client; } @@ -73,7 +75,9 @@ private void ValidateApiKey() ValidateApiKey(); if (!chat.Messages.Any()) + { return null; + } var lastMessage = chat.Messages.Last(); await ChatHelper.ExtractImageFromFiles(lastMessage); @@ -158,8 +162,9 @@ private async Task ProcessWithToolsAsync( StringBuilder fullResponseBuilder = new(); int iterations = 0; List? currentToolUses = null; + var maxToolIterations = chat.ToolsConfiguration?.MaxIterations ?? MaxToolIterations; - while (iterations < MaxToolIterations) + while (iterations < maxToolIterations) { if (iterations > 0 && fullResponseBuilder.Length > 0) { @@ -167,7 +172,9 @@ private async Task ProcessWithToolsAsync( tokens.Add(spaceToken); if (options.TokenCallback != null) + { await options.TokenCallback(spaceToken); + } if (options.InteractiveUpdates) { @@ -291,10 +298,10 @@ await notificationService.DispatchNotification( iterations++; } - if (iterations >= MaxToolIterations) + if (iterations >= maxToolIterations) { logger?.LogWarning("Maximum tool iterations ({MaxIterations}) reached for chat {ChatId}", - MaxToolIterations, chat.Id); + maxToolIterations, chat.Id); } var finalResponse = fullResponseBuilder.ToString(); @@ -350,9 +357,15 @@ await notificationService.DispatchNotification( while (true) { var line = await reader.ReadLineAsync(cancellationToken); - if (line is null) break; + if (line is null) + { + break; + } + if (string.IsNullOrWhiteSpace(line)) + { continue; + } if (line.StartsWith("event:")) { @@ -392,7 +405,9 @@ await notificationService.DispatchNotification( resultBuilder.Append(chunk.Delta.Text); if (options.TokenCallback != null) + { await options.TokenCallback(token); + } if (options.InteractiveUpdates) { @@ -429,12 +444,12 @@ await notificationService.DispatchNotification( return null; } - + private async Task HandleApiError(HttpResponseMessage response, CancellationToken cancellationToken = default) { var errorResponseBody = await response.Content.ReadAsStringAsync(cancellationToken); var errorMessage = ExtractApiErrorMessage(errorResponseBody); - + throw new LLMApiException(LLMApiRegistry.Anthropic.ApiName, response.StatusCode, errorMessage ?? errorResponseBody); } @@ -529,15 +544,28 @@ private async Task> BuildAnthropicRequestBody(Chat ch if (anthParams != null) { - if (anthParams.Temperature.HasValue) requestBody["temperature"] = anthParams.Temperature.Value; - if (anthParams.TopP.HasValue) requestBody["top_p"] = anthParams.TopP.Value; - if (anthParams.TopK.HasValue) requestBody["top_k"] = anthParams.TopK.Value; + if (anthParams.Temperature.HasValue) + { + requestBody["temperature"] = anthParams.Temperature.Value; + } + + if (anthParams.TopP.HasValue) + { + requestBody["top_p"] = anthParams.TopP.Value; + } + + if (anthParams.TopK.HasValue) + { + requestBody["top_k"] = anthParams.TopK.Value; + } } if (chat.BackendParams?.AdditionalParams != null) { foreach (var (key, value) in chat.BackendParams.AdditionalParams) + { requestBody[key] = value; + } } if (systemMessage != null && systemMessage.Content is string systemContent) @@ -587,7 +615,9 @@ private async Task> BuildAnthropicRequestBody(Chat ch } if (toolsList.Count > 0) + { requestBody["tools"] = toolsList; + } return requestBody; } @@ -604,7 +634,7 @@ public async Task GetCurrentModels() var httpClient = CreateAnthropicHttpClient(); using var response = await httpClient.GetAsync(ModelsUrl); - + if (!response.IsSuccessStatusCode) { await HandleApiError(response); @@ -683,9 +713,15 @@ private async Task ProcessStreamingChatAsync( while (true) { var line = await reader.ReadLineAsync(cancellationToken); - if (line is null) break; + if (line is null) + { + break; + } + if (string.IsNullOrWhiteSpace(line)) + { continue; + } if (line.StartsWith("data:")) { @@ -741,7 +777,7 @@ private async Task ProcessNonStreamingChatAsync( var content = new StringContent(requestJson, Encoding.UTF8, "application/json"); using var response = await httpClient.PostAsync(CompletionsUrl, content, cancellationToken); - + if (!response.IsSuccessStatusCode) { await HandleApiError(response, cancellationToken); @@ -884,4 +920,4 @@ file class AnthropicModelListResponse file class AnthropicModelInfo { public required string Id { get; set; } -} \ No newline at end of file +} diff --git a/src/MaIN.Services/Services/LLMService/LLMService.cs b/src/MaIN.Services/Services/LLMService/LLMService.cs index d031d532..6b317ae7 100644 --- a/src/MaIN.Services/Services/LLMService/LLMService.cs +++ b/src/MaIN.Services/Services/LLMService/LLMService.cs @@ -1,3 +1,6 @@ +using System.Collections.Concurrent; +using System.Text; +using System.Text.Json; using LLama; using LLama.Batched; using LLama.Common; @@ -5,9 +8,9 @@ using LLama.Sampling; using MaIN.Domain.Configuration; using MaIN.Domain.Entities; +using MaIN.Domain.Entities.Tools; using MaIN.Domain.Exceptions; using MaIN.Domain.Exceptions.Models; -using MaIN.Domain.Entities.Tools; using MaIN.Domain.Models; using MaIN.Domain.Models.Abstract; using MaIN.Services.Constants; @@ -17,9 +20,6 @@ using MaIN.Services.Services.Models; using MaIN.Services.Utils; using Microsoft.KernelMemory; -using System.Collections.Concurrent; -using System.Text; -using System.Text.Json; using Grammar = LLama.Sampling.Grammar; using LocalInferenceParams = MaIN.Domain.Entities.LocalInferenceParams; #pragma warning disable KMEXP00 @@ -57,7 +57,7 @@ public LLMService( CancellationToken cancellationToken = default) { chat.BackendParams ??= new LocalInferenceParams(); - + if (chat.BackendParams is not LocalInferenceParams) { throw new InvalidBackendParamsException("Local LLM", nameof(LocalInferenceParams), chat.BackendParams.GetType().Name); @@ -86,6 +86,7 @@ public LLMService( { return await ProcessWithToolsAsync(chat, requestOptions, cancellationToken); } + var tokens = await ProcessChatRequest(chat, model, lastMsg, requestOptions, cancellationToken); lastMsg.MarkProcessed(); return await CreateChatResult(chat, tokens, requestOptions); @@ -150,18 +151,24 @@ public Task CleanSessionCache(string? id) ModelLoader.RemoveModel(model.FileName); textGenerator.Dispose(); } + generator._embedder.Dispose(); generator._embedder._weights.Dispose(); generator.Dispose(); var ctxBuilder = new StringBuilder(); foreach (var citation in searchResult.Results.SelectMany(r => r.Partitions)) + { ctxBuilder.AppendLine(citation.Text); + } var originalContent = userMessage.Content; if (ctxBuilder.Length > 0) + { userMessage.Content = $"Use the following context to answer the question:\n\n{ctxBuilder}\n\nQuestion: {originalContent}"; + } + userMessage.Files = null; var chatResult = await Send(chat, requestOptions, cancellationToken); @@ -169,7 +176,6 @@ public Task CleanSessionCache(string? id) return chatResult; } - MemoryAnswer result; var tokens = new List(); @@ -574,7 +580,9 @@ private static LocalModel GetLocalModel(Chat chat) { // Try registry lookup (TryGetById to avoid throwing for unregistered models) if (ModelRegistry.TryGetById(chat.ModelId, out var model) && model is LocalModel localModel) + { return localModel; + } // 3. Fallback: create generic local model for unregistered models var modelId = chat.ModelId; @@ -592,8 +600,11 @@ private static LocalModel GetLocalModel(Chat chat) // Model name only — replace ':' with '-' (colon is illegal in Windows file names) fileName = modelId.Replace(':', '-'); if (!fileName.EndsWith(".gguf", StringComparison.OrdinalIgnoreCase)) + { fileName += ".gguf"; + } } + return new GenericLocalModel(FileName: fileName); } @@ -613,7 +624,10 @@ private string GetModelsPath() private string ResolvePath(string? customPath, string fileName) { if (Path.IsPathFullyQualified(fileName)) + { return fileName; + } + return Path.Combine(customPath ?? modelsPath, fileName); } @@ -667,8 +681,9 @@ private async Task ProcessWithToolsAsync( var iterations = 0; var lastResponseTokens = new List(); var lastResponse = string.Empty; + var maxToolIterations = chat.ToolsConfiguration?.MaxIterations ?? MaxToolIterations; - while (iterations < MaxToolIterations) + while (iterations < maxToolIterations) { var lastMsg = chat.Messages.Last(); var tokenCallbackOrg = requestOptions.TokenCallback; @@ -797,7 +812,7 @@ await requestOptions.ToolCallback.Invoke(new ToolInvocation iterations++; } - if (iterations >= MaxToolIterations) + if (iterations >= maxToolIterations) { var errorMessage = "Maximum tool invocation iterations reached. Ending the tool-loop prematurely."; var iterationMessage = new Message @@ -834,5 +849,4 @@ private static string GetFinalPrompt(Message message, AIModel model, bool startS ? $"{message.Content}{additionalPrompt}" : message.Content; } - } diff --git a/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs b/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs index 4bd89120..d3b0562a 100644 --- a/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs +++ b/src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs @@ -1,22 +1,22 @@ -using MaIN.Domain.Models; -using MaIN.Services.Constants; -using MaIN.Services.Services.Abstract; -using MaIN.Services.Services.LLMService.Utils; -using MaIN.Services.Services.Models; -using MaIN.Services.Utils; -using Microsoft.Extensions.Logging; -using Microsoft.KernelMemory; using System.Collections.Concurrent; using System.Net.Http.Headers; using System.Net.Mime; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; -using MaIN.Domain.Entities; -using MaIN.Services.Services.LLMService.Memory; using LLama.Common; +using MaIN.Domain.Entities; using MaIN.Domain.Entities.Tools; using MaIN.Domain.Exceptions; +using MaIN.Domain.Models; +using MaIN.Services.Constants; +using MaIN.Services.Services.Abstract; +using MaIN.Services.Services.LLMService.Memory; +using MaIN.Services.Services.LLMService.Utils; +using MaIN.Services.Services.Models; +using MaIN.Services.Utils; +using Microsoft.Extensions.Logging; +using Microsoft.KernelMemory; namespace MaIN.Services.Services.LLMService; @@ -36,7 +36,6 @@ public abstract class OpenAiCompatibleService( private static readonly JsonSerializerOptions DefaultJsonSerializerOptions = new() { PropertyNameCaseInsensitive = true }; - protected abstract string GetApiKey(); protected abstract string GetApiName(); protected abstract void ValidateApiKey(); @@ -51,14 +50,16 @@ public abstract class OpenAiCompatibleService( ChatRequestOptions options, CancellationToken cancellationToken = default) { - if (chat.BackendParams != null && chat.BackendParams?.GetType() != ExpectedParamsType) + if (chat.BackendParams is not null && chat.BackendParams?.GetType() != ExpectedParamsType) { throw new InvalidBackendParamsException(GetApiName(), ExpectedParamsType.Name, chat.BackendParams!.GetType().Name); } ValidateApiKey(); if (!chat.Messages.Any()) + { return null; + } List tokens = new(); string apiKey = GetApiKey(); @@ -78,7 +79,7 @@ public abstract class OpenAiCompatibleService( return CreateChatResult(chat, resultBuilder.ToString(), memoryResult.Message.Tokens); } - if (chat.ToolsConfiguration?.Tools != null && chat.ToolsConfiguration.Tools.Any()) + if (chat.ToolsConfiguration?.Tools is not null && chat.ToolsConfiguration.Tools.Any()) { return await ProcessWithToolsAsync( chat, @@ -89,7 +90,7 @@ public abstract class OpenAiCompatibleService( cancellationToken); } - if (options.InteractiveUpdates || options.TokenCallback != null) + if (options.InteractiveUpdates || options.TokenCallback is not null) { await ProcessStreamingChatAsync( chat, @@ -135,26 +136,29 @@ private async Task ProcessWithToolsAsync( CancellationToken cancellationToken) { StringBuilder resultBuilder = new(); - StringBuilder fullResponseBuilder = new(); + StringBuilder fullResponseBuilder = new(); int iterations = 0; + var maxToolIterations = chat.ToolsConfiguration?.MaxIterations ?? MaxToolIterations; - while (iterations < MaxToolIterations) + while (iterations < maxToolIterations) { if (iterations > 0 && options.InteractiveUpdates && fullResponseBuilder.Length > 0) { var spaceToken = new LLMTokenValue { Text = " ", Type = TokenType.Message }; tokens.Add(spaceToken); - - if (options.TokenCallback != null) + + if (options.TokenCallback is not null) + { await options.TokenCallback(spaceToken); - + } + await _notificationService.DispatchNotification( NotificationMessageBuilder.CreateChatCompletion(chat.Id, spaceToken, false), ServiceConstants.Notifications.ReceiveMessageUpdate); } List? currentToolCalls; - if (options.InteractiveUpdates || options.TokenCallback != null) + if (options.InteractiveUpdates || options.TokenCallback is not null) { currentToolCalls = await ProcessStreamingChatWithToolsAsync( chat, @@ -175,17 +179,18 @@ await _notificationService.DispatchNotification( options, cancellationToken); } - - if (resultBuilder.Length > 0) + + if (resultBuilder.Length > 0) { if (fullResponseBuilder.Length > 0) { - fullResponseBuilder.Append(" "); + fullResponseBuilder.Append(' '); } + fullResponseBuilder.Append(resultBuilder); } - - if (currentToolCalls == null || !currentToolCalls.Any()) + + if (currentToolCalls is null || currentToolCalls.Count == 0) { break; } @@ -208,7 +213,7 @@ await _notificationService.DispatchNotification( var executor = chat.ToolsConfiguration?.GetExecutor(toolCall.Function.Name); - if (executor == null) + if (executor is null) { var errorMessage = $"No executor found for tool: {toolCall.Function.Name}"; logger?.LogError(errorMessage); @@ -240,7 +245,7 @@ await _notificationService.DispatchNotification( catch (Exception ex) { logger?.LogError(ex, "Error executing tool {ToolName}", toolCall.Function.Name); - + var errorResult = JsonSerializer.Serialize(new { error = ex.Message }); var toolMessage = new ChatMessage(ServiceConstants.Roles.Tool, errorResult) { @@ -255,10 +260,10 @@ await _notificationService.DispatchNotification( iterations++; } - if (iterations >= MaxToolIterations) + if (iterations >= maxToolIterations) { - logger?.LogWarning("Maximum tool iterations ({MaxIterations}) reached for chat {ChatId}", - MaxToolIterations, chat.Id); + logger?.LogWarning("Maximum tool iterations ({MaxIterations}) reached for chat {ChatId}", + maxToolIterations, chat.Id); } var finalResponse = fullResponseBuilder.ToString(); @@ -317,15 +322,23 @@ await _notificationService.DispatchNotification( while (true) { var line = await reader.ReadLineAsync(cancellationToken); - if (line is null) break; + if (line is null) + { + break; + } + if (string.IsNullOrWhiteSpace(line)) + { continue; + } if (line.StartsWith("data: ")) { var data = line.Substring("data: ".Length).Trim(); if (data == "[DONE]") + { break; + } try { @@ -333,7 +346,7 @@ await _notificationService.DispatchNotification( new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); var choice = chunk?.Choices?.FirstOrDefault(); - if (choice?.Delta != null) + if (choice?.Delta is not null) { // Handle content if (!string.IsNullOrEmpty(choice.Delta.Content)) @@ -356,7 +369,7 @@ await _notificationService.DispatchNotification( } } - if (choice.Delta.ToolCalls != null) + if (choice.Delta.ToolCalls is not null) { foreach (var toolCallChunk in choice.Delta.ToolCalls) { @@ -368,18 +381,26 @@ await _notificationService.DispatchNotification( var builder = toolCallsBuilder[toolCallChunk.Index]; if (!string.IsNullOrEmpty(toolCallChunk.Id)) + { builder.Id = toolCallChunk.Id; + } if (!string.IsNullOrEmpty(toolCallChunk.Type)) + { builder.Type = toolCallChunk.Type; + } - if (toolCallChunk.Function != null) + if (toolCallChunk.Function is not null) { if (!string.IsNullOrEmpty(toolCallChunk.Function.Name)) + { builder.FunctionName = toolCallChunk.Function.Name; + } if (!string.IsNullOrEmpty(toolCallChunk.Function.Arguments)) + { builder.FunctionArguments.Append(toolCallChunk.Function.Arguments); + } } } } @@ -393,9 +414,9 @@ await _notificationService.DispatchNotification( } // Build final tool calls from accumulated chunks - if (toolCallsBuilder.Any()) + if (toolCallsBuilder.Count != 0) { - return toolCallsBuilder.Values.Select(b => b.Build()).ToList(); + return [.. toolCallsBuilder.Values.Select(b => b.Build())]; } return null; @@ -418,7 +439,7 @@ await _notificationService.DispatchNotification( var content = new StringContent(requestJson, Encoding.UTF8, MediaTypeNames.Application.Json); using var response = await client.PostAsync(ChatCompletionsUrl, content, cancellationToken); - + if (!response.IsSuccessStatusCode) { await HandleApiError(response, cancellationToken); @@ -430,7 +451,7 @@ await _notificationService.DispatchNotification( var message = chatResponse?.Choices?.FirstOrDefault()?.Message; - if (message?.Content != null) + if (message?.Content is not null) { resultBuilder.Append(message.Content); } @@ -444,8 +465,10 @@ await _notificationService.DispatchNotification( ChatRequestOptions requestOptions, CancellationToken cancellationToken = default) { - if (!chat.Messages.Any()) + if (chat.Messages.Count == 0) + { return null; + } var kernel = memoryFactory.CreateMemoryWithOpenAi(GetApiKey(), chat.MemoryParams); await memoryService.ImportDataToMemory((kernel, null), memoryOptions, cancellationToken); @@ -453,7 +476,7 @@ await _notificationService.DispatchNotification( var lastMessage = chat.Messages.Last(); var userQuery = lastMessage.Content; - if (chat.MemoryParams.Grammar != null) + if (chat.MemoryParams.Grammar is not null) { var jsonGrammarConverter = new GrammarToJsonConverter(); var jsonGrammar = jsonGrammarConverter.ConvertToJson(chat.MemoryParams.Grammar); @@ -496,7 +519,7 @@ await _notificationService.DispatchNotification( var tokens = new List(); var resultBuilder = new StringBuilder(); - if (requestOptions.InteractiveUpdates || requestOptions.TokenCallback != null) + if (requestOptions.InteractiveUpdates || requestOptions.TokenCallback is not null) { await ProcessStreamingChatAsync(chat, conversation, GetApiKey(), tokens, resultBuilder, requestOptions, cancellationToken); } @@ -522,7 +545,7 @@ await _notificationService.DispatchNotification( MemoryAnswer retrievedContext; var standardTokens = new List(); - if (requestOptions.InteractiveUpdates || requestOptions.TokenCallback != null) + if (requestOptions.InteractiveUpdates || requestOptions.TokenCallback is not null) { var responseBuilder = new StringBuilder(); @@ -594,10 +617,15 @@ await notificationService.DispatchNotification( var contextBuilder = new StringBuilder(); foreach (var partition in retrievedChunks.OrderByDescending(p => p.Relevance)) { - if (string.IsNullOrWhiteSpace(partition.Text)) continue; + if (string.IsNullOrWhiteSpace(partition.Text)) + { + continue; + } + contextBuilder.AppendLine(partition.Text); contextBuilder.AppendLine(); } + var fallback = contextBuilder.ToString().TrimEnd(); logger?.LogInformation( @@ -626,20 +654,20 @@ public virtual async Task GetCurrentModels() SetAuthorizationIfNeeded(client, GetApiKey()); using var response = await client.GetAsync(ModelsUrl); - + if (!response.IsSuccessStatusCode) { await HandleApiError(response); } var responseJson = await response.Content.ReadAsStringAsync(); - var modelsResponse = JsonSerializer.Deserialize(responseJson, + var modelsResponse = JsonSerializer.Deserialize(responseJson, new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); return (modelsResponse?.Data? .Select(m => m.Id) - .Where(id => id != null) - .ToArray() + .Where(id => id is not null) + .ToArray() ?? [])!; } @@ -675,27 +703,31 @@ protected void UpdateSessionCache(string chatId, string assistantResponse, bool protected static async Task ExtractImageFromFiles(Message message) { - if (message.Files == null || message.Files.Count == 0) + if (message.Files is null || message.Files.Count == 0) + { return; + } var imageFiles = message.Files .Where(f => ImageExtensions.Contains(f.Extension.ToLowerInvariant())) .ToList(); if (imageFiles.Count == 0) + { return; + } var imageBytesList = new List(); foreach (var imageFile in imageFiles) { - if (imageFile.StreamContent != null) + if (imageFile.StreamContent is not null) { using var ms = new MemoryStream(); imageFile.StreamContent.Position = 0; await imageFile.StreamContent.CopyToAsync(ms); imageBytesList.Add(ms.ToArray()); } - else if (imageFile.Path != null) + else if (imageFile.Path is not null) { imageBytesList.Add(await File.ReadAllBytesAsync(imageFile.Path)); } @@ -706,12 +738,14 @@ protected static async Task ExtractImageFromFiles(Message message) message.Images = imageBytesList; if (message.Files.Count == 0) + { message.Files = null; + } } protected static bool HasFiles(Message message) { - return message.Files != null && message.Files.Count > 0; + return message.Files is not null && message.Files.Count > 0; } private static void SetAuthorizationIfNeeded(HttpClient client, string apiKey) @@ -734,7 +768,7 @@ private async Task ProcessStreamingChatAsync( SetAuthorizationIfNeeded(client, apiKey); var requestBody = BuildRequestBody(chat, conversation, true); - + var requestJson = JsonSerializer.Serialize(requestBody); var content = new StringContent(requestJson, Encoding.UTF8, MediaTypeNames.Application.Json); @@ -761,15 +795,23 @@ private async Task ProcessStreamingChatAsync( cancellationToken.ThrowIfCancellationRequested(); var line = await reader.ReadLineAsync(cancellationToken); - if (line is null) break; + if (line is null) + { + break; + } + if (string.IsNullOrWhiteSpace(line)) + { continue; + } if (line.StartsWith("data: ")) { var data = line.Substring("data: ".Length).Trim(); if (data == "[DONE]") + { break; + } try { @@ -805,7 +847,7 @@ private async Task HandleApiError(HttpResponseMessage response, CancellationToke { var errorResponseBody = await response.Content.ReadAsStringAsync(cancellationToken); var errorMessage = ExtractApiErrorMessage(errorResponseBody); - + throw new LLMApiException(GetApiName(), response.StatusCode, errorMessage ?? errorResponseBody); } @@ -814,7 +856,7 @@ private async Task HandleApiError(HttpResponseMessage response, CancellationToke try { using var jasonDocument = JsonDocument.Parse(json); - + if (jasonDocument.RootElement.ValueKind == JsonValueKind.Array) { var firstElement = jasonDocument.RootElement[0]; @@ -836,7 +878,7 @@ private async Task HandleApiError(HttpResponseMessage response, CancellationToke // we fall back to the raw response body in the calling method. return null; } - + return null; } @@ -869,7 +911,7 @@ private async Task ProcessNonStreamingChatAsync( var content = new StringContent(requestJson, Encoding.UTF8, MediaTypeNames.Application.Json); using var response = await client.PostAsync(ChatCompletionsUrl, content, cancellationToken); - + if (!response.IsSuccessStatusCode) { await HandleApiError(response, cancellationToken); @@ -880,7 +922,7 @@ private async Task ProcessNonStreamingChatAsync( JsonSerializer.Deserialize(responseJson, DefaultJsonSerializerOptions); var responseContent = chatResponse?.Choices?.FirstOrDefault()?.Message?.Content; - if (responseContent != null) + if (responseContent is not null) { resultBuilder.Append(responseContent); } @@ -898,12 +940,12 @@ private object BuildRequestBody(Chat chat, List conversation, bool ApplyBackendParams(requestBody, chat); ApplyAdditionalParams(requestBody, chat); - if (chat.ToolsConfiguration?.Tools != null && chat.ToolsConfiguration.Tools.Any()) + if (chat.ToolsConfiguration?.Tools is not null && chat.ToolsConfiguration.Tools.Any()) { requestBody["tools"] = chat.ToolsConfiguration.Tools.Select(t => new { type = t.Type, - function = t.Function != null ? new + function = t.Function is not null ? new { name = t.Function.Name, description = t.Function.Description, @@ -926,14 +968,17 @@ protected virtual void ApplyBackendParams(Dictionary requestBody private static void ApplyAdditionalParams(Dictionary requestBody, Chat chat) { - if (chat.BackendParams?.AdditionalParams == null) return; + if (chat.BackendParams?.AdditionalParams is null) + { + return; + } + foreach (var (key, value) in chat.BackendParams.AdditionalParams) { requestBody[key] = value; } } - protected static ChatResult CreateChatResult(Chat chat, string content, List tokens) { return new ChatResult @@ -953,7 +998,7 @@ protected static ChatResult CreateChatResult(Chat chat, string content, List? callback, LLMTokenValue token) { - if (callback != null) + if (callback is not null) { await callback.Invoke(token); } @@ -1034,7 +1079,7 @@ file class ChoiceChunk file class Delta { public string? Content { get; set; } - + [JsonPropertyName("tool_calls")] public List? ToolCalls { get; set; } } @@ -1054,11 +1099,11 @@ file class FunctionCallChunk } file class OpenAiModelsResponse -{ +{ public List? Data { get; set; } } file class OpenAiModel { public string? Id { get; set; } -} \ No newline at end of file +} diff --git a/src/MaIN.Services/Services/LLMService/OpenAiService.cs b/src/MaIN.Services/Services/LLMService/OpenAiService.cs index acca2784..ad30a585 100644 --- a/src/MaIN.Services/Services/LLMService/OpenAiService.cs +++ b/src/MaIN.Services/Services/LLMService/OpenAiService.cs @@ -250,7 +250,7 @@ private async Task RunResponsesToolLoopAsync( string apiKey, CancellationToken cancellationToken) { - const int maxIterations = 5; + var maxIterations = chat.ToolsConfiguration?.MaxIterations ?? MaxToolIterations; string responseJson = string.Empty; for (var iteration = 0; iteration < maxIterations; iteration++) diff --git a/src/MaIN.Services/Services/McpService.cs b/src/MaIN.Services/Services/McpService.cs index 3d1c5f5d..85ce5aca 100644 --- a/src/MaIN.Services/Services/McpService.cs +++ b/src/MaIN.Services/Services/McpService.cs @@ -1,27 +1,26 @@ +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; using MaIN.Domain.Configuration; using MaIN.Domain.Entities; using MaIN.Domain.Models.Concrete; using MaIN.Services.Services.Abstract; using MaIN.Services.Services.LLMService.Auth; -using MaIN.Services.Services.LLMService.Utils; using MaIN.Services.Services.Models; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Google; using ModelContextProtocol.Client; -using System.Net.Http.Headers; -using System.Text; -using System.Text.Json; #pragma warning disable SKEXP0001 -#pragma warning disable SKEXP0070 namespace MaIN.Services.Services; -public class McpService(MaINSettings settings, IServiceProvider serviceProvider) : IMcpService +public class McpService(MaINSettings settings, IServiceProvider serviceProvider, ILogger? logger = null) : IMcpService { - public async Task Prompt(Mcp config, List messageHistory) + public async Task Prompt(Mcp config, List messageHistory, int? maxIterations = null) { await using var mcpClient = await McpClientFactory.CreateAsync( new StdioClientTransport( @@ -38,13 +37,15 @@ public async Task Prompt(Mcp config, List messageHistory) return backendType switch { + BackendType.Gemini or BackendType.Vertex when maxIterations.HasValue => + throw new NotSupportedException($"MaxIterations is not supported for {backendType} backend."), BackendType.Gemini or BackendType.Vertex => await PromptWithSK(mcpClient, tools, config, messageHistory, backendType), BackendType.Anthropic => - await PromptWithAnthropic(mcpClient, tools, config, messageHistory), + await PromptWithAnthropic(mcpClient, tools, config, messageHistory, maxIterations), BackendType.DeepSeek or BackendType.Ollama or BackendType.Self => throw new NotSupportedException($"{backendType} does not support MCP integration."), - _ => await PromptWithHttp(mcpClient, tools, config, messageHistory, backendType) + _ => await PromptWithHttp(mcpClient, tools, config, messageHistory, backendType, maxIterations) }; } @@ -55,7 +56,8 @@ private async Task PromptWithHttp( IList tools, Mcp config, List messageHistory, - BackendType backendType) + BackendType backendType, + int? maxIterations = null) { var (url, apiKey) = GetEndpointAndKey(backendType, config); @@ -82,8 +84,8 @@ private async Task PromptWithHttp( }) .ToList(); - const int maxIterations = 10; - for (int i = 0; i < maxIterations; i++) + var effectiveMaxIterations = maxIterations ?? 10; + for (int i = 0; i < effectiveMaxIterations; i++) { var requestBody = new Dictionary { @@ -123,7 +125,7 @@ private async Task PromptWithHttp( continue; } - return BuildResult(content, config.Model); + return McpService.BuildResult(content, config.Model); } // Add assistant message with tool calls (preserve raw JSON element) @@ -142,28 +144,7 @@ private async Task PromptWithHttp( var argsJson = toolCall.GetProperty("function").GetProperty("arguments").GetString() ?? "{}"; var toolCallId = toolCall.GetProperty("id").GetString()!; - var argsDict = JsonSerializer - .Deserialize>(argsJson) - ?.ToDictionary( - kvp => kvp.Key, - kvp => (object?)(kvp.Value.ValueKind switch - { - JsonValueKind.String => (object)kvp.Value.GetString()!, - JsonValueKind.True => true, - JsonValueKind.False => false, - JsonValueKind.Number when kvp.Value.TryGetInt64(out var l) => l, - JsonValueKind.Number => (object)kvp.Value.GetDouble(), - _ => (object)kvp.Value - })) - ?? new Dictionary(); - - var toolResult = await mcpClient.CallToolAsync(toolName, argsDict); - var resultText = string.Join("\n", toolResult.Content - .Where(c => c.Text != null) - .Select(c => c.Text!)); - - if (toolResult.IsError == true) - Console.WriteLine($"[MCP] Tool '{toolName}' returned error: {resultText}"); + var resultText = await ExecuteToolAsync(mcpClient, toolName, argsJson); messages.Add(new Dictionary { @@ -174,7 +155,28 @@ JsonValueKind.Number when kvp.Value.TryGetInt64(out var l) => l, } } - return BuildResult("Max tool iterations reached.", config.Model); + logger?.LogWarning("Max tool iterations ({MaxIterations}) reached. Sending final synthesis request.", effectiveMaxIterations); + + var finalRequestBody = new Dictionary + { + ["model"] = config.Model, + ["messages"] = messages, + ["tools"] = toolDefs + }; + + var finalJson = JsonSerializer.Serialize(finalRequestBody); + var finalResponse = await client.PostAsync(url, + new StringContent(finalJson, Encoding.UTF8, "application/json")); + finalResponse.EnsureSuccessStatusCode(); + + var finalResponseText = await finalResponse.Content.ReadAsStringAsync(); + var finalDoc = JsonDocument.Parse(finalResponseText); + var finalMessage = finalDoc.RootElement + .GetProperty("choices")[0] + .GetProperty("message"); + var finalContent = finalMessage.TryGetProperty("content", out var fc) ? fc.GetString() ?? "" : ""; + + return McpService.BuildResult(finalContent, config.Model); } // Anthropic uses a different protocol: x-api-key header, input_schema instead of parameters, @@ -183,7 +185,8 @@ private async Task PromptWithAnthropic( IMcpClient mcpClient, IList tools, Mcp config, - List messageHistory) + List messageHistory, + int? maxIterations = null) { var apiKey = GetAnthropicKey() ?? throw new InvalidOperationException("Anthropic API key not configured."); var httpClientFactory = serviceProvider.GetRequiredService(); @@ -211,8 +214,8 @@ private async Task PromptWithAnthropic( }) .ToList(); - const int maxIterations = 10; - for (int i = 0; i < maxIterations; i++) + var effectiveMaxIterations = maxIterations ?? 10; + for (int i = 0; i < effectiveMaxIterations; i++) { var requestBody = new Dictionary { @@ -224,8 +227,10 @@ private async Task PromptWithAnthropic( ? (object)new Dictionary { ["type"] = "any" } : new Dictionary { ["type"] = "auto" } }; - if (systemContent != null) + if (systemContent is not null) + { requestBody["system"] = systemContent; + } var json = JsonSerializer.Serialize(requestBody); var response = await client.PostAsync("https://api.anthropic.com/v1/messages", @@ -262,14 +267,19 @@ private async Task PromptWithAnthropic( }); continue; } - return BuildResult(textContent, config.Model); + + return McpService.BuildResult(textContent, config.Model); } // Add assistant turn with tool_use blocks var assistantContent = new List(); if (!string.IsNullOrEmpty(textContent)) + { assistantContent.Add(new Dictionary { ["type"] = "text", ["text"] = textContent }); + } + foreach (var tu in toolUses) + { assistantContent.Add(new Dictionary { ["type"] = "tool_use", @@ -277,6 +287,8 @@ private async Task PromptWithAnthropic( ["name"] = tu.GetProperty("name").GetString()!, ["input"] = tu.GetProperty("input") }); + } + messages.Add(new Dictionary { ["role"] = "assistant", ["content"] = assistantContent }); // Execute tools and collect tool_result blocks @@ -285,30 +297,7 @@ private async Task PromptWithAnthropic( { var toolName = tu.GetProperty("name").GetString()!; var toolId = tu.GetProperty("id").GetString()!; - var inputElement = tu.GetProperty("input"); - - var argsDict = JsonSerializer - .Deserialize>(inputElement.GetRawText()) - ?.ToDictionary( - kvp => kvp.Key, - kvp => (object?)(kvp.Value.ValueKind switch - { - JsonValueKind.String => (object)kvp.Value.GetString()!, - JsonValueKind.True => true, - JsonValueKind.False => false, - JsonValueKind.Number when kvp.Value.TryGetInt64(out var l) => l, - JsonValueKind.Number => (object)kvp.Value.GetDouble(), - _ => (object)kvp.Value - })) - ?? new Dictionary(); - - var toolResult = await mcpClient.CallToolAsync(toolName, argsDict); - var resultText = string.Join("\n", toolResult.Content - .Where(c => c.Text != null) - .Select(c => c.Text!)); - - if (toolResult.IsError == true) - Console.WriteLine($"[MCP] Tool '{toolName}' returned error: {resultText}"); + var resultText = await ExecuteToolAsync(mcpClient, toolName, tu.GetProperty("input").GetRawText()); toolResults.Add(new Dictionary { @@ -317,10 +306,38 @@ JsonValueKind.Number when kvp.Value.TryGetInt64(out var l) => l, ["content"] = resultText }); } + messages.Add(new Dictionary { ["role"] = "user", ["content"] = toolResults }); } - return BuildResult("Max tool iterations reached.", config.Model); + logger?.LogWarning("Max tool iterations ({MaxIterations}) reached. Sending final synthesis request.", effectiveMaxIterations); + + var finalRequestBody = new Dictionary + { + ["model"] = config.Model, + ["max_tokens"] = 4096, + ["messages"] = messages, + ["tools"] = toolDefs + }; + if (systemContent is not null) + { + finalRequestBody["system"] = systemContent; + } + + var finalJson = JsonSerializer.Serialize(finalRequestBody); + var finalResponse = await client.PostAsync("https://api.anthropic.com/v1/messages", + new StringContent(finalJson, Encoding.UTF8, "application/json")); + finalResponse.EnsureSuccessStatusCode(); + + var finalResponseText = await finalResponse.Content.ReadAsStringAsync(); + var finalDoc = JsonDocument.Parse(finalResponseText); + var finalContent = string.Concat(finalDoc.RootElement + .GetProperty("content") + .EnumerateArray() + .Where(b => b.TryGetProperty("type", out var t) && t.GetString() == "text") + .Select(b => b.TryGetProperty("text", out var txt) ? txt.GetString() ?? "" : "")); + + return McpService.BuildResult(finalContent, config.Model); } private (string url, string apiKey) GetEndpointAndKey(BackendType backendType, Mcp config) @@ -369,7 +386,7 @@ private async Task PromptWithSK( var chatService = kernel.GetRequiredService(); var result = await chatService.GetChatMessageContentsAsync(chatHistory, promptSettings, kernel); - return BuildResult(result.Last().Content!, config.Model); + return McpService.BuildResult(result.Last().Content!, config.Model); } private PromptExecutionSettings InitializeGoogleChatCompletions(IKernelBuilder kernelBuilder, Mcp config, BackendType backendType) @@ -410,7 +427,37 @@ private PromptExecutionSettings InitializeGoogleChatCompletions(IKernelBuilder k }; } - private McpResult BuildResult(string content, string model) => new() + private static Dictionary DeserializeToolArgs(string argsJson) + { + return JsonSerializer.Deserialize>(argsJson) + ?.ToDictionary( + kvp => kvp.Key, + kvp => (object?)(kvp.Value.ValueKind switch + { + JsonValueKind.String => (object)kvp.Value.GetString()!, + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Number when kvp.Value.TryGetInt64(out var l) => l, + JsonValueKind.Number => (object)kvp.Value.GetDouble(), + _ => (object)kvp.Value + })) + ?? []; + } + + private async Task ExecuteToolAsync(IMcpClient mcpClient, string toolName, string argsJson) + { + var argsDict = DeserializeToolArgs(argsJson); + var result = await mcpClient.CallToolAsync(toolName, argsDict); + var text = string.Join("\n", result.Content.Where(c => c.Text is not null).Select(c => c.Text!)); + if (result.IsError == true) + { + logger?.LogError("MCP tool '{ToolName}' returned error: {Error}", toolName, text); + } + + return text; + } + + private static McpResult BuildResult(string content, string model) => new() { CreatedAt = DateTime.Now, Message = new Message diff --git a/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs b/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs index 75ef26e4..8164ff66 100644 --- a/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs +++ b/src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs @@ -1,3 +1,4 @@ +using System.Text.Json; using MaIN.Domain.Configuration; using MaIN.Domain.Entities; using MaIN.Domain.Entities.Agents.Knowledge; @@ -12,7 +13,6 @@ using MaIN.Services.Services.Models.Commands; using MaIN.Services.Services.Steps.Commands.Abstract; using MaIN.Services.Utils; -using System.Text.Json; namespace MaIN.Services.Services.Steps.Commands; @@ -74,7 +74,9 @@ public class AnswerCommandHandler( private async Task ShouldUseKnowledge(Knowledge? knowledge, Chat chat, BackendType backend) { if (knowledge?.Index.Items is not { Count: > 0 }) + { return false; + } var originalContent = chat.Messages.Last().Content; diff --git a/src/MaIN.Services/Services/Steps/Commands/McpCommandHandler.cs b/src/MaIN.Services/Services/Steps/Commands/McpCommandHandler.cs index 8c6b96be..69b8bf30 100644 --- a/src/MaIN.Services/Services/Steps/Commands/McpCommandHandler.cs +++ b/src/MaIN.Services/Services/Steps/Commands/McpCommandHandler.cs @@ -11,7 +11,7 @@ public class McpCommandHandler( { public async Task HandleAsync(McpCommand command) { - var result = await mcpService.Prompt(command.McpConfig, command.Chat.Messages); + var result = await mcpService.Prompt(command.McpConfig, command.Chat.Messages, command.Chat.ToolsConfiguration?.MaxIterations); return result.Message; } -} \ No newline at end of file +}