|
| 1 | +import asyncio |
1 | 2 | import logging |
| 3 | +from typing import Any, Dict, List, Optional |
2 | 4 |
|
3 | 5 | import cohere |
4 | 6 | from abcs.llm import LLM |
5 | | -from abcs.models import PromptResponse, UsageStats |
| 7 | +from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats |
6 | 8 | from tools.tool_manager import ToolManager |
7 | 9 |
|
8 | 10 | logging.basicConfig(level=logging.INFO) |
@@ -121,3 +123,75 @@ def _translate_response(self, response) -> PromptResponse: |
121 | 123 | ) |
122 | 124 | raise e |
123 | 125 |
|
| 126 | + # https://github.com/cohere-ai/cohere-python/blob/main/src/cohere/types/streamed_chat_response.py |
| 127 | + # https://docs.cohere.com/docs/streaming#stream-events |
| 128 | + # https://docs.cohere.com/docs/streaming#example-responses |
| 129 | + async def generate_text_stream( |
| 130 | + self, |
| 131 | + prompt: str, |
| 132 | + past_messages: List[Dict[str, str]], |
| 133 | + tools: Optional[List[Dict[str, Any]]] = None, |
| 134 | + **kwargs, |
| 135 | + ) -> StreamingPromptResponse: |
| 136 | + combined_history = past_messages + [{"role": "user", "content": prompt}] |
| 137 | + |
| 138 | + try: |
| 139 | + combined_history = [] |
| 140 | + for msg in past_messages: |
| 141 | + combined_history.append({ |
| 142 | + "role": 'CHATBOT' if msg['role'] == 'assistant' else 'USER', |
| 143 | + "message": msg['content'], |
| 144 | + }) |
| 145 | + stream = self.client.chat_stream( |
| 146 | + chat_history=combined_history, |
| 147 | + message=prompt, |
| 148 | + tools=tools, |
| 149 | + model=self.model, |
| 150 | + # perform web search before answering the question. You can also use your own custom connector. |
| 151 | + # connectors=[{"id": "web-search"}], |
| 152 | + ) |
| 153 | + |
| 154 | + async def content_generator(): |
| 155 | + for event in stream: |
| 156 | + if isinstance(event, cohere.types.StreamedChatResponse_StreamStart): |
| 157 | + # Message start event, we can ignore this |
| 158 | + pass |
| 159 | + elif isinstance(event, cohere.types.StreamedChatResponse_TextGeneration): |
| 160 | + # This is the event that contains the actual text |
| 161 | + if event.text: |
| 162 | + yield event.text |
| 163 | + elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsGeneration): |
| 164 | + # todo: call tool |
| 165 | + pass |
| 166 | + elif isinstance(event, cohere.types.StreamedChatResponse_CitationGeneration): |
| 167 | + # todo: not sure, but seems useful |
| 168 | + pass |
| 169 | + elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsChunk): |
| 170 | + # todo: tool response |
| 171 | + pass |
| 172 | + elif isinstance(event, cohere.types.StreamedChatResponse_SearchQueriesGeneration): |
| 173 | + pass |
| 174 | + elif isinstance(event, cohere.types.StreamedChatResponse_SearchResults): |
| 175 | + pass |
| 176 | + elif isinstance(event, cohere.types.StreamedChatResponse_StreamEnd): |
| 177 | + # Message stop event, we can ignore this |
| 178 | + pass |
| 179 | + # Small delay to allow for cooperative multitasking |
| 180 | + await asyncio.sleep(0) |
| 181 | + |
| 182 | + return StreamingPromptResponse( |
| 183 | + content=content_generator(), |
| 184 | + raw_response=stream, |
| 185 | + error={}, |
| 186 | + usage=UsageStats( |
| 187 | + input_tokens=0, # These will need to be updated after streaming |
| 188 | + output_tokens=0, |
| 189 | + extra={}, |
| 190 | + ), |
| 191 | + ) |
| 192 | + except Exception as e: |
| 193 | + logger.exception(f"An error occurred while streaming from Claude: {e}") |
| 194 | + raise e |
| 195 | + |
| 196 | + async def handle_tool_call(self, tool_calls, combined_history, tools): |
| 197 | + pass |
0 commit comments