Skip to content

Commit b3e3bf6

Browse files
authored
make client immutate (#94)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent a7330d0 commit b3e3bf6

7 files changed

Lines changed: 48 additions & 45 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ fn main() {
6767
.build()
6868
.unwrap();
6969

70-
let mut client = client::Client::new(config);
70+
let client = client::Client::new(config);
7171
let request = chat::CreateChatCompletionRequestArgs::default()
7272
.messages([
7373
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),

examples/wrr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fn main() {
2323
.build()
2424
.unwrap();
2525

26-
let mut client = client::Client::new(config);
26+
let client = client::Client::new(config);
2727
let request = chat::CreateChatCompletionRequestArgs::default()
2828
.messages([
2929
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),

src/client/client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ impl Client {
2929
}
3030

3131
pub async fn create_response(
32-
&mut self,
32+
&self,
3333
request: responses::CreateResponse,
3434
) -> Result<responses::Response, OpenAIError> {
3535
let candidate = self.router.sample();
@@ -39,7 +39,7 @@ impl Client {
3939

4040
// This is chat completion endpoint.
4141
pub async fn create_completion(
42-
&mut self,
42+
&self,
4343
request: chat::CreateChatCompletionRequest,
4444
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
4545
let candidate = self.router.sample();

src/router/random.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ impl Router for RandomRouter {
1818
"RandomRouter"
1919
}
2020

21-
fn sample(&mut self) -> ModelName {
21+
fn sample(&self) -> ModelName {
2222
let mut rng = rand::rng();
2323
let idx = rng.random_range(0..self.model_infos.len());
2424
self.model_infos[idx].name.clone()

src/router/router.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub fn construct_router(mode: RouterMode, models: Vec<ModelConfig>) -> Box<dyn R
2424

2525
pub trait Router {
2626
fn name(&self) -> &'static str;
27-
fn sample(&mut self) -> ModelName;
27+
fn sample(&self) -> ModelName;
2828
}
2929

3030
#[cfg(test)]

src/router/wrr.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
use std::sync::atomic::AtomicI32;
2+
13
use crate::client::config::ModelName;
24
use crate::router::router::{ModelInfo, Router};
35

46
pub struct WeightedRoundRobinRouter {
57
total_weight: i32,
68
model_infos: Vec<ModelInfo>,
79
// current_weight is ordered by model_infos index.
8-
current_weights: Vec<i32>,
10+
current_weights: Vec<AtomicI32>,
911
}
1012

1113
impl WeightedRoundRobinRouter {
@@ -16,7 +18,7 @@ impl WeightedRoundRobinRouter {
1618
Self {
1719
model_infos: model_infos,
1820
total_weight: total_weight,
19-
current_weights: vec![0; length],
21+
current_weights: (0..length).map(|_| AtomicI32::new(0)).collect(),
2022
}
2123
}
2224
}
@@ -27,27 +29,28 @@ impl Router for WeightedRoundRobinRouter {
2729
}
2830

2931
// Use Smooth Weighted Round Robin Algorithm.
30-
fn sample(&mut self) -> ModelName {
32+
fn sample(&self) -> ModelName {
3133
// return early if only one model.
3234
if self.model_infos.len() == 1 {
3335
return self.model_infos[0].name.clone();
3436
}
3537

36-
self.current_weights
37-
.iter_mut()
38-
.enumerate()
39-
.for_each(|(i, weight)| {
40-
*weight += self.model_infos[i].weight;
41-
});
38+
// 1. add weight to current weight.
39+
self.model_infos.iter().enumerate().for_each(|(i, weight)| {
40+
self.current_weights[i].fetch_add(weight.weight, std::sync::atomic::Ordering::Relaxed);
41+
});
4242

4343
let mut max_index = 0;
4444
for i in 1..self.current_weights.len() {
45-
if self.current_weights[i] > self.current_weights[max_index] {
45+
if self.current_weights[i].load(std::sync::atomic::Ordering::Relaxed)
46+
> self.current_weights[max_index].load(std::sync::atomic::Ordering::Relaxed)
47+
{
4648
max_index = i;
4749
}
4850
}
4951

50-
self.current_weights[max_index] -= self.total_weight;
52+
self.current_weights[max_index]
53+
.fetch_sub(self.total_weight, std::sync::atomic::Ordering::Relaxed);
5154
self.model_infos[max_index].name.clone()
5255
}
5356
}

tests/client.rs

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,31 @@ use arms::types::responses;
88
mod tests {
99
use super::*;
1010

11+
#[tokio::test]
12+
async fn test_completion() {
13+
from_filename(".env.integration-test").ok();
14+
15+
let config = client::Config::builder()
16+
.provider("faker")
17+
.model(
18+
client::ModelConfig::builder()
19+
.name("fake-completion-model")
20+
.build()
21+
.unwrap(),
22+
)
23+
.build()
24+
.unwrap();
25+
26+
let client = client::Client::new(config);
27+
let request = chat::CreateChatCompletionRequestArgs::default()
28+
.build()
29+
.unwrap();
30+
31+
let response = client.create_completion(request).await.unwrap();
32+
assert!(response.id.starts_with("fake-completion-id"));
33+
assert!(response.model == "fake-completion-model");
34+
}
35+
1136
#[tokio::test]
1237
async fn test_response() {
1338
from_filename(".env.integration-test").ok();
@@ -24,7 +49,7 @@ mod tests {
2449
.build()
2550
.unwrap();
2651

27-
let mut client = client::Client::new(config);
52+
let client = client::Client::new(config);
2853
let request = responses::CreateResponseArgs::default()
2954
.input("tell me the weather today")
3055
.build()
@@ -45,7 +70,7 @@ mod tests {
4570
)
4671
.build()
4772
.unwrap();
48-
let mut client = client::Client::new(config);
73+
let client = client::Client::new(config);
4974
let request = responses::CreateResponseArgs::default()
5075
.model("gpt-3.5-turbo")
5176
.input("tell me a joke")
@@ -74,36 +99,11 @@ mod tests {
7499
)
75100
.build()
76101
.unwrap();
77-
let mut client = client::Client::new(config);
102+
let client = client::Client::new(config);
78103
let request = responses::CreateResponseArgs::default()
79104
.input("give me a poem about nature")
80105
.build()
81106
.unwrap();
82107
let _ = client.create_response(request).await.unwrap();
83108
}
84-
85-
#[tokio::test]
86-
async fn test_completion() {
87-
from_filename(".env.integration-test").ok();
88-
89-
let config = client::Config::builder()
90-
.provider("faker")
91-
.model(
92-
client::ModelConfig::builder()
93-
.name("fake-completion-model")
94-
.build()
95-
.unwrap(),
96-
)
97-
.build()
98-
.unwrap();
99-
100-
let mut client = client::Client::new(config);
101-
let request = chat::CreateChatCompletionRequestArgs::default()
102-
.build()
103-
.unwrap();
104-
105-
let response = client.create_completion(request).await.unwrap();
106-
assert!(response.id.starts_with("fake-completion-id"));
107-
assert!(response.model == "fake-completion-model");
108-
}
109109
}

0 commit comments

Comments
 (0)