diff --git a/README.md b/README.md index 7180efd5a..728de6f8f 100644 --- a/README.md +++ b/README.md @@ -147,27 +147,40 @@ You can follow these steps to generate a PageIndex tree from a PDF document. pip3 install --upgrade -r requirements.txt ``` -### 2. Set your OpenAI API key +### 2. Set your API key -Create a `.env` file in the root directory and add your API key: +Create a `.env` file in the root directory and add your API key for your chosen provider: ```bash +# OpenAI (default) CHATGPT_API_KEY=your_openai_key_here + +# MiniMax (optional) +MINIMAX_API_KEY=your_minimax_key_here ``` ### 3. Run PageIndex on your PDF +**OpenAI (default):** ```bash python3 run_pageindex.py --pdf_path /path/to/your/document.pdf ``` +**MiniMax:** +```bash +python3 run_pageindex.py --pdf_path /path/to/your/document.pdf \ + --provider minimax --model MiniMax-Text-01 +``` +
Optional parameters
You can customize the processing with additional optional arguments: ``` ---model OpenAI model to use (default: gpt-4o-2024-11-20) +--model Model to use (default: gpt-4o-2024-11-20) +--provider LLM provider: openai or minimax (default: openai) +--api-base-url Custom API base URL (e.g. https://api.minimax.io/v1 for MiniMax) --toc-check-pages Pages to check for table of contents (default: 20) --max-pages-per-node Max pages per node (default: 10) --max-tokens-per-node Max tokens per node (default: 20000) @@ -175,6 +188,24 @@ You can customize the processing with additional optional arguments: --if-add-node-summary Add node summary (yes/no, default: yes) --if-add-doc-description Add doc description (yes/no, default: yes) ``` + +You can also set the provider via environment variables instead of CLI flags: +```bash +export LLM_PROVIDER=minimax +export API_BASE_URL=https://api.minimax.io/v1 # optional, for custom endpoints +``` +
+ +
+Supported LLM Providers +
+ +| Provider | Example Models | API Key Env Var | Notes | +|----------|---------------|-----------------|-------| +| **OpenAI** (default) | `gpt-4o-2024-11-20`, `gpt-4o-mini` | `CHATGPT_API_KEY` | Full support, recommended | +| **MiniMax** | `MiniMax-M2.5 | `MINIMAX_API_KEY` | Full support via OpenAI-compatible API at `https://api.minimax.io/v1` | + +**Note:** PageIndex relies on structured JSON output from the LLM. For best results, use capable models (GPT-4o or MiniMax-Text-01). Smaller models may produce lower-quality tree structures.
diff --git a/pageindex/config.yaml b/pageindex/config.yaml index fd73e3a2c..eea172a58 100644 --- a/pageindex/config.yaml +++ b/pageindex/config.yaml @@ -1,4 +1,6 @@ model: "gpt-4o-2024-11-20" +provider: "openai" # "openai" or "minimax" +api_base_url: null # Custom API base URL (e.g. https://api.minimax.io/v1 for MiniMax) toc_check_page_num: 20 max_page_num_each_node: 10 max_token_num_each_node: 20000 diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 39018c4df..5cf69d6be 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -1057,7 +1057,17 @@ async def tree_parser(page_list, opt, doc=None, logger=None): def page_index_main(doc, opt=None): logger = JsonLogger(doc) - + + # Set provider config from opt so all downstream API calls pick it up + if hasattr(opt, 'provider') and opt.provider: + os.environ['LLM_PROVIDER'] = opt.provider + from pageindex import utils + utils.LLM_PROVIDER = opt.provider + if hasattr(opt, 'api_base_url') and opt.api_base_url: + os.environ['API_BASE_URL'] = opt.api_base_url + from pageindex import utils + utils.API_BASE_URL = opt.api_base_url + is_valid_pdf = ( (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or isinstance(doc, BytesIO) @@ -1066,7 +1076,7 @@ def page_index_main(doc, opt=None): raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") print('Parsing PDF...') - page_list = get_page_tokens(doc) + page_list = get_page_tokens(doc, model=opt.model) logger.info({'total_page_number': len(page_list)}) logger.info({'total_token': sum([page[1] for page in page_list])}) @@ -1100,7 +1110,8 @@ async def page_index_builder(): return asyncio.run(page_index_builder()) -def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, +def page_index(doc, model=None, provider=None, api_base_url=None, + toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None): user_opt = { diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..869e5cdef 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -18,17 +18,80 @@ from types import SimpleNamespace as config CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY") +LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai") # "openai" or "minimax" +API_BASE_URL = os.getenv("API_BASE_URL") # Custom API base URL + +MINIMAX_BASE_URL = "https://api.minimax.io/v1" + + +def _is_minimax_model(model): + """Check if the model is a MiniMax model by name prefix.""" + if not model: + return False + model_lower = model.lower() + # Support various MiniMax model naming patterns: + # - minimax-m1, minimax-m2.5, minimax-m2.5-highspeed + # - MiniMax-Text-01, abab6.5s-chat, etc. + return model_lower.startswith("minimax") or model_lower.startswith("abab") + + +def _get_provider_config(provider=None, api_key=None, base_url=None): + """Resolve provider, api_key, and base_url from args or environment.""" + provider = provider or LLM_PROVIDER + + if provider == "minimax": + return { + "provider": "minimax", + "api_key": api_key or MINIMAX_API_KEY, + "base_url": base_url or API_BASE_URL or MINIMAX_BASE_URL, + } + else: # openai (default) + cfg = { + "provider": "openai", + "api_key": api_key or CHATGPT_API_KEY, + } + if base_url or API_BASE_URL: + cfg["base_url"] = base_url or API_BASE_URL + return cfg + + +def _get_client_kwargs(model, api_key=None, provider=None, base_url=None): + """Get OpenAI client kwargs based on model name or explicit provider config.""" + # If provider is explicitly set, use provider config + if provider: + pcfg = _get_provider_config(provider, api_key, base_url) + client_kwargs = {"api_key": pcfg["api_key"]} + if "base_url" in pcfg: + client_kwargs["base_url"] = pcfg["base_url"] + return client_kwargs + + # Auto-detect based on model name + if _is_minimax_model(model): + return { + "api_key": api_key or MINIMAX_API_KEY, + "base_url": base_url or API_BASE_URL or MINIMAX_BASE_URL, + } + + # Default to OpenAI + cfg = {"api_key": api_key or CHATGPT_API_KEY} + if base_url or API_BASE_URL: + cfg["base_url"] = base_url or API_BASE_URL + return cfg def count_tokens(text, model=None): if not text: return 0 - enc = tiktoken.encoding_for_model(model) + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + enc = tiktoken.get_encoding("cl100k_base") tokens = enc.encode(text) return len(tokens) -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API_with_finish_reason(model, prompt, api_key=None, chat_history=None, provider=None, base_url=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + client = openai.OpenAI(**_get_client_kwargs(model, api_key, provider, base_url)) for i in range(max_retries): try: if chat_history: @@ -36,7 +99,7 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ messages.append({"role": "user", "content": prompt}) else: messages = [{"role": "user", "content": prompt}] - + response = client.chat.completions.create( model=model, messages=messages, @@ -51,16 +114,16 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API(model, prompt, api_key=None, chat_history=None, provider=None, base_url=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + client = openai.OpenAI(**_get_client_kwargs(model, api_key, provider, base_url)) for i in range(max_retries): try: if chat_history: @@ -68,30 +131,30 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): messages.append({"role": "user", "content": prompt}) else: messages = [{"role": "user", "content": prompt}] - + response = client.chat.completions.create( model=model, messages=messages, temperature=0, ) - + return response.choices[0].message.content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): +async def ChatGPT_API_async(model, prompt, api_key=None, provider=None, base_url=None): max_retries = 10 messages = [{"role": "user", "content": prompt}] for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: + async with openai.AsyncOpenAI(**_get_client_kwargs(model, api_key, provider, base_url)) as client: response = await client.chat.completions.create( model=model, messages=messages, @@ -102,7 +165,7 @@ async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - await asyncio.sleep(1) # Wait for 1s before retrying + await asyncio.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" @@ -411,7 +474,10 @@ def add_preface_if_needed(data): def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): - enc = tiktoken.encoding_for_model(model) + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + enc = tiktoken.get_encoding("cl100k_base") if pdf_parser == "PyPDF2": pdf_reader = PyPDF2.PdfReader(pdf_path) page_list = [] @@ -530,10 +596,10 @@ def remove_structure_text(data): return data -def check_token_limit(structure, limit=110000): +def check_token_limit(structure, limit=110000, model='gpt-4o'): list = structure_to_list(structure) for node in list: - num_tokens = count_tokens(node['text'], model='gpt-4o') + num_tokens = count_tokens(node['text'], model=model) if num_tokens > limit: print(f"Node ID: {node['node_id']} has {num_tokens} tokens") print("Start Index:", node['start_index']) diff --git a/run_pageindex.py b/run_pageindex.py index 107024505..a45ef5c05 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -10,7 +10,13 @@ parser.add_argument('--pdf_path', type=str, help='Path to the PDF file') parser.add_argument('--md_path', type=str, help='Path to the Markdown file') - parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', help='Model to use') + parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', + help='Model to use (e.g. gpt-4o-2024-11-20, MiniMax-Text-01, abab6.5s-chat)') + parser.add_argument('--provider', type=str, default='openai', + choices=['openai', 'minimax'], + help='LLM provider: openai or minimax (default: openai)') + parser.add_argument('--api-base-url', type=str, default=None, + help='Custom API base URL (e.g. https://api.minimax.io/v1 for MiniMax)') parser.add_argument('--toc-check-pages', type=int, default=20, help='Number of pages to check for table of contents (PDF only)') @@ -54,6 +60,8 @@ # Configure options opt = config( model=args.model, + provider=args.provider, + api_base_url=args.api_base_url, toc_check_page_num=args.toc_check_pages, max_page_num_each_node=args.max_pages_per_node, max_token_num_each_node=args.max_tokens_per_node, @@ -98,6 +106,8 @@ # Create options dict with user args user_opt = { 'model': args.model, + 'provider': args.provider, + 'api_base_url': args.api_base_url, 'if_add_node_summary': args.if_add_node_summary, 'if_add_doc_description': args.if_add_doc_description, 'if_add_node_text': args.if_add_node_text,