-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathopenai_client.py
More file actions
94 lines (81 loc) · 3.76 KB
/
openai_client.py
File metadata and controls
94 lines (81 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from openai import OpenAI
from config import Config
from prompt import Prompt
from chat_history import ChatHistory
from db import TranslationRecord
from dataclasses import dataclass, field
@dataclass
class LLMClient:
"""
A class representing a client for interacting with a language model via OpenAI API.
Args:
config (Config): Configuration settings for the client including OpenAI API and model configurations.
prompt (Prompt): The prompt to be used for interactions with the language model.
chat_history (ChatHistory): The history of the chat between the user and the language model. Defaults to an empty ChatHistory.
"""
config: Config
client: OpenAI = field(init=False)
prompt: Prompt
chat_history: ChatHistory = field(default_factory=ChatHistory)
@classmethod
def from_config(cls, config: Config):
"""
Factory method to create an LLMClient from a Config and Prompt instance.
Args:
config (Config): Configuration settings for the client including OpenAI API and model configurations.
prompt (Prompt): The prompt to be used for interactions with the language model.
Returns:
LLMClient: An instance of LLMClient.
"""
return cls(config=config, prompt=config.prompt)
def __post_init__(self):
"""
Initializes the OpenAI client and sets up the system and task prompts in the chat history.
"""
self.client = OpenAI(base_url=self.config.openai_config.base_url,
api_key=self.config.openai_config.api_key)
self.chat_history.set_system_prompt(self.prompt.system_prompt.system_prompt)
self.chat_history.set_task_prompt(self.prompt.template.task_template)
def request_completion(self):
"""
Requests a completion from the language model based on the chat history.
Returns:
str: The content of the response message from the language model.
Exception: An exception if an error occurs during the API call.
"""
try:
completion = self.client.chat.completions.create(
model=self.config.openai_config.model_name,
messages=self.chat_history.chat_history,
temperature=self.config.model_config.temperature,
frequency_penalty=self.config.model_config.frequency_penalty,
presence_penalty=self.config.model_config.presence_penalty
)
except Exception as e:
print(f"Error in OpenAI completion. Error: {e}")
return e
return completion.choices[0].message.content
def reset_history(self):
"""
Resets the chat history.
"""
self.chat_history.reset_history()
if self.prompt.system_prompt.use_system_prompt:
self.chat_history.add_system_prompt(self.prompt.system_prompt.system_prompt)
self.chat_history.add_user_content(self.prompt.template.task_template)
self.chat_history.set_src_lang("")
self.chat_history.set_tgt_lang("")
def apply_latest_translations(self, records: list[TranslationRecord]):
records.reverse()
for record in records:
self.chat_history.add_user_content(record.src_text)
self.chat_history.add_assistant_content(record.tgt_text)
def set_language_targets(self, src_lang: str, tgt_lang: str):
self.chat_history.set_src_lang(src_lang)
self.chat_history.set_tgt_lang(tgt_lang)
if self.prompt.template.specify_language:
self.chat_history.add_user_content(self.prompt.template.get_language_target_prompt(src_lang, tgt_lang))
@dataclass
class Prompt:
system_prompt: str
task_prompt: str