diff --git a/willa/chatbot/graph_manager.py b/willa/chatbot/graph_manager.py index 13dea78..8c948d3 100644 --- a/willa/chatbot/graph_manager.py +++ b/willa/chatbot/graph_manager.py @@ -1,4 +1,5 @@ """Manages the shared state and workflow for Willa chatbots.""" +import re from typing import Any, Optional, Annotated, NotRequired from typing_extensions import TypedDict @@ -19,10 +20,10 @@ class WillaChatbotState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] filtered_messages: NotRequired[list[AnyMessage]] summarized_messages: NotRequired[list[AnyMessage]] - docs_context: NotRequired[str] search_query: NotRequired[str] tind_metadata: NotRequired[str] - context: NotRequired[dict[str, Any]] + documents: NotRequired[list[Any]] + citations: NotRequired[list[dict[str, Any]]] class GraphManager: # pylint: disable=too-few-public-methods @@ -51,13 +52,15 @@ def _create_workflow(self) -> CompiledStateGraph: workflow.add_node("summarize", summarization_node) workflow.add_node("prepare_search", self._prepare_search_query) workflow.add_node("retrieve_context", self._retrieve_context) + workflow.add_node("prepare_for_generation", self._prepare_for_generation) workflow.add_node("generate_response", self._generate_response) # Define edges workflow.add_edge("filter_messages", "summarize") workflow.add_edge("summarize", "prepare_search") workflow.add_edge("prepare_search", "retrieve_context") - workflow.add_edge("retrieve_context", "generate_response") + workflow.add_edge("retrieve_context", "prepare_for_generation") + workflow.add_edge("prepare_for_generation", "generate_response") workflow.set_entry_point("filter_messages") workflow.set_finish_point("generate_response") @@ -87,52 +90,75 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]: vector_store = self._vector_store if not search_query or not vector_store: - return {"docs_context": "", "tind_metadata": ""} + return {"tind_metadata": "", "documents": []} # Search for relevant documents retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])}) matching_docs = retriever.invoke(search_query) - - # Format context and metadata - docs_context = '\n\n'.join(doc.page_content for doc in matching_docs) + formatted_documents = [ + { + "id": f"{doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]}_{i}", + "page_content": doc.page_content, + "title": doc.metadata.get('tind_metadata', {}).get('title', [''])[0], + "project": doc.metadata.get('tind_metadata', {}).get('isPartOf', [''])[0], + "tind_link": format_tind_context.get_tind_url( + doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]) + } + for i, doc in enumerate(matching_docs, 1) + ] + + # Format tind metadata tind_metadata = format_tind_context.get_tind_context(matching_docs) - return {"docs_context": docs_context, "tind_metadata": tind_metadata} + return {"tind_metadata": tind_metadata, "documents": formatted_documents} - # This should be refactored probably. Very bulky - def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: - """Generate response using the model.""" + def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: + """Prepare the current and past messages for response generation.""" messages = state["messages"] summarized_conversation = state.get("summarized_messages", messages) - docs_context = state.get("docs_context", "") - tind_metadata = state.get("tind_metadata", "") - model = self._model - - if not model: - return {"messages": [AIMessage(content="Model not available.")]} - - # Get the latest human message - latest_message = next( - (msg for msg in reversed(messages) if isinstance(msg, HumanMessage)), - None - ) - - if not latest_message: + + if not any(isinstance(msg, HumanMessage) for msg in messages): return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]} - + prompt = get_langfuse_prompt() - system_messages = prompt.invoke({'context': docs_context, - 'question': latest_message.content}) + system_messages = prompt.invoke({}) + if hasattr(system_messages, "messages"): all_messages = summarized_conversation + system_messages.messages else: all_messages = summarized_conversation + [system_messages] + + return {"messages": all_messages} + + def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: + """Generate response using the model.""" + tind_metadata = state.get("tind_metadata", "") + model = self._model + documents = state.get("documents", []) + messages = state["messages"] + + if not model: + return {"messages": [AIMessage(content="Model not available.")]} # Get response from model - response = model.invoke(all_messages) + response = model.invoke( + messages, + additional_model_request_fields={"documents": documents}, + additional_model_response_field_paths=["/citations"] + ) + citations = response.response_metadata.get('additionalModelResponseFields').get('citations') if response.response_metadata else None # Create clean response content response_content = str(response.content) if hasattr(response, 'content') else str(response) + + if citations: + state['citations'] = citations + response_content += "\n\nCitations:\n" + for citation in citations: + doc_ids = list(dict.fromkeys([re.sub(r'_\d*$', '', doc_id) + for doc_id in citation.get('document_ids', [])])) + response_content += f"- {citation.get('text', '')} ({', '.join(doc_ids)})\n" + response_messages: list[AnyMessage] = [AIMessage(content=response_content), ChatMessage(content=tind_metadata, role='TIND', response_metadata={'tind': True})]