Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@
},
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.pytestArgs": ["python"]
"python.testing.pytestArgs": ["python"],
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff"
}
}
90 changes: 90 additions & 0 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ public async Task<CopilotSession> CreateSessionAsync(SessionConfig? config = nul
{
var connection = await EnsureConnectedAsync(cancellationToken);

var hasHooks = config?.Hooks != null && (
config.Hooks.OnPreToolUse != null ||
config.Hooks.OnPostToolUse != null ||
config.Hooks.OnUserPromptSubmitted != null ||
config.Hooks.OnSessionStart != null ||
config.Hooks.OnSessionEnd != null ||
config.Hooks.OnErrorOccurred != null);

var request = new CreateSessionRequest(
config?.Model,
config?.SessionId,
Expand All @@ -345,6 +353,9 @@ public async Task<CopilotSession> CreateSessionAsync(SessionConfig? config = nul
config?.ExcludedTools,
config?.Provider,
config?.OnPermissionRequest != null ? true : null,
config?.OnUserInputRequest != null ? true : null,
hasHooks ? true : null,
config?.WorkingDirectory,
config?.Streaming == true ? true : null,
config?.McpServers,
config?.CustomAgents,
Expand All @@ -362,6 +373,14 @@ public async Task<CopilotSession> CreateSessionAsync(SessionConfig? config = nul
{
session.RegisterPermissionHandler(config.OnPermissionRequest);
}
if (config?.OnUserInputRequest != null)
{
session.RegisterUserInputHandler(config.OnUserInputRequest);
}
if (config?.Hooks != null)
{
session.RegisterHooks(config.Hooks);
}

if (!_sessions.TryAdd(response.SessionId, session))
{
Expand Down Expand Up @@ -399,11 +418,23 @@ public async Task<CopilotSession> ResumeSessionAsync(string sessionId, ResumeSes
{
var connection = await EnsureConnectedAsync(cancellationToken);

var hasHooks = config?.Hooks != null && (
config.Hooks.OnPreToolUse != null ||
config.Hooks.OnPostToolUse != null ||
config.Hooks.OnUserPromptSubmitted != null ||
config.Hooks.OnSessionStart != null ||
config.Hooks.OnSessionEnd != null ||
config.Hooks.OnErrorOccurred != null);

var request = new ResumeSessionRequest(
sessionId,
config?.Tools?.Select(ToolDefinition.FromAIFunction).ToList(),
config?.Provider,
config?.OnPermissionRequest != null ? true : null,
config?.OnUserInputRequest != null ? true : null,
hasHooks ? true : null,
config?.WorkingDirectory,
config?.DisableResume == true ? true : null,
config?.Streaming == true ? true : null,
config?.McpServers,
config?.CustomAgents,
Expand All @@ -419,6 +450,14 @@ public async Task<CopilotSession> ResumeSessionAsync(string sessionId, ResumeSes
{
session.RegisterPermissionHandler(config.OnPermissionRequest);
}
if (config?.OnUserInputRequest != null)
{
session.RegisterUserInputHandler(config.OnUserInputRequest);
}
if (config?.Hooks != null)
{
session.RegisterHooks(config.Hooks);
}

// Replace any existing session entry to ensure new config (like permission handler) is used
_sessions[response.SessionId] = session;
Expand Down Expand Up @@ -804,6 +843,8 @@ private async Task<Connection> ConnectToServerAsync(Process? cliProcess, string?
rpc.AddLocalRpcMethod("session.event", handler.OnSessionEvent);
rpc.AddLocalRpcMethod("tool.call", handler.OnToolCall);
rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequest);
rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest);
rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke);
rpc.StartListening();
return new Connection(rpc, cliProcess, tcpClient, networkStream);
}
Expand Down Expand Up @@ -990,6 +1031,37 @@ public async Task<PermissionRequestResponse> OnPermissionRequest(string sessionI
});
}
}

public async Task<UserInputRequestResponse> OnUserInputRequest(string sessionId, string question, List<string>? choices = null, bool? allowFreeform = null)
{
var session = client.GetSession(sessionId);
if (session == null)
{
throw new ArgumentException($"Unknown session {sessionId}");
}

var request = new UserInputRequest
{
Question = question,
Choices = choices,
AllowFreeform = allowFreeform
};

var result = await session.HandleUserInputRequestAsync(request);
return new UserInputRequestResponse(result.Answer, result.WasFreeform);
}

public async Task<HooksInvokeResponse> OnHooksInvoke(string sessionId, string hookType, JsonElement input)
{
var session = client.GetSession(sessionId);
if (session == null)
{
throw new ArgumentException($"Unknown session {sessionId}");
}

var output = await session.HandleHooksInvokeAsync(hookType, input);
return new HooksInvokeResponse(output);
}
}

private class Connection(
Expand Down Expand Up @@ -1024,6 +1096,9 @@ internal record CreateSessionRequest(
List<string>? ExcludedTools,
ProviderConfig? Provider,
bool? RequestPermission,
bool? RequestUserInput,
bool? Hooks,
string? WorkingDirectory,
bool? Streaming,
Dictionary<string, object>? McpServers,
List<CustomAgentConfig>? CustomAgents,
Expand All @@ -1050,6 +1125,10 @@ internal record ResumeSessionRequest(
List<ToolDefinition>? Tools,
ProviderConfig? Provider,
bool? RequestPermission,
bool? RequestUserInput,
bool? Hooks,
string? WorkingDirectory,
bool? DisableResume,
bool? Streaming,
Dictionary<string, object>? McpServers,
List<CustomAgentConfig>? CustomAgents,
Expand Down Expand Up @@ -1079,6 +1158,13 @@ internal record ToolCallResponse(
internal record PermissionRequestResponse(
PermissionRequestResult Result);

internal record UserInputRequestResponse(
string Answer,
bool WasFreeform);

internal record HooksInvokeResponse(
object? Output);

/// <summary>Trace source that forwards all logs to the ILogger.</summary>
internal sealed class LoggerTraceSource : TraceSource
{
Expand Down Expand Up @@ -1131,6 +1217,7 @@ public override void WriteLine(string? message) =>
[JsonSerializable(typeof(DeleteSessionRequest))]
[JsonSerializable(typeof(DeleteSessionResponse))]
[JsonSerializable(typeof(GetLastSessionIdResponse))]
[JsonSerializable(typeof(HooksInvokeResponse))]
[JsonSerializable(typeof(ListSessionsResponse))]
[JsonSerializable(typeof(PermissionRequestResponse))]
[JsonSerializable(typeof(PermissionRequestResult))]
Expand All @@ -1143,6 +1230,9 @@ public override void WriteLine(string? message) =>
[JsonSerializable(typeof(ToolDefinition))]
[JsonSerializable(typeof(ToolResultAIContent))]
[JsonSerializable(typeof(ToolResultObject))]
[JsonSerializable(typeof(UserInputRequestResponse))]
[JsonSerializable(typeof(UserInputRequest))]
[JsonSerializable(typeof(UserInputResponse))]
internal partial class ClientJsonContext : JsonSerializerContext;
}

Expand Down
146 changes: 146 additions & 0 deletions dotnet/src/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ public partial class CopilotSession : IAsyncDisposable
private readonly JsonRpc _rpc;
private PermissionHandler? _permissionHandler;
private readonly SemaphoreSlim _permissionHandlerLock = new(1, 1);
private UserInputHandler? _userInputHandler;
private readonly SemaphoreSlim _userInputHandlerLock = new(1, 1);
private SessionHooks? _hooks;
private readonly SemaphoreSlim _hooksLock = new(1, 1);

/// <summary>
/// Gets the unique identifier for this session.
Expand Down Expand Up @@ -330,6 +334,136 @@ internal async Task<PermissionRequestResult> HandlePermissionRequestAsync(JsonEl
return await handler(request, invocation);
}

/// <summary>
/// Registers a handler for user input requests from the agent.
/// </summary>
/// <param name="handler">The handler to invoke when user input is requested.</param>
internal void RegisterUserInputHandler(UserInputHandler handler)
{
_userInputHandlerLock.Wait();
try
{
_userInputHandler = handler;
}
finally
{
_userInputHandlerLock.Release();
}
}

/// <summary>
/// Handles a user input request from the Copilot CLI.
/// </summary>
/// <param name="request">The user input request from the CLI.</param>
/// <returns>A task that resolves with the user's response.</returns>
internal async Task<UserInputResponse> HandleUserInputRequestAsync(UserInputRequest request)
{
await _userInputHandlerLock.WaitAsync();
UserInputHandler? handler;
try
{
handler = _userInputHandler;
}
finally
{
_userInputHandlerLock.Release();
}

if (handler == null)
{
throw new InvalidOperationException("No user input handler registered");
}

var invocation = new UserInputInvocation
{
SessionId = SessionId
};

return await handler(request, invocation);
}

/// <summary>
/// Registers hook handlers for this session.
/// </summary>
/// <param name="hooks">The hooks configuration.</param>
internal void RegisterHooks(SessionHooks hooks)
{
_hooksLock.Wait();
try
{
_hooks = hooks;
}
finally
{
_hooksLock.Release();
}
}

/// <summary>
/// Handles a hook invocation from the Copilot CLI.
/// </summary>
/// <param name="hookType">The type of hook to invoke.</param>
/// <param name="input">The hook input data.</param>
/// <returns>A task that resolves with the hook output.</returns>
internal async Task<object?> HandleHooksInvokeAsync(string hookType, JsonElement input)
{
await _hooksLock.WaitAsync();
SessionHooks? hooks;
try
{
hooks = _hooks;
}
finally
{
_hooksLock.Release();
}

if (hooks == null)
{
return null;
}

var invocation = new HookInvocation
{
SessionId = SessionId
};

return hookType switch
{
"preToolUse" => hooks.OnPreToolUse != null
? await hooks.OnPreToolUse(
JsonSerializer.Deserialize(input.GetRawText(), SessionJsonContext.Default.PreToolUseHookInput)!,
invocation)
: null,
"postToolUse" => hooks.OnPostToolUse != null
? await hooks.OnPostToolUse(
JsonSerializer.Deserialize(input.GetRawText(), SessionJsonContext.Default.PostToolUseHookInput)!,
invocation)
: null,
"userPromptSubmitted" => hooks.OnUserPromptSubmitted != null
? await hooks.OnUserPromptSubmitted(
JsonSerializer.Deserialize(input.GetRawText(), SessionJsonContext.Default.UserPromptSubmittedHookInput)!,
invocation)
: null,
"sessionStart" => hooks.OnSessionStart != null
? await hooks.OnSessionStart(
JsonSerializer.Deserialize(input.GetRawText(), SessionJsonContext.Default.SessionStartHookInput)!,
invocation)
: null,
"sessionEnd" => hooks.OnSessionEnd != null
? await hooks.OnSessionEnd(
JsonSerializer.Deserialize(input.GetRawText(), SessionJsonContext.Default.SessionEndHookInput)!,
invocation)
: null,
"errorOccurred" => hooks.OnErrorOccurred != null
? await hooks.OnErrorOccurred(
JsonSerializer.Deserialize(input.GetRawText(), SessionJsonContext.Default.ErrorOccurredHookInput)!,
invocation)
: null,
_ => throw new ArgumentException($"Unknown hook type: {hookType}")
};
}

/// <summary>
/// Gets the complete list of messages and events in the session.
/// </summary>
Expand Down Expand Up @@ -487,5 +621,17 @@ internal record SessionDestroyRequest
[JsonSerializable(typeof(SessionAbortRequest))]
[JsonSerializable(typeof(SessionDestroyRequest))]
[JsonSerializable(typeof(UserMessageDataAttachmentsItem))]
[JsonSerializable(typeof(PreToolUseHookInput))]
[JsonSerializable(typeof(PreToolUseHookOutput))]
[JsonSerializable(typeof(PostToolUseHookInput))]
[JsonSerializable(typeof(PostToolUseHookOutput))]
[JsonSerializable(typeof(UserPromptSubmittedHookInput))]
[JsonSerializable(typeof(UserPromptSubmittedHookOutput))]
[JsonSerializable(typeof(SessionStartHookInput))]
[JsonSerializable(typeof(SessionStartHookOutput))]
[JsonSerializable(typeof(SessionEndHookInput))]
[JsonSerializable(typeof(SessionEndHookOutput))]
[JsonSerializable(typeof(ErrorOccurredHookInput))]
[JsonSerializable(typeof(ErrorOccurredHookOutput))]
internal partial class SessionJsonContext : JsonSerializerContext;
}
Loading
Loading