From ff6ab7357337dd06dc7487576061b7787fd0a897 Mon Sep 17 00:00:00 2001 From: Jason Raitz Date: Wed, 17 Dec 2025 17:29:15 -0500 Subject: [PATCH 1/4] investigating separating out documents from the rest of the message history and instructions. --- willa/chatbot/graph_manager.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/willa/chatbot/graph_manager.py b/willa/chatbot/graph_manager.py index 13dea78..d3347c3 100644 --- a/willa/chatbot/graph_manager.py +++ b/willa/chatbot/graph_manager.py @@ -22,6 +22,7 @@ class WillaChatbotState(TypedDict): docs_context: NotRequired[str] search_query: NotRequired[str] tind_metadata: NotRequired[str] + documents: NotRequired[list[Any]] context: NotRequired[dict[str, Any]] @@ -87,17 +88,25 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]: vector_store = self._vector_store if not search_query or not vector_store: - return {"docs_context": "", "tind_metadata": ""} + return {"docs_context": "", "tind_metadata": "", "documents": []} # Search for relevant documents retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])}) matching_docs = retriever.invoke(search_query) + formatted_documents = [ + { + "page_content": doc.page_content, + "start_index": str(doc.metadata.get('start_index')) if doc.metadata.get('start_index') else '', + "total_pages": str(doc.metadata.get('total_pages')) if doc.metadata.get('total_pages') else '', + } + for doc in matching_docs + ] # Format context and metadata docs_context = '\n\n'.join(doc.page_content for doc in matching_docs) tind_metadata = format_tind_context.get_tind_context(matching_docs) - return {"docs_context": docs_context, "tind_metadata": tind_metadata} + return {"docs_context": docs_context, "tind_metadata": tind_metadata, "documents": formatted_documents} # This should be refactored probably. Very bulky def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: @@ -107,6 +116,7 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess docs_context = state.get("docs_context", "") tind_metadata = state.get("tind_metadata", "") model = self._model + documents = state.get("documents", []) if not model: return {"messages": [AIMessage(content="Model not available.")]} @@ -121,16 +131,20 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]} prompt = get_langfuse_prompt() - system_messages = prompt.invoke({'context': docs_context, - 'question': latest_message.content}) + system_messages = prompt.invoke({}) + if hasattr(system_messages, "messages"): all_messages = summarized_conversation + system_messages.messages else: all_messages = summarized_conversation + [system_messages] # Get response from model - response = model.invoke(all_messages) - + response = model.invoke( + all_messages, + additional_model_request_fields={"documents": documents}, + additional_model_response_field_paths=["/citations"] + ) + # print(response.response_metadata) # Create clean response content response_content = str(response.content) if hasattr(response, 'content') else str(response) response_messages: list[AnyMessage] = [AIMessage(content=response_content), From e999fa839c0b7474b1d2c7e41654615775d79f56 Mon Sep 17 00:00:00 2001 From: Jason Raitz Date: Fri, 19 Dec 2025 11:27:10 -0500 Subject: [PATCH 2/4] preserving cohere response citations - this gets cohere specific response field that includes citations for the response text --- willa/chatbot/graph_manager.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/willa/chatbot/graph_manager.py b/willa/chatbot/graph_manager.py index d3347c3..298d973 100644 --- a/willa/chatbot/graph_manager.py +++ b/willa/chatbot/graph_manager.py @@ -23,6 +23,7 @@ class WillaChatbotState(TypedDict): search_query: NotRequired[str] tind_metadata: NotRequired[str] documents: NotRequired[list[Any]] + citations: NotRequired[list[dict[str, Any]]] context: NotRequired[dict[str, Any]] @@ -144,7 +145,12 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess additional_model_request_fields={"documents": documents}, additional_model_response_field_paths=["/citations"] ) - # print(response.response_metadata) + citations = response.response_metadata.get('additionalModelResponseFields').get('citations') if response.response_metadata else None + + # add citations to graph state + if citations: + state['citations'] = citations + # Create clean response content response_content = str(response.content) if hasattr(response, 'content') else str(response) response_messages: list[AnyMessage] = [AIMessage(content=response_content), From cd976ff36305d54a227bc6c5dfe54787b6942b34 Mon Sep 17 00:00:00 2001 From: Jason Raitz Date: Mon, 22 Dec 2025 14:51:07 -0500 Subject: [PATCH 3/4] add prepare generation node - temporarily add raw citations to response. --- willa/chatbot/graph_manager.py | 55 ++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/willa/chatbot/graph_manager.py b/willa/chatbot/graph_manager.py index 298d973..bd4ae5e 100644 --- a/willa/chatbot/graph_manager.py +++ b/willa/chatbot/graph_manager.py @@ -53,13 +53,15 @@ def _create_workflow(self) -> CompiledStateGraph: workflow.add_node("summarize", summarization_node) workflow.add_node("prepare_search", self._prepare_search_query) workflow.add_node("retrieve_context", self._retrieve_context) + workflow.add_node("prepare_for_generation", self._prepare_for_generation) workflow.add_node("generate_response", self._generate_response) # Define edges workflow.add_edge("filter_messages", "summarize") workflow.add_edge("summarize", "prepare_search") workflow.add_edge("prepare_search", "retrieve_context") - workflow.add_edge("retrieve_context", "generate_response") + workflow.add_edge("retrieve_context", "prepare_for_generation") + workflow.add_edge("prepare_for_generation", "generate_response") workflow.set_entry_point("filter_messages") workflow.set_finish_point("generate_response") @@ -109,50 +111,51 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]: return {"docs_context": docs_context, "tind_metadata": tind_metadata, "documents": formatted_documents} - # This should be refactored probably. Very bulky - def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: - """Generate response using the model.""" + def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: + """Prepare the current and past messages for response generation.""" messages = state["messages"] summarized_conversation = state.get("summarized_messages", messages) - docs_context = state.get("docs_context", "") - tind_metadata = state.get("tind_metadata", "") - model = self._model - documents = state.get("documents", []) - - if not model: - return {"messages": [AIMessage(content="Model not available.")]} - - # Get the latest human message - latest_message = next( - (msg for msg in reversed(messages) if isinstance(msg, HumanMessage)), - None - ) - - if not latest_message: + + if not any(isinstance(msg, HumanMessage) for msg in messages): return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]} - + prompt = get_langfuse_prompt() system_messages = prompt.invoke({}) - + if hasattr(system_messages, "messages"): all_messages = summarized_conversation + system_messages.messages else: all_messages = summarized_conversation + [system_messages] + + return {"messages": all_messages} + + def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: + """Generate response using the model.""" + tind_metadata = state.get("tind_metadata", "") + model = self._model + documents = state.get("documents", []) + messages = state["messages"] + + if not model: + return {"messages": [AIMessage(content="Model not available.")]} # Get response from model response = model.invoke( - all_messages, + messages, additional_model_request_fields={"documents": documents}, additional_model_response_field_paths=["/citations"] ) citations = response.response_metadata.get('additionalModelResponseFields').get('citations') if response.response_metadata else None - # add citations to graph state - if citations: - state['citations'] = citations - # Create clean response content response_content = str(response.content) if hasattr(response, 'content') else str(response) + + if citations: + state['citations'] = citations + response_content += "\n\nCitations:\n" + for citation in citations: + response_content += f"- {citation.get('text', '')} (docs: {citation.get('document_ids', [])})\n" + response_messages: list[AnyMessage] = [AIMessage(content=response_content), ChatMessage(content=tind_metadata, role='TIND', response_metadata={'tind': True})] From 326d7318448a1dc55995a76f414c08149f7a28e2 Mon Sep 17 00:00:00 2001 From: Jason Raitz Date: Tue, 23 Dec 2025 16:16:04 -0500 Subject: [PATCH 4/4] improving citation output prep --- willa/chatbot/graph_manager.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/willa/chatbot/graph_manager.py b/willa/chatbot/graph_manager.py index bd4ae5e..8c948d3 100644 --- a/willa/chatbot/graph_manager.py +++ b/willa/chatbot/graph_manager.py @@ -1,4 +1,5 @@ """Manages the shared state and workflow for Willa chatbots.""" +import re from typing import Any, Optional, Annotated, NotRequired from typing_extensions import TypedDict @@ -19,12 +20,10 @@ class WillaChatbotState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] filtered_messages: NotRequired[list[AnyMessage]] summarized_messages: NotRequired[list[AnyMessage]] - docs_context: NotRequired[str] search_query: NotRequired[str] tind_metadata: NotRequired[str] documents: NotRequired[list[Any]] citations: NotRequired[list[dict[str, Any]]] - context: NotRequired[dict[str, Any]] class GraphManager: # pylint: disable=too-few-public-methods @@ -91,25 +90,27 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]: vector_store = self._vector_store if not search_query or not vector_store: - return {"docs_context": "", "tind_metadata": "", "documents": []} + return {"tind_metadata": "", "documents": []} # Search for relevant documents retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])}) matching_docs = retriever.invoke(search_query) formatted_documents = [ { + "id": f"{doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]}_{i}", "page_content": doc.page_content, - "start_index": str(doc.metadata.get('start_index')) if doc.metadata.get('start_index') else '', - "total_pages": str(doc.metadata.get('total_pages')) if doc.metadata.get('total_pages') else '', + "title": doc.metadata.get('tind_metadata', {}).get('title', [''])[0], + "project": doc.metadata.get('tind_metadata', {}).get('isPartOf', [''])[0], + "tind_link": format_tind_context.get_tind_url( + doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]) } - for doc in matching_docs + for i, doc in enumerate(matching_docs, 1) ] - # Format context and metadata - docs_context = '\n\n'.join(doc.page_content for doc in matching_docs) + # Format tind metadata tind_metadata = format_tind_context.get_tind_context(matching_docs) - return {"docs_context": docs_context, "tind_metadata": tind_metadata, "documents": formatted_documents} + return {"tind_metadata": tind_metadata, "documents": formatted_documents} def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: """Prepare the current and past messages for response generation.""" @@ -154,7 +155,9 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess state['citations'] = citations response_content += "\n\nCitations:\n" for citation in citations: - response_content += f"- {citation.get('text', '')} (docs: {citation.get('document_ids', [])})\n" + doc_ids = list(dict.fromkeys([re.sub(r'_\d*$', '', doc_id) + for doc_id in citation.get('document_ids', [])])) + response_content += f"- {citation.get('text', '')} ({', '.join(doc_ids)})\n" response_messages: list[AnyMessage] = [AIMessage(content=response_content), ChatMessage(content=tind_metadata, role='TIND',