Skip to content

Commit a2edc20

Browse files
committed
code style and api
1 parent 5ccddc3 commit a2edc20

File tree

12 files changed

+405
-89
lines changed

12 files changed

+405
-89
lines changed

Together.Tests/HttpCallsTests.cs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Together.Models.Completions;
66
using Together.Models.Embeddings;
77
using Together.Models.Images;
8+
using Together.Models.Rerank;
89
using ChatMessage = Microsoft.Extensions.AI.ChatMessage;
910

1011
namespace Together.Tests;
@@ -19,6 +20,7 @@ private HttpClient CreateHttpClient()
1920
httpClient.Timeout = TimeSpan.FromSeconds(TogetherConstants.TIMEOUT_SECS);
2021
httpClient.BaseAddress = new Uri(TogetherConstants.BASE_URL);
2122
httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", API_KEY);
23+
httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
2224
return httpClient;
2325
}
2426

@@ -28,7 +30,7 @@ public async Task CompletionTest()
2830
var client = new TogetherClient(CreateHttpClient());
2931

3032

31-
var responseAsync = await client.GetCompletionResponseAsync(new CompletionRequest()
33+
var responseAsync = await client.Completions.CreateAsync(new CompletionRequest()
3234
{
3335
Prompt = "Hi",
3436
Model = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
@@ -43,7 +45,7 @@ public async Task ChatCompletionTest()
4345
{
4446
var client = new TogetherClient(CreateHttpClient());
4547

46-
var responseAsync = await client.GetChatCompletionResponseAsync(new ChatCompletionRequest
48+
var responseAsync = await client.ChatCompletions.CreateAsync(new ChatCompletionRequest
4749
{
4850
Messages = new List<ChatCompletionMessage>()
4951
{
@@ -65,7 +67,7 @@ public async Task StreamChatCompletionTest()
6567
{
6668
var client = new TogetherClient(CreateHttpClient());
6769

68-
var responseAsync = await client.GetStreamChatCompletionResponseAsync(new ChatCompletionRequest
70+
var responseAsync = await client.ChatCompletions.CreateStreamAsync(new ChatCompletionRequest
6971
{
7072
Messages = new List<ChatCompletionMessage>()
7173
{
@@ -90,7 +92,7 @@ public async Task EmbeddingTest()
9092
{
9193
var client = new TogetherClient(CreateHttpClient());
9294

93-
var responseAsync = await client.GetEmbeddingResponseAsync(new EmbeddingRequest()
95+
var responseAsync = await client.Embeddings.CreateAsync(new EmbeddingRequest()
9496
{
9597
Input = "Hi",
9698
Model = "togethercomputer/m2-bert-80M-2k-retrieval",
@@ -104,7 +106,7 @@ public async Task ImageTest()
104106
{
105107
var client = new TogetherClient(CreateHttpClient());
106108

107-
var responseAsync = await client.GetImageResponseAsync(new ImageRequest()
109+
var responseAsync = await client.Images.GenerateAsync(new ImageRequest()
108110
{
109111
Model = "black-forest-labs/FLUX.1-dev",
110112
Prompt = "Cats eating popcorn",
@@ -117,14 +119,37 @@ public async Task ImageTest()
117119
Assert.NotEmpty(responseAsync.Data.First().Url);
118120
}
119121

122+
[Fact]
123+
public async Task ModelsTest()
124+
{
125+
var client = new TogetherClient(CreateHttpClient());
126+
127+
var responseAsync = await client.Models.ListModelsAsync();
128+
129+
Assert.NotEmpty(responseAsync);
130+
}
131+
132+
[Fact]
133+
public async Task RerankTest()
134+
{
135+
var client = new TogetherClient(CreateHttpClient());
136+
137+
var responseAsync = await client.Rerank.CreateAsync(new RerankRequest()
138+
{
139+
140+
});
141+
142+
Assert.NotEmpty(responseAsync.Results);
143+
}
144+
120145
[Fact]
121146
public async Task WrongModelTest()
122147
{
123148
var client = new TogetherClient(CreateHttpClient());
124149

125150
await Assert.ThrowsAsync<Exception>(async () =>
126151
{
127-
var responseAsync = await client.GetImageResponseAsync(new ImageRequest()
152+
var responseAsync = await client.Images.GenerateAsync(new ImageRequest()
128153
{
129154
Model = "Wring-Model",
130155
Prompt = "so wrong",

Together/Clients/BaseClient.cs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using System.Net.Http.Json;
2+
using Together.Models.Error;
3+
4+
namespace Together.Clients;
5+
6+
public abstract class BaseClient
7+
{
8+
protected readonly HttpClient HttpClient;
9+
10+
protected BaseClient(HttpClient httpClient)
11+
{
12+
HttpClient = httpClient;
13+
}
14+
15+
protected async Task<TResponse> SendRequestAsync<TRequest, TResponse>(string requestUri, TRequest request, CancellationToken cancellationToken)
16+
{
17+
var responseMessage = await HttpClient.PostAsJsonAsync(requestUri, request, cancellationToken);
18+
return await HandleResponseAsync<TResponse>(responseMessage, cancellationToken);
19+
}
20+
21+
protected async Task<TResponse> SendRequestAsync<TResponse>(string requestUri, HttpMethod method, HttpContent? content, CancellationToken cancellationToken)
22+
{
23+
using var request = new HttpRequestMessage(method, requestUri);
24+
if (content != null)
25+
{
26+
request.Content = content;
27+
}
28+
29+
var responseMessage = await HttpClient.SendAsync(request, cancellationToken);
30+
return await HandleResponseAsync<TResponse>(responseMessage, cancellationToken);
31+
}
32+
33+
private static async Task<TResponse> HandleResponseAsync<TResponse>(HttpResponseMessage responseMessage, CancellationToken cancellationToken)
34+
{
35+
if (responseMessage.IsSuccessStatusCode)
36+
{
37+
if (typeof(TResponse) == typeof(HttpResponseMessage) && responseMessage is TResponse response)
38+
{
39+
return response;
40+
}
41+
42+
var result = await responseMessage.Content.ReadFromJsonAsync<TResponse>(cancellationToken: cancellationToken);
43+
return result!;
44+
}
45+
46+
var errorResponse = await responseMessage.Content.ReadFromJsonAsync<ErrorResponse>(cancellationToken: cancellationToken);
47+
if (errorResponse?.Error != null)
48+
{
49+
throw new Exception(errorResponse.Error.Message);
50+
}
51+
52+
var statusCode = responseMessage.StatusCode;
53+
var errorContent = await responseMessage.Content.ReadAsStringAsync(cancellationToken);
54+
throw new Exception($"Request failed with status code {statusCode}: {errorContent}");
55+
}
56+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System.Runtime.CompilerServices;
2+
using System.Text.Json;
3+
using Together.Models.ChatCompletions;
4+
5+
namespace Together.Clients;
6+
7+
public class ChatCompletionClient(HttpClient httpClient) : BaseClient(httpClient)
8+
{
9+
public async Task<ChatCompletionResponse> CreateAsync(ChatCompletionRequest request,
10+
CancellationToken cancellationToken = default)
11+
{
12+
return await SendRequestAsync<ChatCompletionRequest, ChatCompletionResponse>("/chat/completions", request, cancellationToken);
13+
}
14+
15+
public async IAsyncEnumerable<ChatCompletionChunk> CreateStreamAsync(ChatCompletionRequest request,
16+
[EnumeratorCancellation] CancellationToken cancellationToken = default)
17+
{
18+
var responseMessage = await SendRequestAsync<ChatCompletionRequest, HttpResponseMessage>("/chat/completions", request, cancellationToken);
19+
20+
await using var stream = await responseMessage.Content.ReadAsStreamAsync(cancellationToken);
21+
using var reader = new StreamReader(stream);
22+
23+
while (await reader.ReadLineAsync(cancellationToken) is string line)
24+
{
25+
if (!line.StartsWith("data:"))
26+
continue;
27+
28+
var eventData = line.Substring("data:".Length)
29+
.Trim();
30+
if (eventData is null or "[DONE]")
31+
break;
32+
33+
var result = JsonSerializer.Deserialize<ChatCompletionChunk>(eventData);
34+
35+
if (result is not null)
36+
yield return result;
37+
}
38+
}
39+
40+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System.Net.Http.Json;
2+
using Together.Models.Completions;
3+
4+
namespace Together.Clients;
5+
6+
public class CompletionClient(HttpClient httpClient) : BaseClient(httpClient)
7+
{
8+
9+
10+
public async Task<CompletionResponse> CreateAsync(CompletionRequest request, CancellationToken cancellationToken = default)
11+
{
12+
return await SendRequestAsync<CompletionRequest, CompletionResponse>("/completions", request, cancellationToken);
13+
}
14+
15+
16+
17+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Together.Models.Embeddings;
2+
3+
namespace Together.Clients;
4+
5+
public class EmbeddingClient(HttpClient httpClient) : BaseClient(httpClient)
6+
{
7+
8+
9+
public async Task<EmbeddingResponse> CreateAsync(EmbeddingRequest request, CancellationToken cancellationToken = default)
10+
{
11+
return await SendRequestAsync<EmbeddingRequest, EmbeddingResponse>("/embeddings", request, cancellationToken);
12+
}
13+
14+
15+
}

Together/Clients/FileClient.cs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using System.Net.Http.Headers;
2+
using Together.Models.Files;
3+
4+
namespace Together.Clients;
5+
6+
public class FileClient(HttpClient httpClient) : BaseClient(httpClient)
7+
{
8+
public async Task<FileResponse> UploadAsync(
9+
string filePath,
10+
FilePurpose? purpose = null,
11+
bool checkFile = true,
12+
CancellationToken cancellationToken = default)
13+
{
14+
purpose ??= FilePurpose.FineTune;
15+
16+
if (checkFile && !File.Exists(filePath))
17+
{
18+
throw new FileNotFoundException("File not found", filePath);
19+
}
20+
21+
using var form = new MultipartFormDataContent();
22+
using var fileStream = File.OpenRead(filePath);
23+
using var content = new StreamContent(fileStream);
24+
25+
content.Headers.ContentType = new MediaTypeHeaderValue("application/octet-stream");
26+
form.Add(content, "file", Path.GetFileName(filePath));
27+
form.Add(new StringContent(purpose.ToString().ToLowerInvariant()), "purpose");
28+
29+
return await SendRequestAsync<FileResponse>("/files", HttpMethod.Post, form, cancellationToken);
30+
}
31+
32+
public async Task<FileList> ListAsync(CancellationToken cancellationToken = default)
33+
{
34+
return await SendRequestAsync<FileList>("/files", HttpMethod.Get, null, cancellationToken);
35+
}
36+
37+
public async Task<FileResponse> RetrieveAsync(string fileId, CancellationToken cancellationToken = default)
38+
{
39+
return await SendRequestAsync<FileResponse>($"/files/{fileId}", HttpMethod.Get, null, cancellationToken);
40+
}
41+
42+
public async Task<FileObject> RetrieveContentAsync(string fileId, string? outputPath = null, CancellationToken cancellationToken = default)
43+
{
44+
var fileName = outputPath ?? NormalizeKey($"{fileId}.jsonl");
45+
var response = await HttpClient.GetAsync($"/files/{fileId}/content", cancellationToken);
46+
response.EnsureSuccessStatusCode();
47+
48+
await using var fs = File.Create(fileName);
49+
await response.Content.CopyToAsync(fs, cancellationToken);
50+
51+
var fileInfo = new FileInfo(fileName);
52+
return new FileObject
53+
{
54+
Object = "local",
55+
Id = fileId,
56+
Filename = fileName,
57+
Size = (int)fileInfo.Length
58+
};
59+
}
60+
61+
public async Task<FileDeleteResponse> DeleteAsync(string fileId, CancellationToken cancellationToken = default)
62+
{
63+
return await SendRequestAsync<FileDeleteResponse>($"/files/{fileId}", HttpMethod.Delete, null, cancellationToken);
64+
}
65+
66+
private static string NormalizeKey(string key) => string.Join("_", key.Split(Path.GetInvalidFileNameChars()));
67+
}

0 commit comments

Comments
 (0)