@@ -8,10 +8,12 @@ namespace ModelContextProtocol.Client;
88/// <summary>Provides the client side of a stream-based session transport.</summary>
99internal class StreamClientSessionTransport : TransportBase
1010{
11+ private static readonly byte [ ] s_newlineBytes = "\n "u8 . ToArray ( ) ;
12+
1113 internal static UTF8Encoding NoBomUtf8Encoding { get ; } = new ( encoderShouldEmitUTF8Identifier : false ) ;
1214
1315 private readonly TextReader _serverOutput ;
14- private readonly TextWriter _serverInput ;
16+ private readonly Stream _serverInputStream ;
1517 private readonly SemaphoreSlim _sendLock = new ( 1 , 1 ) ;
1618 private CancellationTokenSource ? _shutdownCts = new ( ) ;
1719 private Task ? _readTask ;
@@ -20,12 +22,13 @@ internal class StreamClientSessionTransport : TransportBase
2022 /// Initializes a new instance of the <see cref="StreamClientSessionTransport"/> class.
2123 /// </summary>
2224 /// <param name="serverInput">
23- /// The text writer connected to the server's input stream.
24- /// Messages written to this writer will be sent to the server.
25+ /// The server's input stream. Messages written to this stream will be sent to the server.
2526 /// </param>
2627 /// <param name="serverOutput">
27- /// The text reader connected to the server's output stream.
28- /// Messages read from this reader will be received from the server.
28+ /// The server's output stream. Messages read from this stream will be received from the server.
29+ /// </param>
30+ /// <param name="encoding">
31+ /// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
2932 /// </param>
3033 /// <param name="endpointName">
3134 /// A name that identifies this transport endpoint in logs.
@@ -37,12 +40,18 @@ internal class StreamClientSessionTransport : TransportBase
3740 /// This constructor starts a background task to read messages from the server output stream.
3841 /// The transport will be marked as connected once initialized.
3942 /// </remarks>
40- public StreamClientSessionTransport (
41- TextWriter serverInput , TextReader serverOutput , string endpointName , ILoggerFactory ? loggerFactory )
43+ public StreamClientSessionTransport ( Stream serverInput , Stream serverOutput , Encoding ? encoding , string endpointName , ILoggerFactory ? loggerFactory )
4244 : base ( endpointName , loggerFactory )
4345 {
44- _serverOutput = serverOutput ;
45- _serverInput = serverInput ;
46+ Throw . IfNull ( serverInput ) ;
47+ Throw . IfNull ( serverOutput ) ;
48+
49+ _serverInputStream = serverInput ;
50+ #if NET
51+ _serverOutput = new StreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ;
52+ #else
53+ _serverOutput = new CancellableStreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ;
54+ #endif
4655
4756 SetConnected ( ) ;
4857
@@ -57,43 +66,6 @@ public StreamClientSessionTransport(
5766 readTask . Start ( ) ;
5867 }
5968
60- /// <summary>
61- /// Initializes a new instance of the <see cref="StreamClientSessionTransport"/> class.
62- /// </summary>
63- /// <param name="serverInput">
64- /// The server's input stream. Messages written to this stream will be sent to the server.
65- /// </param>
66- /// <param name="serverOutput">
67- /// The server's output stream. Messages read from this stream will be received from the server.
68- /// </param>
69- /// <param name="encoding">
70- /// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
71- /// </param>
72- /// <param name="endpointName">
73- /// A name that identifies this transport endpoint in logs.
74- /// </param>
75- /// <param name="loggerFactory">
76- /// Optional factory for creating loggers. If null, a NullLogger is used.
77- /// </param>
78- /// <remarks>
79- /// This constructor starts a background task to read messages from the server output stream.
80- /// The transport will be marked as connected once initialized.
81- /// </remarks>
82- public StreamClientSessionTransport ( Stream serverInput , Stream serverOutput , Encoding ? encoding , string endpointName , ILoggerFactory ? loggerFactory )
83- : this (
84- new StreamWriter ( serverInput , encoding ?? NoBomUtf8Encoding ) ,
85- #if NET
86- new StreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ,
87- #else
88- new CancellableStreamReader ( serverOutput , encoding ?? NoBomUtf8Encoding ) ,
89- #endif
90- endpointName,
91- loggerFactory )
92- {
93- Throw . IfNull ( serverInput ) ;
94- Throw . IfNull ( serverOutput ) ;
95- }
96-
9769 /// <inheritdoc/>
9870 public override async Task SendMessageAsync ( JsonRpcMessage message , CancellationToken cancellationToken = default )
9971 {
@@ -103,16 +75,15 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation
10375 id = messageWithId . Id . ToString ( ) ;
10476 }
10577
106- var json = JsonSerializer . Serialize ( message , McpJsonUtilities . JsonContext . Default . JsonRpcMessage ) ;
107-
108- LogTransportSendingMessageSensitive ( Name , json ) ;
78+ LogTransportSendingMessageSensitive ( message ) ;
10979
11080 using var _ = await _sendLock . LockAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
11181 try
11282 {
113- // Write the message followed by a newline using our UTF-8 writer
114- await _serverInput . WriteLineAsync ( json ) . ConfigureAwait ( false ) ;
115- await _serverInput . FlushAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
83+ var json = JsonSerializer . SerializeToUtf8Bytes ( message , McpJsonUtilities . JsonContext . Default . JsonRpcMessage ) ;
84+ await _serverInputStream . WriteAsync ( json , cancellationToken ) . ConfigureAwait ( false ) ;
85+ await _serverInputStream . WriteAsync ( s_newlineBytes , cancellationToken ) . ConfigureAwait ( false ) ;
86+ await _serverInputStream . FlushAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
11687 }
11788 catch ( Exception ex )
11889 {
0 commit comments