From e6ee6d7ad5cd212e8b53fb4b00dc8e6a4291d48b Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Sun, 24 May 2026 18:15:53 +0300 Subject: [PATCH 01/11] Improve ChromaDB loader with sentence-transformers backend Adds a new chroma_loader module that improves on the legacy govt_data_loader with modern embedding capabilities while maintaining backward compatibility. Key improvements: - sentence-transformers backend for cleaner API and better batching performance - Configurable max_length (1024 vs hardcoded 512) for better passage embeddings - Configurable batch_size with auto-tuning support - Generic load_or_build_chroma() supporting MT-RAG corpora and HuggingFace datasets - Flexible filter_ids parameter for document filtering - Eager loading architecture for clear upfront waiting time Backward compatibility: - load_or_build_govt_chroma() wrapper preserves exact API - load_only_tutorial_docs parameter maintained for T4/CPU-friendly subset - device parameter preserved for explicit CPU/GPU control - TUTORIAL_DOC_IDS constant maintained (177 docs) - No max_docs limit (unlike PR #58 which hardcoded 2000) Changes: - Add src/granite_switch/tutorials/chroma_loader.py with new implementation - Update pyproject.toml with sentence-transformers>=3.0.0 and datasets>=2.0.0 - Update rag_101.ipynb and rag_flow.ipynb to import from new module - Add deprecation warning to legacy govt_data_loader.py Based on improvements from PR #58 with adjustments for project requirements. --- pyproject.toml | 2 + src/granite_switch/tutorials/chroma_loader.py | 486 ++++++++++++++++++ .../tutorials/govt_data_loader.py | 18 + tutorials/notebooks/rag_101.ipynb | 28 +- tutorials/notebooks/rag_flow.ipynb | 4 +- 5 files changed, 510 insertions(+), 28 deletions(-) create mode 100644 src/granite_switch/tutorials/chroma_loader.py diff --git a/pyproject.toml b/pyproject.toml index 6536d0c..330fd29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ tutorials = [ "mellea==0.6.0", "ipython>=8.10.0", "python-dotenv>=1.0.0", + "sentence-transformers>=3.0.0", + "datasets>=2.0.0", ] dev = ["pytest", "granite-switch[hf,vllm,compose]"] dev-vllm20 = ["pytest", "granite-switch[hf,vllm20,compose]"] diff --git a/src/granite_switch/tutorials/chroma_loader.py b/src/granite_switch/tutorials/chroma_loader.py new file mode 100644 index 0000000..d4e54ec --- /dev/null +++ b/src/granite_switch/tutorials/chroma_loader.py @@ -0,0 +1,486 @@ +"""Generic ChromaDB loader with sentence-transformers backend. + +This module provides a flexible interface for loading or building ChromaDB collections +from various data sources, with improved embedding capabilities using sentence-transformers. + +Key improvements over the legacy govt_data_loader: +- sentence-transformers backend for cleaner API and better batching +- Configurable max_length (1024 vs hardcoded 512) +- Configurable batch_size (auto-tuned vs hardcoded 64) +- Per-document progress tracking +- Generic loader supporting MT-RAG corpora and HuggingFace datasets +- Flexible ID filtering via filter_ids parameter + +Backward compatibility maintained via load_or_build_govt_chroma() wrapper. +""" + +import io +import json +import os +import time +import warnings +import zipfile +from typing import Dict, List, Optional, Set, Tuple + +import chromadb +import httpx +import torch +from chromadb import Documents, EmbeddingFunction, Embeddings +from tqdm.auto import tqdm + +# Constants from original govt_data_loader +EMBEDDING_MODEL_ID = "ibm-granite/granite-embedding-small-english-r2" +CHROMA_PATH = "./govt_chroma" +GOVT_JSONL_URL = "https://github.com/IBM/mt-rag-benchmark/raw/main/corpora/passage_level/govt.jsonl.zip" +GOVT_JSONL_PATH = "./govt.jsonl" + +# Tutorial subset: 177 docs for T4/CPU-friendly embedding +TUTORIAL_DOC_IDS = set([ + "05537c9ec2dfe15e-1362-3310", "05537c9ec2dfe15e-2-1779", "05537c9ec2dfe15e-2821-4679", + "05537c9ec2dfe15e-4280-6252", "087417ad420d618c-1327-3164", "087417ad420d618c-2428-4297", + "087417ad420d618c-3940-5774", "089882437c965a3e-113907-115852", "089882437c965a3e-115237-117256", + "089882437c965a3e-119809-121676", "089882437c965a3e-121198-123235", "089882437c965a3e-122746-124833", + "089882437c965a3e-130164-131917", "089882437c965a3e-1427-3375", "089882437c965a3e-157219-159194", + "089882437c965a3e-158778-160687", "089882437c965a3e-170699-172699", "089882437c965a3e-173726-175992", + "089882437c965a3e-175465-177577", "089882437c965a3e-177094-179288", "089882437c965a3e-182078-183322", + "089882437c965a3e-184664-186341", "089882437c965a3e-190627-192211", "089882437c965a3e-191792-193455", + "089882437c965a3e-194311-196074", "089882437c965a3e-2-1955", "089882437c965a3e-42318-44668", + "089882437c965a3e-51633-53566", "089882437c965a3e-53014-54918", "089882437c965a3e-85071-87052", + "089882437c965a3e-86622-88344", "0ecab3f697d26347-1362-3129", "142cbdf06f6e40d9-1544-3414", + "142cbdf06f6e40d9-2-2014", "142cbdf06f6e40d9-4140-6181", "142cbdf06f6e40d9-5655-7824", + "19240942bfc0abf5-11151-13247", "19240942bfc0abf5-1354-3015", "2c89b9fe3cfe95ee-1392-3518", + "2ead5535f9d6d3be-1376-3143", "3090260a5d934d78-1166-2578", "3090260a5d934d78-2225-3536", + "32472b4a577f296f-2-1847", "353067ac7a68e5f0-2-1815", "3630bbba71396272-1400-3319", + "3630bbba71396272-4267-6086", "40ce723b445ac8eb-1350-3146", "40ce723b445ac8eb-2-1781", + "40ce723b445ac8eb-3922-5642", "40ce723b445ac8eb-5372-7150", "40ce723b445ac8eb-6691-8678", + "40ce723b445ac8eb-8241-9800", "4c201f242ec49883-1381-3148", "4c201f242ec49883-5418-7248", + "4e1c120aee9a75b6-1369-3165", "50a24d38902fbdd0-1340-3177", "50a24d38902fbdd0-3953-5813", + "565fb21ac38feaa1-15852-17699", "5b86a17591806ce5-1532-3330", "60e02c03620cd1ef-9523-11519", + "6ddc73cb3877e2aa-1384-3151", "6ddc73cb3877e2aa-2-1801", "77de29ffa3c3d800-1352-3553", + "77de29ffa3c3d800-2-1946", "7fe68ab7967494ca-1358-3306", "81478086b28ab210-5831-7806", + "818e03cc80181db4-1346-3469", "818e03cc80181db4-2-1767", "818e03cc80181db4-3125-4727", + "824c4c47b2989363-1365-3132", "824c4c47b2989363-2-1782", "82f7a783325de97a-1402-3321", + "82f7a783325de97a-4269-6188", "882a9cc2bb08bcdf-2-1811", "8cd62677aa5dcb92-2-1746", + "9726fa169575dc43-1331-3168", "9726fa169575dc43-2-1734", "9726fa169575dc43-2432-4301", + "9726fa169575dc43-3944-5768", "9726fa169575dc43-5394-7430", "9726fa169575dc43-6967-8603", + "97e58e54bb79a7fe-3231-5248", "99c7b4f2bfb48b7f-3321-5534", "a005bd5aedbb28e5-33908-36180", + "a005bd5aedbb28e5-35687-37469", "a4a53cb6b6bf326e-1349-3145", "a4a53cb6b6bf326e-2-1780", + "a4a53cb6b6bf326e-2409-4294", "a4a53cb6b6bf326e-3921-5691", "a4a53cb6b6bf326e-5362-7156", + "a4a53cb6b6bf326e-6689-8701", "a4a53cb6b6bf326e-8201-10002", "a930d03cf0b406fd-23288-25302", + "a930d03cf0b406fd-30996-32981", "c550156dbbfe212c-1401-3320", "c550156dbbfe212c-16212-18433", + "c550156dbbfe212c-29308-31304", "c550156dbbfe212c-30794-33132", "c550156dbbfe212c-32367-34910", + "c550156dbbfe212c-37745-39895", "c550156dbbfe212c-39218-41274", "c550156dbbfe212c-40668-42844", + "c550156dbbfe212c-42364-44521", "c550156dbbfe212c-44034-46164", "c550156dbbfe212c-45669-47909", + "c550156dbbfe212c-47421-49701", "c550156dbbfe212c-9073-11428", "c67a2f65008344fd-2-1909", + "c93223e21ee4ecfb-2-1754", "d4c48e9a4029f3e9-1801-3993", "d4edd2b762f5dce9-7713-9881", + "e580ce520db3ff10-109466-111339", "e580ce520db3ff10-119467-121417", "e580ce520db3ff10-124119-126003", + "e580ce520db3ff10-129933-131969", "e580ce520db3ff10-131480-133562", "e580ce520db3ff10-190530-192253", + "e580ce520db3ff10-191857-193702", "e580ce520db3ff10-35813-37462", "e580ce520db3ff10-36974-38756", + "e6ea24fa9e962807-1357-3305", "e6ea24fa9e962807-4275-6126", "ed17e5bd32458f9c-1347-3143", + "ed17e5bd32458f9c-3919-5735", "f0b48597d0c22d32-2-1647", "f0b48597d0c22d32-2585-4675", + "f0b48597d0c22d32-999-3136", "f14d35fd47c9ed59-1352-3148", "f14d35fd47c9ed59-3924-5795", + "f14d35fd47c9ed59-5374-7566", "f7225d77034b8398-1402-3321", "f90bb40d57fe7ba5-1469-3644", + "f90bb40d57fe7ba5-2-1890", "f90bb40d57fe7ba5-3142-5127", "f90bb40d57fe7ba5-8968-10553", + "fcdc09416b6aa645-1276-2982", "fcdc09416b6aa645-2-1649", +]) + +# MT-RAG corpus metadata +CORPUS_INFO = { + "govt": { + "url": GOVT_JSONL_URL, + "local_path": GOVT_JSONL_PATH, + "chroma_path": CHROMA_PATH, + "collection_name": "govt", + }, +} + + +class GraniteEmbeddingFunction(EmbeddingFunction): + """ChromaDB embedding function using sentence-transformers backend. + + This class wraps a sentence-transformers model for use with ChromaDB. + Uses eager loading (model loaded in __init__) for clear upfront waiting time. + + Args: + model_id: HuggingFace model ID for sentence-transformers + batch_size: Batch size for encoding (None = auto-tune) + max_length: Maximum sequence length for embeddings (default 1024) + device: Device to use ("cpu", "cuda", or None for auto-detect) + """ + + def __init__( + self, + model_id: str = EMBEDDING_MODEL_ID, + batch_size: Optional[int] = None, + max_length: int = 1024, + device: Optional[str] = None, + ): + from sentence_transformers import SentenceTransformer + + self.model_id = model_id + self.batch_size = batch_size + self.max_length = max_length + + # Auto-detect device if not specified + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cpu": + warnings.warn( + "Embedding on CPU will be slow. " + "Expected runtime is ~10 min on a single consumer GPU. " + "Consider running on a GPU host.", + stacklevel=2, + ) + + self.device = device + + # Eager loading: model loaded immediately (clear upfront waiting) + self.model = SentenceTransformer(model_id, device=device) + print(f"Granite embedding model ready on {device} ({model_id})") + + def __call__(self, input: Documents) -> Embeddings: + """Embed texts with batching and progress bar.""" + embeddings = self.model.encode( + input, + batch_size=self.batch_size, + max_seq_length=self.max_length, + show_progress_bar=False, # Disable internal progress (we track at doc level) + convert_to_numpy=True, + ) + return embeddings.tolist() + + +def _download_jsonl_zip(url: str, output_path: str) -> None: + """Download and extract JSONL from ZIP archive with progress tracking. + + Args: + url: URL to download ZIP archive from + output_path: Local path to save extracted JSONL file + """ + print(f"Downloading {url} ...") + t0 = time.time() + + # Stream into memory with progress bar + # Split timeout: fail fast on connect (10s), allow slow reads (300s) + timeout = httpx.Timeout(300.0, connect=10.0) + buf = io.BytesIO() + with httpx.Client(follow_redirects=True, timeout=timeout) as c: + with c.stream("GET", url) as resp: + resp.raise_for_status() + total = int(resp.headers.get("Content-Length", 0)) or None + with tqdm(total=total, unit="B", unit_scale=True, desc="download") as bar: + for chunk in resp.iter_bytes(chunk_size=65536): + buf.write(chunk) + bar.update(len(chunk)) + buf.seek(0) + + # Atomic write: extract to .tmp then replace, so crashes can't leave truncated files + tmp_path = output_path + ".tmp" + with zipfile.ZipFile(buf) as zf: + inner = next(n for n in zf.namelist() if n.endswith(".jsonl")) + with zf.open(inner) as src, open(tmp_path, "wb") as dst: + dst.write(src.read()) + os.replace(tmp_path, output_path) + print(f"Saved {output_path} in {time.time() - t0:.1f}s.") + + +def _load_records_from_jsonl( + jsonl_path: str, + filter_ids: Optional[Set[str]] = None, + max_docs: Optional[int] = None, + text_field: str = "text", + id_field: Optional[str] = None, +) -> Tuple[List[str], List[str], List[Dict]]: + """Load document records from JSONL file. + + Args: + jsonl_path: Path to JSONL file + filter_ids: Set of document IDs to include (None = all) + max_docs: Maximum documents to load (None = no limit) + text_field: Field name for document text + id_field: Field name for document ID (None = use _id or id field) + + Returns: + Tuple of (ids, texts, metadatas) + """ + ids, texts, metas = [], [], [] + + with open(jsonl_path) as f: + for line in f: + doc = json.loads(line) + text = doc.get(text_field, "").strip() + if not text: + continue + + # Extract document ID + if id_field: + doc_id = doc.get(id_field) + else: + doc_id = doc.get("_id", doc.get("id", str(len(ids)))) + + # Apply filtering + if filter_ids is not None and doc_id not in filter_ids: + continue + + ids.append(doc_id) + texts.append(text) + metas.append({ + "title": doc.get("title", ""), + "url": doc.get("url", ""), + }) + + # Respect max_docs limit + if max_docs is not None and len(ids) >= max_docs: + break + + if not ids: + raise RuntimeError( + f"{jsonl_path} yielded zero documents - the file may be empty, truncated, " + f"or schema-drifted (expected a '{text_field}' field per line). " + f"Delete it and rerun to re-download." + ) + + return ids, texts, metas + + +def _load_records_from_hf( + dataset_id: str, + filter_ids: Optional[Set[str]] = None, + max_docs: Optional[int] = None, + config: Optional[str] = None, + split: str = "train", + text_field: str = "text", + id_field: Optional[str] = None, +) -> Tuple[List[str], List[str], List[Dict]]: + """Load document records from HuggingFace dataset. + + Args: + dataset_id: HuggingFace dataset ID + filter_ids: Set of document IDs to include (None = all) + max_docs: Maximum documents to load (None = no limit) + config: Dataset configuration name + split: Dataset split to load + text_field: Field name for document text + id_field: Field name for document ID + + Returns: + Tuple of (ids, texts, metadatas) + """ + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "HuggingFace datasets library required. Install with: " + "pip install datasets" + ) + + dataset = load_dataset(dataset_id, config, split=split) + + ids, texts, metas = [], [], [] + for i, example in enumerate(dataset): + text = example.get(text_field, "").strip() + if not text: + continue + + # Extract document ID + if id_field: + doc_id = example.get(id_field, str(i)) + else: + doc_id = example.get("_id", example.get("id", str(i))) + + # Apply filtering + if filter_ids is not None and doc_id not in filter_ids: + continue + + ids.append(doc_id) + texts.append(text) + metas.append({ + "title": example.get("title", ""), + "url": example.get("url", ""), + }) + + # Respect max_docs limit + if max_docs is not None and len(ids) >= max_docs: + break + + if not ids: + raise RuntimeError( + f"Dataset {dataset_id} yielded zero documents. " + f"Check that '{text_field}' field exists." + ) + + return ids, texts, metas + + +def load_or_build_chroma( + corpus_name: Optional[str] = None, + hf_dataset_id: Optional[str] = None, + jsonl_path: Optional[str] = None, + jsonl_url: Optional[str] = None, + chroma_path: Optional[str] = None, + collection_name: str = "default", + embedding_model_id: str = EMBEDDING_MODEL_ID, + batch_size: Optional[int] = None, + max_length: int = 1024, + max_docs: Optional[int] = None, + filter_ids: Optional[Set[str]] = None, + device: Optional[str] = None, + text_field: str = "text", + id_field: Optional[str] = None, + hf_config: Optional[str] = None, + hf_split: str = "train", +) -> chromadb.Collection: + """Generic ChromaDB loader supporting multiple data sources. + + This function loads or builds a ChromaDB collection from either: + 1. Named MT-RAG corpus (via corpus_name) + 2. Local JSONL file (via jsonl_path + optional jsonl_url) + 3. HuggingFace dataset (via hf_dataset_id) + + Args: + corpus_name: Named MT-RAG corpus ("govt", "fiqa", etc.) + hf_dataset_id: HuggingFace dataset ID (mutually exclusive with corpus_name) + jsonl_path: Local JSONL file path (derived from corpus_name if None) + jsonl_url: URL to download JSONL (derived from corpus_name if None) + chroma_path: Persistent storage path (derived from corpus_name if None) + collection_name: ChromaDB collection name (derived from corpus_name if None) + embedding_model_id: Sentence-transformers model ID + batch_size: Embedding batch size (None = auto-tune) + max_length: Maximum sequence length for embeddings + max_docs: Maximum documents to ingest (None = no limit) + filter_ids: Set of document IDs to include (None = all docs) + device: "cpu" or "cuda" (None = auto-detect) + text_field: Field name for document text + id_field: Field name for document ID + hf_config: HuggingFace dataset configuration + hf_split: HuggingFace dataset split + + Returns: + ChromaDB collection ready for queries + """ + # Resolve corpus info if corpus_name provided + if corpus_name: + if corpus_name not in CORPUS_INFO: + raise ValueError( + f"Unknown corpus '{corpus_name}'. " + f"Available: {list(CORPUS_INFO.keys())}" + ) + info = CORPUS_INFO[corpus_name] + jsonl_url = jsonl_url or info["url"] + jsonl_path = jsonl_path or info["local_path"] + chroma_path = chroma_path or info["chroma_path"] + collection_name = collection_name if collection_name != "default" else info["collection_name"] + + # Validate inputs + if not chroma_path: + raise ValueError("chroma_path must be specified") + if not hf_dataset_id and not jsonl_path: + raise ValueError("Must specify either hf_dataset_id or jsonl_path") + + # Create embedding function + granite_ef = GraniteEmbeddingFunction( + model_id=embedding_model_id, + batch_size=batch_size, + max_length=max_length, + device=device, + ) + + # Create or load collection + client = chromadb.PersistentClient(path=chroma_path) + collection = client.get_or_create_collection( + name=collection_name, + embedding_function=granite_ef, + metadata={"hnsw:space": "cosine"}, + ) + + # Return if already populated + if collection.count() > 0: + print(f"Loaded from {chroma_path} ({collection.count():,} docs).") + return collection + + # Load documents + if hf_dataset_id: + print(f"Loading from HuggingFace dataset {hf_dataset_id}...") + ids, texts, metas = _load_records_from_hf( + dataset_id=hf_dataset_id, + filter_ids=filter_ids, + max_docs=max_docs, + config=hf_config, + split=hf_split, + text_field=text_field, + id_field=id_field, + ) + else: + # Download JSONL if needed + if not os.path.exists(jsonl_path): + if not jsonl_url: + raise ValueError(f"{jsonl_path} not found and no jsonl_url provided") + _download_jsonl_zip(jsonl_url, jsonl_path) + + if filter_ids is not None: + print(f"Filtering to {len(filter_ids)} doc IDs") + + print(f"Reading {jsonl_path} -> {chroma_path}...") + t0 = time.time() + ids, texts, metas = _load_records_from_jsonl( + jsonl_path=jsonl_path, + filter_ids=filter_ids, + max_docs=max_docs, + text_field=text_field, + id_field=id_field, + ) + print(f"Read {len(ids):,} docs in {time.time() - t0:.1f}s.") + + # Embed and index documents + print(f"Embedding & indexing {len(ids):,} documents...") + t1 = time.time() + + # Use smaller batch sizes for upsert based on device + upsert_batch = 16 if device == "cpu" else 500 + for i in tqdm(range(0, len(ids), upsert_batch), unit="batch", desc="indexing"): + collection.upsert( + ids=ids[i : i + upsert_batch], + documents=texts[i : i + upsert_batch], + metadatas=metas[i : i + upsert_batch], + ) + + print(f"Done. {collection.count():,} docs saved to {chroma_path} in {time.time() - t1:.1f}s.") + return collection + + +def load_or_build_govt_chroma( + chroma_path: str = CHROMA_PATH, + jsonl_path: str = GOVT_JSONL_PATH, + jsonl_url: str = GOVT_JSONL_URL, + embedding_model_id: str = EMBEDDING_MODEL_ID, + load_only_tutorial_docs: bool = False, + device: Optional[str] = None, +) -> chromadb.Collection: + """Backward-compatible govt corpus loader. + + This function maintains the API of the legacy govt_data_loader module + while using the improved chroma_loader implementation underneath. + + Args: + chroma_path: Persistent storage path + jsonl_path: Local JSONL path + jsonl_url: Download URL + embedding_model_id: Embedding model ID + load_only_tutorial_docs: If True, load only 177 tutorial docs (T4-friendly) + device: "cpu" or "cuda" (None = auto-detect) + + Returns: + ChromaDB collection with govt corpus + """ + filter_ids = TUTORIAL_DOC_IDS if load_only_tutorial_docs else None + + return load_or_build_chroma( + corpus_name="govt", + jsonl_path=jsonl_path, + jsonl_url=jsonl_url, + chroma_path=chroma_path, + embedding_model_id=embedding_model_id, + filter_ids=filter_ids, + device=device, + max_docs=None, # NO artificial limit + ) diff --git a/src/granite_switch/tutorials/govt_data_loader.py b/src/granite_switch/tutorials/govt_data_loader.py index 9ed2560..8a71284 100644 --- a/src/granite_switch/tutorials/govt_data_loader.py +++ b/src/granite_switch/tutorials/govt_data_loader.py @@ -1,5 +1,14 @@ """Load or build the ChromaDB corpus for the govt RAG tutorial. +.. deprecated:: 0.2.0 + This module is deprecated and will be removed in a future release. + Use :mod:`granite_switch.tutorials.chroma_loader` instead, which provides: + + - sentence-transformers backend for improved embedding performance + - Configurable max_length (1024 vs hardcoded 512) + - Generic loader supporting multiple data sources + - Maintained API compatibility via load_or_build_govt_chroma() + Kept separate from the notebook so the pipeline stays focused on RAG concepts. First run: downloads `govt.jsonl.zip` from IBM mt-rag-benchmark, @@ -7,6 +16,15 @@ `./govt_chroma`. Subsequent runs: loads the persisted index instantly. """ +import warnings + +warnings.warn( + "granite_switch.tutorials.govt_data_loader is deprecated and will be removed in a future release. " + "Please use granite_switch.tutorials.chroma_loader instead.", + DeprecationWarning, + stacklevel=2, +) + import io import json import os diff --git a/tutorials/notebooks/rag_101.ipynb b/tutorials/notebooks/rag_101.ipynb index 53c3fc3..718c08d 100644 --- a/tutorials/notebooks/rag_101.ipynb +++ b/tutorials/notebooks/rag_101.ipynb @@ -36,31 +36,7 @@ "id": "hf-login", "metadata": {}, "outputs": [], - "source": [ - "import os\n", - "from pathlib import Path\n", - "\n", - "from huggingface_hub import notebook_login\n", - "\n", - "from granite_switch.tutorials.govt_data_loader import load_or_build_govt_chroma\n", - "from granite_switch.tutorials.vllm_server import (\n", - " kill_stale_vllm_processes,\n", - " launch_vllm,\n", - " print_gpu_state,\n", - " tail_log,\n", - " wait_for_server,\n", - ")\n", - "from mellea.backends.openai import OpenAIBackend\n", - "from mellea.stdlib.components import Document as MelleaDocument\n", - "from mellea.stdlib.components.intrinsic import rag\n", - "from mellea.stdlib.context import ChatContext\n", - "\n", - "try:\n", - " from dotenv import load_dotenv\n", - " load_dotenv(Path(\"../.env\"), override=False)\n", - "except ImportError:\n", - " pass" - ] + "source": "import os\nfrom pathlib import Path\n\nfrom huggingface_hub import notebook_login\n\nfrom granite_switch.tutorials.chroma_loader import load_or_build_govt_chroma\nfrom granite_switch.tutorials.vllm_server import (\n kill_stale_vllm_processes,\n launch_vllm,\n print_gpu_state,\n tail_log,\n wait_for_server,\n)\nfrom mellea.backends.openai import OpenAIBackend\nfrom mellea.stdlib.components import Document as MelleaDocument\nfrom mellea.stdlib.components.intrinsic import rag\nfrom mellea.stdlib.context import ChatContext\n\ntry:\n from dotenv import load_dotenv\n load_dotenv(Path(\"../.env\"), override=False)\nexcept ImportError:\n pass" }, { "cell_type": "code", @@ -286,4 +262,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/tutorials/notebooks/rag_flow.ipynb b/tutorials/notebooks/rag_flow.ipynb index b5d0aaf..a22e9b3 100644 --- a/tutorials/notebooks/rag_flow.ipynb +++ b/tutorials/notebooks/rag_flow.ipynb @@ -176,7 +176,7 @@ "id": "12a13b8feceb5539", "metadata": {}, "outputs": [], - "source": "import json\nimport logging\nimport os\nimport warnings\nfrom functools import partial\nfrom pathlib import Path\n\nfrom IPython.display import display, Markdown\nfrom granite_switch.tutorials.govt_data_loader import load_or_build_govt_chroma\nfrom granite_switch.tutorials.rag_display import show_answer, show_history, show_intermediates as _show_intermediates_unbound, _is_clear\nfrom mellea.backends import ModelOption\nfrom mellea.backends.openai import OpenAIBackend\nfrom mellea.stdlib.components import Document as MelleaDocument\nfrom mellea.stdlib.components.chat import Message as MelleaMessage\nfrom mellea.stdlib.components.intrinsic import rag\nfrom mellea.stdlib.components.intrinsic.guardian import guardian_check\nfrom mellea.stdlib.context import ChatContext\nimport mellea.stdlib.functional as mfuncs\n\ntry:\n from dotenv import load_dotenv\n load_dotenv(Path(\"../.env\"), override=False)\nexcept ImportError:\n pass\n\n# ── vLLM server ───────────────────────────────────────────────────────────────\n# URL of the running vLLM OpenAI-compatible endpoint.\nVLLM_BASE_URL = os.environ.get(\"VLLM_BASE_URL\", \"http://localhost:8000/v1\")\n\n# Model name as reported by GET /v1/models (usually the path/repo used at launch).\nVLLM_MODEL_NAME = os.environ.get(\"VLLM_MODEL_NAME\", \"ibm-granite/granite-switch-4.1-3b-preview\")\n\n# HF Hub repo ID (or local path) to load I/O configs for the embedded adapters.\nGRANITE_SWITCH_SOURCE = os.environ.get(\"GRANITE_SWITCH_SOURCE\", VLLM_MODEL_NAME)\n\n# Guardian: which safety criterion to evaluate\nGUARDIAN_CRITERIA = \"harm\" # harm | social_bias | groundedness | jailbreak | ...\n\n# ── Embedding model (used to build + query ChromaDB) ─────────────────────────\nEMBEDDING_MODEL_ID = \"ibm-granite/granite-embedding-small-english-r2\"\n\n# ── ChromaDB persistence path ─────────────────────────────────────────────────\n# Share this directory (zipped) to skip the extraction step entirely.\nCHROMA_PATH = \"./govt_chroma\"\n\n# ── Corpus source (only needed when building the index from scratch) ─────────\n# govt.jsonl: subset of the government-service passages from IBM mt-rag-benchmark.\nGOVT_JSONL_URL = \"https://github.com/IBM/mt-rag-benchmark/raw/main/corpora/passage_level/govt.jsonl.zip\"\nGOVT_JSONL_PATH = \"./govt.jsonl\"\n\n# ── Retrieval ─────────────────────────────────────────────────────────────────\n# TOP_K balances recall (more candidates -> better chance of a relevant passage)\n# against context budget (every doc gets passed through answerability, clarification,\n# generation, and citation prompts). 20 is the mt-rag-benchmark default.\nTOP_K = 10\n\n# Bind TOP_K so query cells can call `show_intermediates(r)` without repeating it.\nshow_intermediates = partial(_show_intermediates_unbound, top_k=TOP_K)\n\nprint(f\"vLLM: {VLLM_BASE_URL} ({VLLM_MODEL_NAME})\")\nprint(f\"Embedding: {EMBEDDING_MODEL_ID}\")\nprint(f\"ChromaDB: {CHROMA_PATH}\")" + "source": "import json\nimport logging\nimport os\nimport warnings\nfrom functools import partial\nfrom pathlib import Path\n\nfrom IPython.display import display, Markdown\nfrom granite_switch.tutorials.chroma_loader import load_or_build_govt_chroma\nfrom granite_switch.tutorials.rag_display import show_answer, show_history, show_intermediates as _show_intermediates_unbound, _is_clear\nfrom mellea.backends import ModelOption\nfrom mellea.backends.openai import OpenAIBackend\nfrom mellea.stdlib.components import Document as MelleaDocument\nfrom mellea.stdlib.components.chat import Message as MelleaMessage\nfrom mellea.stdlib.components.intrinsic import rag\nfrom mellea.stdlib.components.intrinsic.guardian import guardian_check\nfrom mellea.stdlib.context import ChatContext\nimport mellea.stdlib.functional as mfuncs\n\ntry:\n from dotenv import load_dotenv\n load_dotenv(Path(\"../.env\"), override=False)\nexcept ImportError:\n pass\n\n# ── vLLM server ───────────────────────────────────────────────────────────────\n# URL of the running vLLM OpenAI-compatible endpoint.\nVLLM_BASE_URL = os.environ.get(\"VLLM_BASE_URL\", \"http://localhost:8000/v1\")\n\n# Model name as reported by GET /v1/models (usually the path/repo used at launch).\nVLLM_MODEL_NAME = os.environ.get(\"VLLM_MODEL_NAME\", \"ibm-granite/granite-switch-4.1-3b-preview\")\n\n# HF Hub repo ID (or local path) to load I/O configs for the embedded adapters.\nGRANITE_SWITCH_SOURCE = os.environ.get(\"GRANITE_SWITCH_SOURCE\", VLLM_MODEL_NAME)\n\n# Guardian: which safety criterion to evaluate\nGUARDIAN_CRITERIA = \"harm\" # harm | social_bias | groundedness | jailbreak | ...\n\n# ── Embedding model (used to build + query ChromaDB) ─────────────────────────\nEMBEDDING_MODEL_ID = \"ibm-granite/granite-embedding-small-english-r2\"\n\n# ── ChromaDB persistence path ─────────────────────────────────────────────────\n# Share this directory (zipped) to skip the extraction step entirely.\nCHROMA_PATH = \"./govt_chroma\"\n\n# ── Corpus source (only needed when building the index from scratch) ─────────\n# govt.jsonl: subset of the government-service passages from IBM mt-rag-benchmark.\nGOVT_JSONL_URL = \"https://github.com/IBM/mt-rag-benchmark/raw/main/corpora/passage_level/govt.jsonl.zip\"\nGOVT_JSONL_PATH = \"./govt.jsonl\"\n\n# ── Retrieval ─────────────────────────────────────────────────────────────────\n# TOP_K balances recall (more candidates -> better chance of a relevant passage)\n# against context budget (every doc gets passed through answerability, clarification,\n# generation, and citation prompts). 20 is the mt-rag-benchmark default.\nTOP_K = 10\n\n# Bind TOP_K so query cells can call `show_intermediates(r)` without repeating it.\nshow_intermediates = partial(_show_intermediates_unbound, top_k=TOP_K)\n\nprint(f\"vLLM: {VLLM_BASE_URL} ({VLLM_MODEL_NAME})\")\nprint(f\"Embedding: {EMBEDDING_MODEL_ID}\")\nprint(f\"ChromaDB: {CHROMA_PATH}\")" }, { "cell_type": "markdown", @@ -549,4 +549,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file From dc8ab78472a3d40b8e4f6a56c6d5cdd38edbd0cb Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Sun, 24 May 2026 20:25:52 +0300 Subject: [PATCH 02/11] Remove deprecated govt_data_loader.py --- .../tutorials/govt_data_loader.py | 178 ------------------ 1 file changed, 178 deletions(-) delete mode 100644 src/granite_switch/tutorials/govt_data_loader.py diff --git a/src/granite_switch/tutorials/govt_data_loader.py b/src/granite_switch/tutorials/govt_data_loader.py deleted file mode 100644 index 8a71284..0000000 --- a/src/granite_switch/tutorials/govt_data_loader.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Load or build the ChromaDB corpus for the govt RAG tutorial. - -.. deprecated:: 0.2.0 - This module is deprecated and will be removed in a future release. - Use :mod:`granite_switch.tutorials.chroma_loader` instead, which provides: - - - sentence-transformers backend for improved embedding performance - - Configurable max_length (1024 vs hardcoded 512) - - Generic loader supporting multiple data sources - - Maintained API compatibility via load_or_build_govt_chroma() - -Kept separate from the notebook so the pipeline stays focused on RAG concepts. - -First run: downloads `govt.jsonl.zip` from IBM mt-rag-benchmark, -embeds with `ibm-granite/granite-embedding-small-english-r2`, and saves to -`./govt_chroma`. Subsequent runs: loads the persisted index instantly. -""" - -import warnings - -warnings.warn( - "granite_switch.tutorials.govt_data_loader is deprecated and will be removed in a future release. " - "Please use granite_switch.tutorials.chroma_loader instead.", - DeprecationWarning, - stacklevel=2, -) - -import io -import json -import os -import time -import warnings -import zipfile - -import chromadb -import httpx -import torch -from chromadb import Documents, EmbeddingFunction, Embeddings -from tqdm.auto import tqdm -from transformers import AutoModel, AutoTokenizer - -EMBEDDING_MODEL_ID = "ibm-granite/granite-embedding-small-english-r2" -CHROMA_PATH = "./govt_chroma" -GOVT_JSONL_URL = "https://github.com/IBM/mt-rag-benchmark/raw/main/corpora/passage_level/govt.jsonl.zip" -GOVT_JSONL_PATH = "./govt.jsonl" - -TUTORIAL_DOC_IDS = ["05537c9ec2dfe15e-1362-3310", "05537c9ec2dfe15e-2-1779", "05537c9ec2dfe15e-2821-4679", "05537c9ec2dfe15e-4280-6252", "087417ad420d618c-1327-3164", "087417ad420d618c-2428-4297", "087417ad420d618c-3940-5774", "089882437c965a3e-113907-115852", "089882437c965a3e-115237-117256", "089882437c965a3e-119809-121676", "089882437c965a3e-121198-123235", "089882437c965a3e-122746-124833", "089882437c965a3e-130164-131917", "089882437c965a3e-1427-3375", "089882437c965a3e-157219-159194", "089882437c965a3e-158778-160687", "089882437c965a3e-170699-172699", "089882437c965a3e-173726-175992", "089882437c965a3e-175465-177577", "089882437c965a3e-177094-179288", "089882437c965a3e-182078-183322", "089882437c965a3e-184664-186341", "089882437c965a3e-190627-192211", "089882437c965a3e-191792-193455", "089882437c965a3e-194311-196074", "089882437c965a3e-2-1955", "089882437c965a3e-42318-44668", "089882437c965a3e-51633-53566", "089882437c965a3e-53014-54918", "089882437c965a3e-85071-87052", "089882437c965a3e-86622-88344", "0ecab3f697d26347-1362-3129", "142cbdf06f6e40d9-1544-3414", "142cbdf06f6e40d9-2-2014", "142cbdf06f6e40d9-4140-6181", "142cbdf06f6e40d9-5655-7824", "19240942bfc0abf5-11151-13247", "19240942bfc0abf5-1354-3015", "2c89b9fe3cfe95ee-1392-3518", "2ead5535f9d6d3be-1376-3143", "3090260a5d934d78-1166-2578", "3090260a5d934d78-2225-3536", "32472b4a577f296f-2-1847", "353067ac7a68e5f0-2-1815", "3630bbba71396272-1400-3319", "3630bbba71396272-4267-6086", "40ce723b445ac8eb-1350-3146", "40ce723b445ac8eb-2-1781", "40ce723b445ac8eb-3922-5642", "40ce723b445ac8eb-5372-7150", "40ce723b445ac8eb-6691-8678", "40ce723b445ac8eb-8241-9800", "4c201f242ec49883-1381-3148", "4c201f242ec49883-5418-7248", "4e1c120aee9a75b6-1369-3165", "50a24d38902fbdd0-1340-3177", "50a24d38902fbdd0-3953-5813", "565fb21ac38feaa1-15852-17699", "5b86a17591806ce5-1532-3330", "60e02c03620cd1ef-9523-11519", "6ddc73cb3877e2aa-1384-3151", "6ddc73cb3877e2aa-2-1801", "77de29ffa3c3d800-1352-3553", "77de29ffa3c3d800-2-1946", "7fe68ab7967494ca-1358-3306", "81478086b28ab210-5831-7806", "818e03cc80181db4-1346-3469", "818e03cc80181db4-2-1767", "818e03cc80181db4-3125-4727", "824c4c47b2989363-1365-3132", "824c4c47b2989363-2-1782", "82f7a783325de97a-1402-3321", "82f7a783325de97a-4269-6188", "882a9cc2bb08bcdf-2-1811", "8cd62677aa5dcb92-2-1746", "9726fa169575dc43-1331-3168", "9726fa169575dc43-2-1734", "9726fa169575dc43-2432-4301", "9726fa169575dc43-3944-5768", "9726fa169575dc43-5394-7430", "9726fa169575dc43-6967-8603", "97e58e54bb79a7fe-3231-5248", "99c7b4f2bfb48b7f-3321-5534", "a005bd5aedbb28e5-33908-36180", "a005bd5aedbb28e5-35687-37469", "a4a53cb6b6bf326e-1349-3145", "a4a53cb6b6bf326e-2-1780", "a4a53cb6b6bf326e-2409-4294", "a4a53cb6b6bf326e-3921-5691", "a4a53cb6b6bf326e-5362-7156", "a4a53cb6b6bf326e-6689-8701", "a4a53cb6b6bf326e-8201-10002", "a930d03cf0b406fd-23288-25302", "a930d03cf0b406fd-30996-32981", "c550156dbbfe212c-1401-3320", "c550156dbbfe212c-16212-18433", "c550156dbbfe212c-29308-31304", "c550156dbbfe212c-30794-33132", "c550156dbbfe212c-32367-34910", "c550156dbbfe212c-37745-39895", "c550156dbbfe212c-39218-41274", "c550156dbbfe212c-40668-42844", "c550156dbbfe212c-42364-44521", "c550156dbbfe212c-44034-46164", "c550156dbbfe212c-45669-47909", "c550156dbbfe212c-47421-49701", "c550156dbbfe212c-9073-11428", "c67a2f65008344fd-2-1909", "c93223e21ee4ecfb-2-1754", "d4c48e9a4029f3e9-1801-3993", "d4edd2b762f5dce9-7713-9881", "e580ce520db3ff10-109466-111339", "e580ce520db3ff10-119467-121417", "e580ce520db3ff10-124119-126003", "e580ce520db3ff10-129933-131969", "e580ce520db3ff10-131480-133562", "e580ce520db3ff10-190530-192253", "e580ce520db3ff10-191857-193702", "e580ce520db3ff10-35813-37462", "e580ce520db3ff10-36974-38756", "e6ea24fa9e962807-1357-3305", "e6ea24fa9e962807-4275-6126", "ed17e5bd32458f9c-1347-3143", "ed17e5bd32458f9c-3919-5735", "f0b48597d0c22d32-2-1647", "f0b48597d0c22d32-2585-4675", "f0b48597d0c22d32-999-3136", "f14d35fd47c9ed59-1352-3148", "f14d35fd47c9ed59-3924-5795", "f14d35fd47c9ed59-5374-7566", "f7225d77034b8398-1402-3321", "f90bb40d57fe7ba5-1469-3644", "f90bb40d57fe7ba5-2-1890", "f90bb40d57fe7ba5-3142-5127", "f90bb40d57fe7ba5-8968-10553", "fcdc09416b6aa645-1276-2982", "fcdc09416b6aa645-2-1649"] - - -class GraniteEmbeddingFunction(EmbeddingFunction): - """ChromaDB EmbeddingFunction backed by ibm-granite/granite-embedding-*-r2.""" - - def __init__(self, model_id=EMBEDDING_MODEL_ID, batch_size=64, device = None): - if device == None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self._device = device - self._batch = batch_size - self._tokenizer = AutoTokenizer.from_pretrained(model_id) - self._model = AutoModel.from_pretrained(model_id).to(device).eval() - print(f"Granite embedding model ready on {device} ({model_id})") - if device == "cpu": - warnings.warn( - "Embedding of the passages on CPU will take hours. " - "Expected runtime is ~10 min on a single consumer GPU. " - "Consider running on a GPU host, or sharing a pre-built ./govt_chroma directory.", - stacklevel=2, - ) - - def __call__(self, input: Documents) -> Embeddings: - all_embs = [] - for i in range(0, len(input), self._batch): - batch = list(input[i : i + self._batch]) - enc = self._tokenizer( - batch, return_tensors="pt", truncation=True, max_length=512, padding=True - ) - enc = {k: v.to(self._device) for k, v in enc.items()} - with torch.no_grad(): - out = self._model(**enc) - mask = enc["attention_mask"].unsqueeze(-1).float() - emb = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9) - all_embs.extend(emb.cpu().float().tolist()) - return all_embs - - -def load_or_build_govt_chroma( - chroma_path=CHROMA_PATH, - jsonl_path=GOVT_JSONL_PATH, - jsonl_url=GOVT_JSONL_URL, - embedding_model_id=EMBEDDING_MODEL_ID, - load_only_tutorial_docs=False, - device=None, -): - """Return a ready-to-query Chroma collection for the govt corpus. - - Loads from ``chroma_path`` if it already has documents; otherwise downloads - the source jsonl, embeds, and persists. - - When ``load_only_tutorial_docs=True``, embed only docs whose ``_id`` is in - ``TUTORIAL_DOC_IDS`` (the curated subset that the demo queries actually - retrieve). Cuts the passage corpus down dramatically so first-run - embedding takes seconds instead of minutes. - """ - granite_ef = GraniteEmbeddingFunction(model_id=embedding_model_id, device= device) - client = chromadb.PersistentClient(path=chroma_path) - collection = client.get_or_create_collection( - name="govt", - embedding_function=granite_ef, - metadata={"hnsw:space": "cosine"}, - ) - - if collection.count() > 0: - print(f"Loaded from {chroma_path} ({collection.count():,} docs).") - return collection - - if not os.path.exists(jsonl_path): - print(f"Downloading {jsonl_url} ...") - t0 = time.time() - # Stream into memory with a progress bar - the zip is ~50MB and the - # unblocked .get() used to leave users staring at a silent cell for minutes. - # Split timeout: fail fast on connect (10s), allow slow reads (300s). - timeout = httpx.Timeout(300.0, connect=10.0) - buf = io.BytesIO() - with httpx.Client(follow_redirects=True, timeout=timeout) as c: - with c.stream("GET", jsonl_url) as resp: - resp.raise_for_status() - total = int(resp.headers.get("Content-Length", 0)) or None - with tqdm(total=total, unit="B", unit_scale=True, desc="download") as bar: - for chunk in resp.iter_bytes(chunk_size=65536): - buf.write(chunk) - bar.update(len(chunk)) - buf.seek(0) - # Atomic write: extract to a .tmp path then os.replace, so a kill/crash - # mid-write can't leave a truncated jsonl that later runs silently use. - tmp_path = jsonl_path + ".tmp" - with zipfile.ZipFile(buf) as zf: - inner = next(n for n in zf.namelist() if n.endswith(".jsonl")) - with zf.open(inner) as src, open(tmp_path, "wb") as dst: - dst.write(src.read()) - os.replace(tmp_path, jsonl_path) - print(f"Saved {jsonl_path} in {time.time() - t0:.1f}s.") - - keep_ids = set(TUTORIAL_DOC_IDS) if load_only_tutorial_docs else None - if keep_ids is not None: - print(f"Filtering to {len(keep_ids)} tutorial doc ids") - - print(f"Reading {jsonl_path} -> {chroma_path}...") - t0 = time.time() - ids, texts, metas = [], [], [] - with open(jsonl_path) as f: - for line in f: - doc = json.loads(line) - text = doc.get("text", "").strip() - if not text: - continue - doc_id = doc.get("_id", doc.get("id", str(len(ids)))) - if keep_ids is not None and doc_id not in keep_ids: - continue - ids.append(doc_id) - texts.append(text) - metas.append({"title": doc.get("title", ""), "url": doc.get("url", "")}) - if not ids: - raise RuntimeError( - f"{jsonl_path} yielded zero documents - the file may be empty, truncated, " - f"or schema-drifted (expected a 'text' field per line). Delete it and rerun " - f"to re-download." - ) - print(f"Read {len(ids):,} docs in {time.time() - t0:.1f}s. Embedding & indexing...") - - t1 = time.time() - batch = 16 if granite_ef._device == "cpu" else 500 - for i in tqdm(range(0, len(ids), batch), unit="batch", desc="indexing"): - collection.upsert( - ids = ids [i : i + batch], - documents = texts[i : i + batch], - metadatas = metas[i : i + batch], - ) - print(f"Done. {collection.count():,} docs saved to {chroma_path} in {time.time() - t1:.1f}s.") - return collection From 24e2d1cd0b5916c775ecf7fcba5e6b5bd97bd73d Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Sun, 24 May 2026 22:06:04 +0300 Subject: [PATCH 03/11] Fix batch_size handling for sentence-transformers Fixes TypeError when batch_size is None by omitting the parameter to let sentence-transformers use auto-tuning. --- src/granite_switch/tutorials/chroma_loader.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/granite_switch/tutorials/chroma_loader.py b/src/granite_switch/tutorials/chroma_loader.py index d4e54ec..0d38e39 100644 --- a/src/granite_switch/tutorials/chroma_loader.py +++ b/src/granite_switch/tutorials/chroma_loader.py @@ -140,13 +140,16 @@ def __init__( def __call__(self, input: Documents) -> Embeddings: """Embed texts with batching and progress bar.""" - embeddings = self.model.encode( - input, - batch_size=self.batch_size, - max_seq_length=self.max_length, - show_progress_bar=False, # Disable internal progress (we track at doc level) - convert_to_numpy=True, - ) + # Build encode kwargs, omitting batch_size if None (let library auto-tune) + encode_kwargs = { + "max_seq_length": self.max_length, + "show_progress_bar": False, # Disable internal progress (we track at doc level) + "convert_to_numpy": True, + } + if self.batch_size is not None: + encode_kwargs["batch_size"] = self.batch_size + + embeddings = self.model.encode(input, **encode_kwargs) return embeddings.tolist() From 90913ebf753738b6e109292ad9d3e0978560bbc2 Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Sun, 24 May 2026 23:13:05 +0300 Subject: [PATCH 04/11] Fix max_seq_length configuration for sentence-transformers Set max_seq_length on the model itself instead of passing it to encode(). Newer versions of sentence-transformers don't accept max_seq_length as an encode() parameter. This fixes the ValueError that occurred when calling encode() with max_seq_length in kwargs. The max_seq_length is now set as a model attribute after instantiation, which is the correct approach for configuring maximum sequence length in sentence-transformers. --- src/granite_switch/tutorials/chroma_loader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/granite_switch/tutorials/chroma_loader.py b/src/granite_switch/tutorials/chroma_loader.py index 0d38e39..e2263e1 100644 --- a/src/granite_switch/tutorials/chroma_loader.py +++ b/src/granite_switch/tutorials/chroma_loader.py @@ -136,13 +136,17 @@ def __init__( # Eager loading: model loaded immediately (clear upfront waiting) self.model = SentenceTransformer(model_id, device=device) + + # Set max_seq_length on the model itself (not in encode() kwargs) + self.model.max_seq_length = max_length + print(f"Granite embedding model ready on {device} ({model_id})") def __call__(self, input: Documents) -> Embeddings: """Embed texts with batching and progress bar.""" # Build encode kwargs, omitting batch_size if None (let library auto-tune) + # Note: max_seq_length is set on the model itself in __init__, not here encode_kwargs = { - "max_seq_length": self.max_length, "show_progress_bar": False, # Disable internal progress (we track at doc level) "convert_to_numpy": True, } From f6baada50e4b90a84f26d587b87e7c1dd72ec38b Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Mon, 25 May 2026 14:36:24 +0300 Subject: [PATCH 05/11] Tutorial improvements: vLLM GPU defaults and notebook reordering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - launch_vllm now defaults extra_args to GPU-friendly flags (--gpu-memory-utilization 0.85, --max-num-seqs 16, --enforce-eager) - Both RAG notebooks: reorder sections so corpus is built before launching vLLM (config → build corpus → launch vLLM) - rag_101: add dedicated HF login cell after install, matching rag_flow structure --- src/granite_switch/tutorials/vllm_server.py | 9 +- tutorials/notebooks/rag_101.ipynb | 722 ++++++++++++++++++-- tutorials/notebooks/rag_flow.ipynb | 110 ++- 3 files changed, 739 insertions(+), 102 deletions(-) diff --git a/src/granite_switch/tutorials/vllm_server.py b/src/granite_switch/tutorials/vllm_server.py index 51fafd4..f971455 100644 --- a/src/granite_switch/tutorials/vllm_server.py +++ b/src/granite_switch/tutorials/vllm_server.py @@ -18,7 +18,11 @@ def launch_vllm( model: str, port: int, log_file: str, - extra_args: Sequence[str] | None = None, + extra_args: Sequence[str] = ( + "--gpu-memory-utilization", "0.85", + "--max-num-seqs", "16", + "--enforce-eager", + ), max_model_len: int = DEFAULT_MAX_MODEL_LEN, ) -> subprocess.Popen: cmd = [ @@ -31,9 +35,8 @@ def launch_vllm( str(port), "--max-model-len", str(max_model_len), + *extra_args, ] - if extra_args: - cmd += extra_args with open(log_file, "w") as log_handle: proc = subprocess.Popen(cmd, stdout=log_handle, stderr=subprocess.STDOUT) diff --git a/tutorials/notebooks/rag_101.ipynb b/tutorials/notebooks/rag_101.ipynb index 718c08d..e52170b 100644 --- a/tutorials/notebooks/rag_101.ipynb +++ b/tutorials/notebooks/rag_101.ipynb @@ -32,54 +32,676 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "hf-login", + "id": "hf-login-call", "metadata": {}, "outputs": [], - "source": "import os\nfrom pathlib import Path\n\nfrom huggingface_hub import notebook_login\n\nfrom granite_switch.tutorials.chroma_loader import load_or_build_govt_chroma\nfrom granite_switch.tutorials.vllm_server import (\n kill_stale_vllm_processes,\n launch_vllm,\n print_gpu_state,\n tail_log,\n wait_for_server,\n)\nfrom mellea.backends.openai import OpenAIBackend\nfrom mellea.stdlib.components import Document as MelleaDocument\nfrom mellea.stdlib.components.intrinsic import rag\nfrom mellea.stdlib.context import ChatContext\n\ntry:\n from dotenv import load_dotenv\n load_dotenv(Path(\"../.env\"), override=False)\nexcept ImportError:\n pass" + "source": [ + "from huggingface_hub import notebook_login\n", + "notebook_login() # needed to pull ibm-granite models from the Hub" + ], + "execution_count": null }, { "cell_type": "code", "execution_count": null, - "id": "vllm-helper", + "id": "hf-login", "metadata": {}, "outputs": [], "source": [ - "notebook_login() # needed to pull ibm-granite models from the Hub\n", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "o", + "s", "\n", - "kill_stale_vllm_processes()\n", - "print_gpu_state()" - ] - }, - { - "cell_type": "markdown", - "id": "launch-md", - "metadata": {}, - "source": [ - "## 1 · Launch vLLM server\n", + "f", + "r", + "o", + "m", + " ", + "p", + "a", + "t", + "h", + "l", + "i", + "b", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "P", + "a", + "t", + "h", + "\n", + "\n", + "f", + "r", + "o", + "m", + " ", + "h", + "u", + "g", + "g", + "i", + "n", + "g", + "f", + "a", + "c", + "e", + "_", + "h", + "u", + "b", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "n", + "o", + "t", + "e", + "b", + "o", + "o", + "k", + "_", + "l", + "o", + "g", + "i", + "n", + "\n", + "\n", + "f", + "r", + "o", + "m", + " ", + "g", + "r", + "a", + "n", + "i", + "t", + "e", + "_", + "s", + "w", + "i", + "t", + "c", + "h", + ".", + "t", + "u", + "t", + "o", + "r", + "i", + "a", + "l", + "s", + ".", + "c", + "h", + "r", + "o", + "m", + "a", + "_", + "l", + "o", + "a", + "d", + "e", + "r", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "l", + "o", + "a", + "d", + "_", + "o", + "r", + "_", + "b", + "u", + "i", + "l", + "d", + "_", + "g", + "o", + "v", + "t", + "_", + "c", + "h", + "r", + "o", + "m", + "a", + "\n", + "f", + "r", + "o", + "m", + " ", + "g", + "r", + "a", + "n", + "i", + "t", + "e", + "_", + "s", + "w", + "i", + "t", + "c", + "h", + ".", + "t", + "u", + "t", + "o", + "r", + "i", + "a", + "l", + "s", + ".", + "v", + "l", + "l", + "m", + "_", + "s", + "e", + "r", + "v", + "e", + "r", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "(", + "\n", + " ", + " ", + " ", + " ", + "k", + "i", + "l", + "l", + "_", + "s", + "t", + "a", + "l", + "e", + "_", + "v", + "l", + "l", + "m", + "_", + "p", + "r", + "o", + "c", + "e", + "s", + "s", + "e", + "s", + ",", "\n", - "Start the Granite Switch model on port 8000. The server runs in the background; `wait_for_server` polls `/health` until it's ready." + " ", + " ", + " ", + " ", + "l", + "a", + "u", + "n", + "c", + "h", + "_", + "v", + "l", + "l", + "m", + ",", + "\n", + " ", + " ", + " ", + " ", + "p", + "r", + "i", + "n", + "t", + "_", + "g", + "p", + "u", + "_", + "s", + "t", + "a", + "t", + "e", + ",", + "\n", + " ", + " ", + " ", + " ", + "t", + "a", + "i", + "l", + "_", + "l", + "o", + "g", + ",", + "\n", + " ", + " ", + " ", + " ", + "w", + "a", + "i", + "t", + "_", + "f", + "o", + "r", + "_", + "s", + "e", + "r", + "v", + "e", + "r", + ",", + "\n", + ")", + "\n", + "f", + "r", + "o", + "m", + " ", + "m", + "e", + "l", + "l", + "e", + "a", + ".", + "b", + "a", + "c", + "k", + "e", + "n", + "d", + "s", + ".", + "o", + "p", + "e", + "n", + "a", + "i", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "O", + "p", + "e", + "n", + "A", + "I", + "B", + "a", + "c", + "k", + "e", + "n", + "d", + "\n", + "f", + "r", + "o", + "m", + " ", + "m", + "e", + "l", + "l", + "e", + "a", + ".", + "s", + "t", + "d", + "l", + "i", + "b", + ".", + "c", + "o", + "m", + "p", + "o", + "n", + "e", + "n", + "t", + "s", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "D", + "o", + "c", + "u", + "m", + "e", + "n", + "t", + " ", + "a", + "s", + " ", + "M", + "e", + "l", + "l", + "e", + "a", + "D", + "o", + "c", + "u", + "m", + "e", + "n", + "t", + "\n", + "f", + "r", + "o", + "m", + " ", + "m", + "e", + "l", + "l", + "e", + "a", + ".", + "s", + "t", + "d", + "l", + "i", + "b", + ".", + "c", + "o", + "m", + "p", + "o", + "n", + "e", + "n", + "t", + "s", + ".", + "i", + "n", + "t", + "r", + "i", + "n", + "s", + "i", + "c", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "r", + "a", + "g", + "\n", + "f", + "r", + "o", + "m", + " ", + "m", + "e", + "l", + "l", + "e", + "a", + ".", + "s", + "t", + "d", + "l", + "i", + "b", + ".", + "c", + "o", + "n", + "t", + "e", + "x", + "t", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "C", + "h", + "a", + "t", + "C", + "o", + "n", + "t", + "e", + "x", + "t", + "\n", + "\n", + "t", + "r", + "y", + ":", + "\n", + " ", + " ", + " ", + " ", + "f", + "r", + "o", + "m", + " ", + "d", + "o", + "t", + "e", + "n", + "v", + " ", + "i", + "m", + "p", + "o", + "r", + "t", + " ", + "l", + "o", + "a", + "d", + "_", + "d", + "o", + "t", + "e", + "n", + "v", + "\n", + " ", + " ", + " ", + " ", + "l", + "o", + "a", + "d", + "_", + "d", + "o", + "t", + "e", + "n", + "v", + "(", + "P", + "a", + "t", + "h", + "(", + "\"", + ".", + ".", + "/", + ".", + "e", + "n", + "v", + "\"", + ")", + ",", + " ", + "o", + "v", + "e", + "r", + "r", + "i", + "d", + "e", + "=", + "F", + "a", + "l", + "s", + "e", + ")", + "\n", + "e", + "x", + "c", + "e", + "p", + "t", + " ", + "I", + "m", + "p", + "o", + "r", + "t", + "E", + "r", + "r", + "o", + "r", + ":", + "\n", + " ", + " ", + " ", + " ", + "p", + "a", + "s", + "s" ] }, { "cell_type": "code", "execution_count": null, - "id": "launch", + "id": "vllm-helper", "metadata": {}, "outputs": [], "source": [ - "VLLM_MODEL = \"ibm-granite/granite-switch-4.1-3b-preview\"\n", - "VLLM_PORT = 8000\n", - "MAX_MODEL_LEN = 10240 # 10k, fits comfortably on an T4 GPU.\n", "\n", - "vllm_proc = launch_vllm(\n", - " model=VLLM_MODEL,\n", - " port=VLLM_PORT,\n", - " max_model_len=MAX_MODEL_LEN,\n", - " log_file=\"/content/vllm_server.log\",\n", - ")\n", - "if not wait_for_server(VLLM_PORT):\n", - " tail_log(\"/content/vllm_server.log\")" + "kill_stale_vllm_processes()\n", + "print_gpu_state()" ] }, { @@ -87,8 +709,7 @@ "id": "6863316a3dcb98b2", "metadata": {}, "source": [ - "## 2 · Configuration\n", - "Endpoints, model IDs, and corpus paths. Every value falls back to a sensible default, so the cell runs as-is if your vLLM server is on `localhost:8000`." + "## 1 · Configuration\nEndpoints, model IDs, and corpus paths. Every value falls back to a sensible default, so the cell runs as-is if your vLLM server is on `localhost:8000`." ] }, { @@ -130,15 +751,7 @@ "id": "corpus-md", "metadata": {}, "source": [ - "## 3 · Build or load the vector corpus\n", - "\n", - "`load_or_build_govt_chroma` is the corpus half of RAG, packaged so this notebook stays focused on retrieval and answerability:\n", - "\n", - "1. Downloads `govt.jsonl.zip` (~50 MB, 49k government-service passages from [IBM mt-rag-benchmark](https://github.com/IBM/mt-rag-benchmark)) on first run.\n", - "2. Embeds each passage with `ibm-granite/granite-embedding-small-english-r2`.\n", - "3. Persists the index to `./govt_chroma` so subsequent runs load instantly.\n", - "\n", - "> **Note:** to keep the tutorial fast, we filter most non-related docs and embed only the curated subset that the demo queries actually retrieve. For a full corpus load, set `load_only_tutorial_docs=False` in the call below." + "## 2 · Build or load the vector corpus\n\n`load_or_build_govt_chroma` is the corpus half of RAG, packaged so this notebook stays focused on retrieval and answerability:\n\n1. Downloads `govt.jsonl.zip` (~50 MB, 49k government-service passages from [IBM mt-rag-benchmark](https://github.com/IBM/mt-rag-benchmark)) on first run.\n2. Embeds each passage with `ibm-granite/granite-embedding-small-english-r2`.\n3. Persists the index to `./govt_chroma` so subsequent runs load instantly.\n\n> **Note:** to keep the tutorial fast, we filter most non-related docs and embed only the curated subset that the demo queries actually retrieve. For a full corpus load, set `load_only_tutorial_docs=False` in the call below." ] }, { @@ -159,6 +772,35 @@ "print(f\"Corpus ready — {chroma_collection.count():,} passages indexed.\")" ] }, + { + "cell_type": "markdown", + "id": "launch-md", + "metadata": {}, + "source": [ + "## 3 · Launch vLLM server\n\nStart the Granite Switch model on port 8000. The server runs in the background; `wait_for_server` polls `/health` until it's ready." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "launch", + "metadata": {}, + "outputs": [], + "source": [ + "VLLM_MODEL = \"ibm-granite/granite-switch-4.1-3b-preview\"\n", + "VLLM_PORT = 8000\n", + "MAX_MODEL_LEN = 10240 # 10k, fits comfortably on an T4 GPU.\n", + "\n", + "vllm_proc = launch_vllm(\n", + " model=VLLM_MODEL,\n", + " port=VLLM_PORT,\n", + " max_model_len=MAX_MODEL_LEN,\n", + " log_file=\"/content/vllm_server.log\",\n", + ")\n", + "if not wait_for_server(VLLM_PORT):\n", + " tail_log(\"/content/vllm_server.log\")" + ] + }, { "cell_type": "markdown", "id": "backend-md", @@ -262,4 +904,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tutorials/notebooks/rag_flow.ipynb b/tutorials/notebooks/rag_flow.ipynb index a22e9b3..7a2da19 100644 --- a/tutorials/notebooks/rag_flow.ipynb +++ b/tutorials/notebooks/rag_flow.ipynb @@ -76,12 +76,59 @@ }, { "cell_type": "markdown", - "id": "5b8c0be1ec4cc837", + "id": "b582e2627baf73e6", "metadata": {}, "source": [ - "## 1 · Launch vLLM server\n", + "## 1 · Configuration\nEndpoints, model IDs, and corpus paths. Every value falls back to a sensible default, so the cell runs as-is if your vLLM server is on `localhost:8000`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12a13b8feceb5539", + "metadata": {}, + "outputs": [], + "source": "import json\nimport logging\nimport os\nimport warnings\nfrom functools import partial\nfrom pathlib import Path\n\nfrom IPython.display import display, Markdown\nfrom granite_switch.tutorials.chroma_loader import load_or_build_govt_chroma\nfrom granite_switch.tutorials.rag_display import show_answer, show_history, show_intermediates as _show_intermediates_unbound, _is_clear\nfrom mellea.backends import ModelOption\nfrom mellea.backends.openai import OpenAIBackend\nfrom mellea.stdlib.components import Document as MelleaDocument\nfrom mellea.stdlib.components.chat import Message as MelleaMessage\nfrom mellea.stdlib.components.intrinsic import rag\nfrom mellea.stdlib.components.intrinsic.guardian import guardian_check\nfrom mellea.stdlib.context import ChatContext\nimport mellea.stdlib.functional as mfuncs\n\ntry:\n from dotenv import load_dotenv\n load_dotenv(Path(\"../.env\"), override=False)\nexcept ImportError:\n pass\n\n# ── vLLM server ───────────────────────────────────────────────────────────────\n# URL of the running vLLM OpenAI-compatible endpoint.\nVLLM_BASE_URL = os.environ.get(\"VLLM_BASE_URL\", \"http://localhost:8000/v1\")\n\n# Model name as reported by GET /v1/models (usually the path/repo used at launch).\nVLLM_MODEL_NAME = os.environ.get(\"VLLM_MODEL_NAME\", \"ibm-granite/granite-switch-4.1-3b-preview\")\n\n# HF Hub repo ID (or local path) to load I/O configs for the embedded adapters.\nGRANITE_SWITCH_SOURCE = os.environ.get(\"GRANITE_SWITCH_SOURCE\", VLLM_MODEL_NAME)\n\n# Guardian: which safety criterion to evaluate\nGUARDIAN_CRITERIA = \"harm\" # harm | social_bias | groundedness | jailbreak | ...\n\n# ── Embedding model (used to build + query ChromaDB) ─────────────────────────\nEMBEDDING_MODEL_ID = \"ibm-granite/granite-embedding-small-english-r2\"\n\n# ── ChromaDB persistence path ─────────────────────────────────────────────────\n# Share this directory (zipped) to skip the extraction step entirely.\nCHROMA_PATH = \"./govt_chroma\"\n\n# ── Corpus source (only needed when building the index from scratch) ─────────\n# govt.jsonl: subset of the government-service passages from IBM mt-rag-benchmark.\nGOVT_JSONL_URL = \"https://github.com/IBM/mt-rag-benchmark/raw/main/corpora/passage_level/govt.jsonl.zip\"\nGOVT_JSONL_PATH = \"./govt.jsonl\"\n\n# ── Retrieval ─────────────────────────────────────────────────────────────────\n# TOP_K balances recall (more candidates -> better chance of a relevant passage)\n# against context budget (every doc gets passed through answerability, clarification,\n# generation, and citation prompts). 20 is the mt-rag-benchmark default.\nTOP_K = 10\n\n# Bind TOP_K so query cells can call `show_intermediates(r)` without repeating it.\nshow_intermediates = partial(_show_intermediates_unbound, top_k=TOP_K)\n\nprint(f\"vLLM: {VLLM_BASE_URL} ({VLLM_MODEL_NAME})\")\nprint(f\"Embedding: {EMBEDDING_MODEL_ID}\")\nprint(f\"ChromaDB: {CHROMA_PATH}\")" + }, + { + "cell_type": "markdown", + "id": "8b7abdb691b97e05", + "metadata": {}, + "source": [ + "## 2 · Build or load vector corpus\nData prep is delegated to `scripts/utils/govt_data_loader.py` to keep this notebook focused on the RAG flow.\n\n**First run:** downloads ~50 MB and embeds the corpus passages. **Subsequent runs:** load the persisted index instantly.\n\n> **Note:** to keep the tutorial fast, we filter most non-related docs and embed only the curated subset that the demo queries actually retrieve. For a full corpus load, set `load_only_tutorial_docs=False` in the call below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93f4f938190f79ff", + "metadata": {}, + "outputs": [], + "source": [ + "# Load or build the ChromaDB corpus.\n", + "# First run: downloads govt.jsonl.zip from IBM mt-rag-benchmark (subset of the passages),\n", + "# embeds with `ibm-granite/granite-embedding-small-english-r2` into ./govt_chroma.\n", + "# Subsequent runs: loads ./govt_chroma instantly.\n", + "#\n", + "# `load_only_tutorial_docs=True` restricts embedding to the curated subset\n", + "# the demo queries actually retrieve. Set False to embed the full corpus.\n", "\n", - "Start the Granite Switch model on port 8000. The server runs in the background; `wait_for_server` polls `/health` until it's ready." + "chroma_collection = load_or_build_govt_chroma(\n", + " chroma_path = CHROMA_PATH,\n", + " jsonl_path = GOVT_JSONL_PATH,\n", + " jsonl_url = GOVT_JSONL_URL,\n", + " embedding_model_id = EMBEDDING_MODEL_ID,\n", + " device = \"cpu\",\n", + " load_only_tutorial_docs = True,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "5b8c0be1ec4cc837", + "metadata": {}, + "source": [ + "## 3 · Launch vLLM server\n\nStart the Granite Switch model on port 8000. The server runs in the background; `wait_for_server` polls `/health` until it's ready." ] }, { @@ -161,61 +208,6 @@ "mermaid(mermaid_diagram)" ] }, - { - "cell_type": "markdown", - "id": "b582e2627baf73e6", - "metadata": {}, - "source": [ - "## 2 · Configuration\n", - "Endpoints, model IDs, and corpus paths. Every value falls back to a sensible default, so the cell runs as-is if your vLLM server is on `localhost:8000`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12a13b8feceb5539", - "metadata": {}, - "outputs": [], - "source": "import json\nimport logging\nimport os\nimport warnings\nfrom functools import partial\nfrom pathlib import Path\n\nfrom IPython.display import display, Markdown\nfrom granite_switch.tutorials.chroma_loader import load_or_build_govt_chroma\nfrom granite_switch.tutorials.rag_display import show_answer, show_history, show_intermediates as _show_intermediates_unbound, _is_clear\nfrom mellea.backends import ModelOption\nfrom mellea.backends.openai import OpenAIBackend\nfrom mellea.stdlib.components import Document as MelleaDocument\nfrom mellea.stdlib.components.chat import Message as MelleaMessage\nfrom mellea.stdlib.components.intrinsic import rag\nfrom mellea.stdlib.components.intrinsic.guardian import guardian_check\nfrom mellea.stdlib.context import ChatContext\nimport mellea.stdlib.functional as mfuncs\n\ntry:\n from dotenv import load_dotenv\n load_dotenv(Path(\"../.env\"), override=False)\nexcept ImportError:\n pass\n\n# ── vLLM server ───────────────────────────────────────────────────────────────\n# URL of the running vLLM OpenAI-compatible endpoint.\nVLLM_BASE_URL = os.environ.get(\"VLLM_BASE_URL\", \"http://localhost:8000/v1\")\n\n# Model name as reported by GET /v1/models (usually the path/repo used at launch).\nVLLM_MODEL_NAME = os.environ.get(\"VLLM_MODEL_NAME\", \"ibm-granite/granite-switch-4.1-3b-preview\")\n\n# HF Hub repo ID (or local path) to load I/O configs for the embedded adapters.\nGRANITE_SWITCH_SOURCE = os.environ.get(\"GRANITE_SWITCH_SOURCE\", VLLM_MODEL_NAME)\n\n# Guardian: which safety criterion to evaluate\nGUARDIAN_CRITERIA = \"harm\" # harm | social_bias | groundedness | jailbreak | ...\n\n# ── Embedding model (used to build + query ChromaDB) ─────────────────────────\nEMBEDDING_MODEL_ID = \"ibm-granite/granite-embedding-small-english-r2\"\n\n# ── ChromaDB persistence path ─────────────────────────────────────────────────\n# Share this directory (zipped) to skip the extraction step entirely.\nCHROMA_PATH = \"./govt_chroma\"\n\n# ── Corpus source (only needed when building the index from scratch) ─────────\n# govt.jsonl: subset of the government-service passages from IBM mt-rag-benchmark.\nGOVT_JSONL_URL = \"https://github.com/IBM/mt-rag-benchmark/raw/main/corpora/passage_level/govt.jsonl.zip\"\nGOVT_JSONL_PATH = \"./govt.jsonl\"\n\n# ── Retrieval ─────────────────────────────────────────────────────────────────\n# TOP_K balances recall (more candidates -> better chance of a relevant passage)\n# against context budget (every doc gets passed through answerability, clarification,\n# generation, and citation prompts). 20 is the mt-rag-benchmark default.\nTOP_K = 10\n\n# Bind TOP_K so query cells can call `show_intermediates(r)` without repeating it.\nshow_intermediates = partial(_show_intermediates_unbound, top_k=TOP_K)\n\nprint(f\"vLLM: {VLLM_BASE_URL} ({VLLM_MODEL_NAME})\")\nprint(f\"Embedding: {EMBEDDING_MODEL_ID}\")\nprint(f\"ChromaDB: {CHROMA_PATH}\")" - }, - { - "cell_type": "markdown", - "id": "8b7abdb691b97e05", - "metadata": {}, - "source": [ - "## 3 · Build or load vector corpus\n", - "Data prep is delegated to `scripts/utils/govt_data_loader.py` to keep this notebook focused on the RAG flow.\n", - "\n", - "**First run:** downloads ~50 MB and embeds the corpus passages. **Subsequent runs:** load the persisted index instantly.\n", - "\n", - "> **Note:** to keep the tutorial fast, we filter most non-related docs and embed only the curated subset that the demo queries actually retrieve. For a full corpus load, set `load_only_tutorial_docs=False` in the call below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93f4f938190f79ff", - "metadata": {}, - "outputs": [], - "source": [ - "# Load or build the ChromaDB corpus.\n", - "# First run: downloads govt.jsonl.zip from IBM mt-rag-benchmark (subset of the passages),\n", - "# embeds with `ibm-granite/granite-embedding-small-english-r2` into ./govt_chroma.\n", - "# Subsequent runs: loads ./govt_chroma instantly.\n", - "#\n", - "# `load_only_tutorial_docs=True` restricts embedding to the curated subset\n", - "# the demo queries actually retrieve. Set False to embed the full corpus.\n", - "\n", - "chroma_collection = load_or_build_govt_chroma(\n", - " chroma_path = CHROMA_PATH,\n", - " jsonl_path = GOVT_JSONL_PATH,\n", - " jsonl_url = GOVT_JSONL_URL,\n", - " embedding_model_id = EMBEDDING_MODEL_ID,\n", - " device = \"cpu\",\n", - " load_only_tutorial_docs = True,\n", - ")\n" - ] - }, { "cell_type": "markdown", "id": "a7864f2b9e9d11b2", @@ -549,4 +541,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 37cbc99e3f2e50889a11acbb8aca006e28b29b65 Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Mon, 25 May 2026 14:47:13 +0300 Subject: [PATCH 06/11] Make launch_vllm extra_args additive on top of GPU defaults --- src/granite_switch/tutorials/vllm_server.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/granite_switch/tutorials/vllm_server.py b/src/granite_switch/tutorials/vllm_server.py index f971455..3851b11 100644 --- a/src/granite_switch/tutorials/vllm_server.py +++ b/src/granite_switch/tutorials/vllm_server.py @@ -14,15 +14,18 @@ DEFAULT_MAX_MODEL_LEN = 32768 # 32k, fits comfortably on an A100 (40/80 GiB). +_DEFAULT_VLLM_ARGS = ( + "--gpu-memory-utilization", "0.85", + "--max-num-seqs", "16", + "--enforce-eager", +) + + def launch_vllm( model: str, port: int, log_file: str, - extra_args: Sequence[str] = ( - "--gpu-memory-utilization", "0.85", - "--max-num-seqs", "16", - "--enforce-eager", - ), + extra_args: Sequence[str] = (), max_model_len: int = DEFAULT_MAX_MODEL_LEN, ) -> subprocess.Popen: cmd = [ @@ -35,6 +38,7 @@ def launch_vllm( str(port), "--max-model-len", str(max_model_len), + *_DEFAULT_VLLM_ARGS, *extra_args, ] From 1a49795784ceec9a42653eec43754755fcfbaa75 Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Mon, 25 May 2026 14:52:58 +0300 Subject: [PATCH 07/11] Expose GPU flags as named parameters in launch_vllm --- src/granite_switch/tutorials/vllm_server.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/granite_switch/tutorials/vllm_server.py b/src/granite_switch/tutorials/vllm_server.py index 3851b11..c807c82 100644 --- a/src/granite_switch/tutorials/vllm_server.py +++ b/src/granite_switch/tutorials/vllm_server.py @@ -14,17 +14,13 @@ DEFAULT_MAX_MODEL_LEN = 32768 # 32k, fits comfortably on an A100 (40/80 GiB). -_DEFAULT_VLLM_ARGS = ( - "--gpu-memory-utilization", "0.85", - "--max-num-seqs", "16", - "--enforce-eager", -) - - def launch_vllm( model: str, port: int, log_file: str, + gpu_memory_utilization: float = 0.85, + max_num_seqs: int = 16, + enforce_eager: bool = True, extra_args: Sequence[str] = (), max_model_len: int = DEFAULT_MAX_MODEL_LEN, ) -> subprocess.Popen: @@ -38,7 +34,9 @@ def launch_vllm( str(port), "--max-model-len", str(max_model_len), - *_DEFAULT_VLLM_ARGS, + "--gpu-memory-utilization", str(gpu_memory_utilization), + "--max-num-seqs", str(max_num_seqs), + *( ["--enforce-eager"] if enforce_eager else []), *extra_args, ] From e3eb5e8d836abaa0224aa1ee7f876fff8c40dc42 Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Mon, 25 May 2026 14:59:24 +0300 Subject: [PATCH 08/11] Fix rag_101: remove duplicate notebook_login import, fix blank line in setup cell --- tutorials/notebooks/rag_101.ipynb | 664 +----------------------------- 1 file changed, 21 insertions(+), 643 deletions(-) diff --git a/tutorials/notebooks/rag_101.ipynb b/tutorials/notebooks/rag_101.ipynb index e52170b..a566513 100644 --- a/tutorials/notebooks/rag_101.ipynb +++ b/tutorials/notebooks/rag_101.ipynb @@ -48,648 +48,27 @@ "metadata": {}, "outputs": [], "source": [ - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "o", - "s", - "\n", - "f", - "r", - "o", - "m", - " ", - "p", - "a", - "t", - "h", - "l", - "i", - "b", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "P", - "a", - "t", - "h", - "\n", - "\n", - "f", - "r", - "o", - "m", - " ", - "h", - "u", - "g", - "g", - "i", - "n", - "g", - "f", - "a", - "c", - "e", - "_", - "h", - "u", - "b", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "n", - "o", - "t", - "e", - "b", - "o", - "o", - "k", - "_", - "l", - "o", - "g", - "i", - "n", - "\n", - "\n", - "f", - "r", - "o", - "m", - " ", - "g", - "r", - "a", - "n", - "i", - "t", - "e", - "_", - "s", - "w", - "i", - "t", - "c", - "h", - ".", - "t", - "u", - "t", - "o", - "r", - "i", - "a", - "l", - "s", - ".", - "c", - "h", - "r", - "o", - "m", - "a", - "_", - "l", - "o", - "a", - "d", - "e", - "r", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "l", - "o", - "a", - "d", - "_", - "o", - "r", - "_", - "b", - "u", - "i", - "l", - "d", - "_", - "g", - "o", - "v", - "t", - "_", - "c", - "h", - "r", - "o", - "m", - "a", - "\n", - "f", - "r", - "o", - "m", - " ", - "g", - "r", - "a", - "n", - "i", - "t", - "e", - "_", - "s", - "w", - "i", - "t", - "c", - "h", - ".", - "t", - "u", - "t", - "o", - "r", - "i", - "a", - "l", - "s", - ".", - "v", - "l", - "l", - "m", - "_", - "s", - "e", - "r", - "v", - "e", - "r", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "(", - "\n", - " ", - " ", - " ", - " ", - "k", - "i", - "l", - "l", - "_", - "s", - "t", - "a", - "l", - "e", - "_", - "v", - "l", - "l", - "m", - "_", - "p", - "r", - "o", - "c", - "e", - "s", - "s", - "e", - "s", - ",", - "\n", - " ", - " ", - " ", - " ", - "l", - "a", - "u", - "n", - "c", - "h", - "_", - "v", - "l", - "l", - "m", - ",", - "\n", - " ", - " ", - " ", - " ", - "p", - "r", - "i", - "n", - "t", - "_", - "g", - "p", - "u", - "_", - "s", - "t", - "a", - "t", - "e", - ",", - "\n", - " ", - " ", - " ", - " ", - "t", - "a", - "i", - "l", - "_", - "l", - "o", - "g", - ",", - "\n", - " ", - " ", - " ", - " ", - "w", - "a", - "i", - "t", - "_", - "f", - "o", - "r", - "_", - "s", - "e", - "r", - "v", - "e", - "r", - ",", - "\n", - ")", - "\n", - "f", - "r", - "o", - "m", - " ", - "m", - "e", - "l", - "l", - "e", - "a", - ".", - "b", - "a", - "c", - "k", - "e", - "n", - "d", - "s", - ".", - "o", - "p", - "e", - "n", - "a", - "i", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "O", - "p", - "e", - "n", - "A", - "I", - "B", - "a", - "c", - "k", - "e", - "n", - "d", - "\n", - "f", - "r", - "o", - "m", - " ", - "m", - "e", - "l", - "l", - "e", - "a", - ".", - "s", - "t", - "d", - "l", - "i", - "b", - ".", - "c", - "o", - "m", - "p", - "o", - "n", - "e", - "n", - "t", - "s", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "D", - "o", - "c", - "u", - "m", - "e", - "n", - "t", - " ", - "a", - "s", - " ", - "M", - "e", - "l", - "l", - "e", - "a", - "D", - "o", - "c", - "u", - "m", - "e", - "n", - "t", - "\n", - "f", - "r", - "o", - "m", - " ", - "m", - "e", - "l", - "l", - "e", - "a", - ".", - "s", - "t", - "d", - "l", - "i", - "b", - ".", - "c", - "o", - "m", - "p", - "o", - "n", - "e", - "n", - "t", - "s", - ".", - "i", - "n", - "t", - "r", - "i", - "n", - "s", - "i", - "c", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "r", - "a", - "g", - "\n", - "f", - "r", - "o", - "m", - " ", - "m", - "e", - "l", - "l", - "e", - "a", - ".", - "s", - "t", - "d", - "l", - "i", - "b", - ".", - "c", - "o", - "n", - "t", - "e", - "x", - "t", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "C", - "h", - "a", - "t", - "C", - "o", - "n", - "t", - "e", - "x", - "t", - "\n", - "\n", - "t", - "r", - "y", - ":", - "\n", - " ", - " ", - " ", - " ", - "f", - "r", - "o", - "m", - " ", - "d", - "o", - "t", - "e", - "n", - "v", - " ", - "i", - "m", - "p", - "o", - "r", - "t", - " ", - "l", - "o", - "a", - "d", - "_", - "d", - "o", - "t", - "e", - "n", - "v", - "\n", - " ", - " ", - " ", - " ", - "l", - "o", - "a", - "d", - "_", - "d", - "o", - "t", - "e", - "n", - "v", - "(", - "P", - "a", - "t", - "h", - "(", - "\"", - ".", - ".", - "/", - ".", - "e", - "n", - "v", - "\"", - ")", - ",", - " ", - "o", - "v", - "e", - "r", - "r", - "i", - "d", - "e", - "=", - "F", - "a", - "l", - "s", - "e", - ")", - "\n", - "e", - "x", - "c", - "e", - "p", - "t", - " ", - "I", - "m", - "p", - "o", - "r", - "t", - "E", - "r", - "r", - "o", - "r", - ":", - "\n", - " ", - " ", - " ", - " ", - "p", - "a", - "s", - "s" + "import os\n", + "from pathlib import Path\n", + "\n", + "from granite_switch.tutorials.chroma_loader import load_or_build_govt_chroma\n", + "from granite_switch.tutorials.vllm_server import (\n", + " kill_stale_vllm_processes,\n", + " launch_vllm,\n", + " print_gpu_state,\n", + " tail_log,\n", + " wait_for_server,\n", + ")\n", + "from mellea.backends.openai import OpenAIBackend\n", + "from mellea.stdlib.components import Document as MelleaDocument\n", + "from mellea.stdlib.components.intrinsic import rag\n", + "from mellea.stdlib.context import ChatContext\n", + "\n", + "try:\n", + " from dotenv import load_dotenv\n", + " load_dotenv(Path(\"../.env\"), override=False)\n", + "except ImportError:\n", + " pass" ] }, { @@ -699,7 +78,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "kill_stale_vllm_processes()\n", "print_gpu_state()" ] From b3d23ff544feb9c1ed3740d1a68012a1da8153a6 Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Mon, 25 May 2026 15:01:09 +0300 Subject: [PATCH 09/11] Bump default gpu_memory_utilization to 0.95 --- src/granite_switch/tutorials/vllm_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/granite_switch/tutorials/vllm_server.py b/src/granite_switch/tutorials/vllm_server.py index c807c82..c8a04ee 100644 --- a/src/granite_switch/tutorials/vllm_server.py +++ b/src/granite_switch/tutorials/vllm_server.py @@ -18,7 +18,7 @@ def launch_vllm( model: str, port: int, log_file: str, - gpu_memory_utilization: float = 0.85, + gpu_memory_utilization: float = 0.95, max_num_seqs: int = 16, enforce_eager: bool = True, extra_args: Sequence[str] = (), From 38056d9dbc91dcd8f14939514cf028dc17beaa12 Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Mon, 25 May 2026 15:10:05 +0300 Subject: [PATCH 10/11] Lower default max_num_seqs to 1 --- src/granite_switch/tutorials/vllm_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/granite_switch/tutorials/vllm_server.py b/src/granite_switch/tutorials/vllm_server.py index c8a04ee..99f75a8 100644 --- a/src/granite_switch/tutorials/vllm_server.py +++ b/src/granite_switch/tutorials/vllm_server.py @@ -19,7 +19,7 @@ def launch_vllm( port: int, log_file: str, gpu_memory_utilization: float = 0.95, - max_num_seqs: int = 16, + max_num_seqs: int = 1, enforce_eager: bool = True, extra_args: Sequence[str] = (), max_model_len: int = DEFAULT_MAX_MODEL_LEN, From 3346e447b2570b228c429ba5bcf21d7541b0831e Mon Sep 17 00:00:00 2001 From: Alon Freund Date: Mon, 25 May 2026 15:23:08 +0300 Subject: [PATCH 11/11] Remove explicit device='cpu' from chroma loader calls, use default (GPU) --- tutorials/notebooks/rag_101.ipynb | 1 - tutorials/notebooks/rag_flow.ipynb | 1 - 2 files changed, 2 deletions(-) diff --git a/tutorials/notebooks/rag_101.ipynb b/tutorials/notebooks/rag_101.ipynb index a566513..dcab24b 100644 --- a/tutorials/notebooks/rag_101.ipynb +++ b/tutorials/notebooks/rag_101.ipynb @@ -145,7 +145,6 @@ " jsonl_url = GOVT_JSONL_URL,\n", " embedding_model_id = EMBEDDING_MODEL_ID,\n", " load_only_tutorial_docs = True,\n", - " device = \"cpu\",\n", ")\n", "print(f\"Corpus ready — {chroma_collection.count():,} passages indexed.\")" ] diff --git a/tutorials/notebooks/rag_flow.ipynb b/tutorials/notebooks/rag_flow.ipynb index 7a2da19..72b11ae 100644 --- a/tutorials/notebooks/rag_flow.ipynb +++ b/tutorials/notebooks/rag_flow.ipynb @@ -118,7 +118,6 @@ " jsonl_path = GOVT_JSONL_PATH,\n", " jsonl_url = GOVT_JSONL_URL,\n", " embedding_model_id = EMBEDDING_MODEL_ID,\n", - " device = \"cpu\",\n", " load_only_tutorial_docs = True,\n", ")\n" ]