Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 79 additions & 12 deletions create_index_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@


class IndexCreator:
"""Interactive index creation utility."""
"""Interactive index creation utility for LocalGPT RAG system."""

def __init__(self, config_path: Optional[str] = None):
"""Initialize the index creator with optional custom configuration."""
"""
Initialize the index creator with optional custom configuration.

Args:
config_path: Optional path to custom configuration file. If not provided,
uses default configuration.
"""
self.db = ChatDatabase()
self.config = self._load_config(config_path)

Expand All @@ -56,7 +62,15 @@ def __init__(self, config_path: Optional[str] = None):
)

def _load_config(self, config_path: Optional[str] = None) -> dict:
"""Load configuration from file or use default."""
"""
Load configuration from file or use default.

Args:
config_path: Optional path to configuration file.

Returns:
Dictionary containing configuration settings.
"""
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
Expand All @@ -68,14 +82,31 @@ def _load_config(self, config_path: Optional[str] = None) -> dict:
return PIPELINE_CONFIGS.get("default", {})

def get_user_input(self, prompt: str, default: str = "") -> str:
"""Get user input with optional default value."""
"""
Get user input with optional default value.

Args:
prompt: The prompt message to display to the user.
default: Default value to use if user provides no input.

Returns:
User input string or default value if no input provided.
"""
if default:
user_input = input(f"{prompt} [{default}]: ").strip()
return user_input if user_input else default
return input(f"{prompt}: ").strip()

def select_documents(self) -> List[str]:
"""Interactive document selection."""
"""
Interactive document selection interface.

Provides options to add single documents, entire directories,
and manage the selected document list.

Returns:
List of absolute paths to selected documents.
"""
print("\nπŸ“ Document Selection")
print("=" * 50)

Expand Down Expand Up @@ -141,7 +172,15 @@ def select_documents(self) -> List[str]:
return documents

def configure_processing(self) -> dict:
"""Interactive processing configuration."""
"""
Interactive processing configuration interface.

Allows users to configure document processing parameters including
chunk size, overlap, enrichment options, and model selection.

Returns:
Dictionary containing processing configuration settings.
"""
print("\nβš™οΈ Processing Configuration")
print("=" * 50)

Expand Down Expand Up @@ -175,7 +214,12 @@ def configure_processing(self) -> dict:
}

def create_index_interactive(self) -> None:
"""Run the interactive index creation process."""
"""
Run the interactive index creation process.

Guides the user through the complete index creation workflow including
naming, document selection, configuration, and processing.
"""
print("πŸš€ LocalGPT Index Creation Tool")
print("=" * 50)

Expand Down Expand Up @@ -237,7 +281,12 @@ def create_index_interactive(self) -> None:
traceback.print_exc()

def test_index(self, index_id: str) -> None:
"""Test the created index with a sample query."""
"""
Test the created index with a sample query.

Args:
index_id: The ID of the index to test.
"""
try:
print("\nπŸ§ͺ Testing Index")
print("=" * 50)
Expand All @@ -258,7 +307,15 @@ def test_index(self, index_id: str) -> None:
print(f"❌ Error testing index: {e}")

def batch_create_from_config(self, config_file: str) -> None:
"""Create index from batch configuration file."""
"""
Create index from batch configuration file.

Processes a JSON configuration file containing index settings and
document paths for automated index creation.

Args:
config_file: Path to the batch configuration JSON file.
"""
try:
with open(config_file, 'r') as f:
batch_config = json.load(f)
Expand Down Expand Up @@ -312,7 +369,12 @@ def batch_create_from_config(self, config_file: str) -> None:


def create_sample_batch_config():
"""Create a sample batch configuration file."""
"""
Create a sample batch configuration file.

Generates a JSON configuration file with example settings that can be
used as a template for batch index creation.
"""
sample_config = {
"index_name": "Sample Batch Index",
"index_description": "Example batch index configuration",
Expand Down Expand Up @@ -340,7 +402,12 @@ def create_sample_batch_config():


def main():
"""Main entry point for the script."""
"""
Main entry point for the script.

Parses command line arguments and executes the appropriate index creation
workflow based on the provided options.
"""
parser = argparse.ArgumentParser(description="LocalGPT Index Creation Tool")
parser.add_argument("--batch", help="Batch configuration file", type=str)
parser.add_argument("--config", help="Custom pipeline configuration file", type=str)
Expand Down Expand Up @@ -369,4 +436,4 @@ def main():


if __name__ == "__main__":
main()
main()
106 changes: 94 additions & 12 deletions rag_system/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
# -------------- Helper ----------------

def _apply_index_embedding_model(idx_ids):
"""Ensure retrieval pipeline uses the embedding model stored with the first index."""
"""Ensure retrieval pipeline uses the embedding model stored with the first index.

Args:
idx_ids (list): List of index IDs to check for embedding model metadata.
"""
debug_info = f"πŸ”§ _apply_index_embedding_model called with idx_ids: {idx_ids}\n"

if not idx_ids:
Expand Down Expand Up @@ -69,7 +73,14 @@ def _apply_index_embedding_model(idx_ids):
f.write(debug_info)

def _get_table_name_for_session(session_id):
"""Get the correct vector table name for a session by looking up its linked indexes."""
"""Get the correct vector table name for a session by looking up its linked indexes.

Args:
session_id (str): The session ID to look up indexes for.

Returns:
str or None: The vector table name for the session, or None if not found.
"""
logger = logging.getLogger(__name__)

if not session_id:
Expand Down Expand Up @@ -113,6 +124,12 @@ def _get_table_name_for_session(session_id):
return default_table

class AdvancedRagApiHandler(http.server.BaseHTTPRequestHandler):
"""HTTP request handler for the RAG API server.

Handles POST requests for chat interactions, streaming responses, and document indexing.
Also handles GET requests for retrieving available models.
"""

def do_OPTIONS(self):
"""Handle CORS preflight requests for frontend integration."""
self.send_response(200)
Expand All @@ -122,7 +139,13 @@ def do_OPTIONS(self):
self.end_headers()

def do_POST(self):
"""Handle POST requests for chat and indexing."""
"""Handle POST requests for chat and indexing endpoints.

Routes requests to appropriate handlers based on the URL path:
- /chat: Regular chat interactions
- /chat/stream: Streaming chat responses
- /index: Document indexing
"""
parsed_path = urlparse(self.path)

if parsed_path.path == '/chat':
Expand All @@ -135,6 +158,11 @@ def do_POST(self):
self.send_json_response({"error": "Not Found"}, status_code=404)

def do_GET(self):
"""Handle GET requests for retrieving information.

Routes requests to appropriate handlers based on the URL path:
- /models: List available models
"""
parsed_path = urlparse(self.path)

if parsed_path.path == '/models':
Expand All @@ -143,7 +171,16 @@ def do_GET(self):
self.send_json_response({"error": "Not Found"}, status_code=404)

def handle_chat(self):
"""Handles a chat query by calling the agentic RAG pipeline."""
"""Handle a chat query by calling the agentic RAG pipeline.

Processes JSON request body containing query, session_id, and various configuration flags.
Updates session title for first messages and stores user/assistant messages in database.
Returns JSON response with the generated answer and metadata.

Raises:
json.JSONDecodeError: If the request body contains invalid JSON.
Exception: For various server errors during processing.
"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
Expand Down Expand Up @@ -302,7 +339,17 @@ def handle_chat(self):
self.send_json_response({"error": f"Server error: {str(e)}"}, status_code=500)

def handle_chat_stream(self):
"""Stream internal phases and final answer using SSE (text/event-stream)."""
"""Stream internal phases and final answer using SSE (text/event-stream).

Similar to handle_chat but streams responses using Server-Sent Events.
Emits events for different processing phases and the final result.
Handles client disconnections gracefully.

Raises:
json.JSONDecodeError: If the request body contains invalid JSON.
BrokenPipeError: If the client disconnects during streaming.
Exception: For various server errors during processing.
"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
Expand Down Expand Up @@ -385,7 +432,15 @@ def handle_chat_stream(self):
self.end_headers()

def emit(event_type: str, payload):
"""Send a single SSE event."""
"""Send a single SSE event.

Args:
event_type (str): The type of event to emit.
payload: The data payload to send with the event.

Raises:
BrokenPipeError: If the client has disconnected.
"""
try:
data_str = json.dumps({"type": event_type, "data": payload})
self.wfile.write(f"data: {data_str}\n\n".encode('utf-8'))
Expand Down Expand Up @@ -499,7 +554,16 @@ def emit(event_type: str, payload):
self.send_json_response({"error": f"Server error: {str(e)}"}, status_code=500)

def handle_index(self):
"""Triggers the document indexing pipeline for specific files."""
"""Trigger the document indexing pipeline for specific files.

Processes JSON request body containing file paths and indexing configuration.
Creates temporary pipeline instances with custom configurations for each indexing job.
Updates index metadata with embedding model information.

Raises:
json.JSONDecodeError: If the request body contains invalid JSON.
Exception: For various server errors during indexing.
"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
Expand Down Expand Up @@ -690,7 +754,15 @@ def handle_index(self):
self.send_json_response({"error": f"Failed to start indexing: {str(e)}"}, status_code=500)

def handle_models(self):
"""Return a list of locally installed Ollama models and supported HuggingFace models, grouped by capability."""
"""Return a list of locally installed Ollama models and supported HuggingFace models, grouped by capability.

Queries the Ollama API for available models and classifies them as generation or embedding models.
Also includes a predefined list of supported HuggingFace embedding models.
Returns JSON response with models grouped by capability.

Raises:
Exception: If there are errors querying the Ollama API or processing model data.
"""
try:
generation_models = []
embedding_models = []
Expand Down Expand Up @@ -732,7 +804,12 @@ def handle_models(self):
self.send_json_response({"error": f"Could not list models: {e}"}, status_code=500)

def send_json_response(self, data, status_code=200):
"""Utility to send a JSON response with CORS headers."""
"""Send a JSON response with CORS headers.

Args:
data: The data to serialize as JSON and send in the response body.
status_code (int): HTTP status code to send. Defaults to 200.
"""
self.send_response(status_code)
self.send_header('Content-Type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
Expand All @@ -741,17 +818,22 @@ def send_json_response(self, data, status_code=200):
self.wfile.write(response.encode('utf-8'))

def start_server(port=8001):
"""Starts the API server."""
"""Start the API server on the specified port.

Args:
port (int): Port number to bind the server to. Defaults to 8001.
"""
# Use a reusable TCP server to avoid "address in use" errors on restart
class ReusableTCPServer(socketserver.TCPServer):
"""TCP server that allows address reuse to prevent binding errors on restart."""
allow_reuse_address = True

with ReusableTCPServer(("", port), AdvancedRagApiHandler) as httpd:
with ReusableTCPServer("", port, AdvancedRagApiHandler) as httpd:
print(f"πŸš€ Starting Advanced RAG API server on port {port}")
print(f"πŸ’¬ Chat endpoint: http://localhost:{port}/chat")
print(f"✨ Indexing endpoint: http://localhost:{port}/index")
httpd.serve_forever()

if __name__ == "__main__":
# To run this server: python -m rag_system.api_server
start_server()
start_server()
Loading