diff --git a/.gitignore b/.gitignore index e7058c9e..4cef1977 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ docs/_build/ # Jupyter Notebook .ipynb_checkpoints +*.ipynb notebooks/ # IPython diff --git a/bertopic/representation/__init__.py b/bertopic/representation/__init__.py index f1502982..3df028fe 100644 --- a/bertopic/representation/__init__.py +++ b/bertopic/representation/__init__.py @@ -33,6 +33,13 @@ msg = "`pip install openai` \n\n" OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg) +# Ollama Generator +try: + from bertopic.representation._ollama import Ollama +except ModuleNotFoundError: + msg = "`pip install ollama` \n\n" + Ollama = NotInstalled("Ollama", "ollama", custom_msg=msg) + # LiteLLM Generator try: from bertopic.representation._litellm import LiteLLM @@ -68,6 +75,7 @@ "LiteLLM", "LlamaCPP", "MaximalMarginalRelevance", + "Ollama", "OpenAI", "PartOfSpeech", "TextGeneration", diff --git a/bertopic/representation/_base.py b/bertopic/representation/_base.py index 63feeda9..54be3d2e 100644 --- a/bertopic/representation/_base.py +++ b/bertopic/representation/_base.py @@ -1,7 +1,10 @@ import pandas as pd from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator -from typing import Mapping, List, Tuple +from typing import Mapping, List, Tuple, Union, Callable + +from bertopic.representation._prompts import DEFAULT_CHAT_PROMPT +from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters class BaseRepresentation(BaseEstimator): @@ -38,3 +41,113 @@ def extract_topics( model is used. """ return topic_model.topic_representations_ + + +class LLMRepresentation(BaseRepresentation): + """Base class for LLM-based representation models.""" + + def __init__( + self, + prompt: str | None = None, + nr_docs: int = 4, + diversity: float | None = None, + doc_length: int | None = None, + tokenizer: Union[str, Callable] | None = None, + ): + """Generate a representation model that leverages LLMs. + + Arguments: + prompt: The prompt to be used in the model. If no prompt is given, + `bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead. + NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt + to decide where the keywords and documents need to be + inserted. + nr_docs: The number of documents to pass to Ollama if a prompt + with the `["DOCUMENTS"]` tag is used. + diversity: The diversity of documents to pass to Ollama. + Accepts values between 0 and 1. A higher + values results in passing more diverse documents + whereas lower values passes more similar documents. + doc_length: The maximum length of each document. If a document is longer, + it will be truncated. If None, the entire document is passed. + tokenizer: The tokenizer used to calculate to split the document into segments + used to count the length of a document. + * If tokenizer is 'char', then the document is split up + into characters which are counted to adhere to `doc_length` + * If tokenizer is 'whitespace', the document is split up + into words separated by whitespaces. These words are counted + and truncated depending on `doc_length` + * If tokenizer is 'vectorizer', then the internal CountVectorizer + is used to tokenize the document. These tokens are counted + and truncated depending on `doc_length` + * If tokenizer is a callable, then that callable is used to tokenize + the document. These tokens are counted and truncated depending + on `doc_length` + """ + self.prompt = prompt + + # Representative document extraction parameters + self.nr_docs = nr_docs + self.diversity = diversity + + # Document truncation + self.doc_length = doc_length + self.tokenizer = tokenizer + validate_truncate_document_parameters(self.tokenizer, self.doc_length) + + # Store prompts for inspection + self.prompts_ = [] + + def _create_prompt( + self, docs: list[str], topic: int, topics: Mapping[str, List[Tuple[str, float]]], topic_model + ) -> str: + """Create prompt for LLM by either using the default prompt or replacing custom tags. + Specifically, [KEYWORDS] and [DOCUMENTS] can be used in custom prompts to insert the topic's keywords and most representative documents. + + Arguments: + docs: The most representative documents for a given topic. + topic: The topic for which to create the prompt. + topics: A dictionary with topic (key) and tuple of word and + weight (value) as calculated by c-TF-IDF. + topic_model: The BERTopic model that is fitted until topic + representations are calculated. + + Returns: + prompt: The created prompt. + """ + keywords = next(zip(*topics[topic])) + + # Use the Default Chat Prompt + if self.prompt == DEFAULT_CHAT_PROMPT: + prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords)) + prompt = self._replace_documents(prompt, docs, topic_model) + + # Use a custom prompt that leverages keywords, documents or both using + # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively + else: + prompt = self.prompt + if "[KEYWORDS]" in prompt: + prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords)) + if "[DOCUMENTS]" in prompt: + prompt = self._replace_documents(prompt, docs, topic_model) + + return prompt + + def _replace_documents(self, prompt: str, docs: list[str], topic_model) -> str: + """Replace [DOCUMENTS] tag in prompt with actual documents. + + Arguments: + prompt: The prompt containing the [DOCUMENTS] tag. + docs: The most representative documents for a given topic. + topic_model: The BERTopic model that is fitted until topic + representations are calculated. + + Returns: + The prompt with the [DOCUMENTS] tag replaced by actual documents. + """ + # Truncate documents if needed + truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] + + # Replace tag with documents + formatted_docs = "\n".join([f"- {doc}" for doc in truncated_docs]) + return prompt.replace("[DOCUMENTS]", formatted_docs) diff --git a/bertopic/representation/_cohere.py b/bertopic/representation/_cohere.py index 9ccd5ae5..9d6c056b 100644 --- a/bertopic/representation/_cohere.py +++ b/bertopic/representation/_cohere.py @@ -3,42 +3,11 @@ from tqdm import tqdm from scipy.sparse import csr_matrix from typing import Mapping, List, Tuple, Union, Callable -from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters - - -DEFAULT_PROMPT = """ -This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title ---- -Topic: -Sample texts from this topic: -- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food. -- Meat, but especially beef, is the word food in terms of emissions. -- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one. - -Keywords: meat beef eat eating emissions steak food health processed chicken -Topic name: Environmental impacts of eating meat ---- -Topic: -Sample texts from this topic: -- I have ordered the product weeks ago but it still has not arrived! -- The website mentions that it only takes a couple of days to deliver but I still have not received mine. -- I got a message stating that I received the monitor but that is not true! -- It took a month longer to deliver than was advised... - -Keywords: deliver weeks product shipping long delivery received arrived arrive week -Topic name: Shipping and delivery issues ---- -Topic: -Sample texts from this topic: -[DOCUMENTS] -Keywords: [KEYWORDS] -Topic name:""" - -DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts." - - -class Cohere(BaseRepresentation): +from bertopic.representation._base import LLMRepresentation +from bertopic.representation._prompts import DEFAULT_SYSTEM_PROMPT, DEFAULT_CHAT_PROMPT + + +class Cohere(LLMRepresentation): """Use the Cohere API to generate topic labels based on their generative model. @@ -49,12 +18,12 @@ class Cohere(BaseRepresentation): client: A `cohere.Client` model: Model to use within Cohere, defaults to `"xlarge"`. prompt: The prompt to be used in the model. If no prompt is given, - `self.default_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead. NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt to decide where the keywords and documents need to be inserted. system_prompt: The system prompt to be used in the model. If no system prompt is given, - `self.default_system_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_SYSTEM_PROMPT` is used instead. delay_in_seconds: The delay in seconds between consecutive prompts in order to prevent RateLimitErrors. nr_docs: The number of documents to pass to OpenAI if a prompt @@ -120,20 +89,19 @@ def __init__( doc_length: int | None = None, tokenizer: Union[str, Callable] | None = None, ): + super().__init__( + prompt=prompt if prompt is not None else DEFAULT_CHAT_PROMPT, + nr_docs=nr_docs, + diversity=diversity, + doc_length=doc_length, + tokenizer=tokenizer, + ) + + # Cohere specific parameters self.client = client self.model = model - self.prompt = prompt if prompt is not None else DEFAULT_PROMPT self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT - self.default_prompt_ = DEFAULT_PROMPT - self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT self.delay_in_seconds = delay_in_seconds - self.nr_docs = nr_docs - self.diversity = diversity - self.doc_length = doc_length - self.tokenizer = tokenizer - validate_truncate_document_parameters(self.tokenizer, self.doc_length) - - self.prompts_ = [] def extract_topics( self, @@ -161,8 +129,7 @@ def extract_topics( # Generate using Cohere's Language Model updated_topics = {} for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] - prompt = self._create_prompt(truncated_docs, topic, topics) + prompt = self._create_prompt(docs=docs, topic=topic, topics=topics, topic_model=topic_model) self.prompts_.append(prompt) # Delay @@ -180,30 +147,3 @@ def extract_topics( updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)] return updated_topics - - def _create_prompt(self, docs, topic, topics): - keywords = next(zip(*topics[topic])) - - # Use the Default Chat Prompt - if self.prompt == DEFAULT_PROMPT: - prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords)) - prompt = self._replace_documents(prompt, docs) - - # Use a custom prompt that leverages keywords, documents or both using - # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively - else: - prompt = self.prompt - if "[KEYWORDS]" in prompt: - prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords)) - if "[DOCUMENTS]" in prompt: - prompt = self._replace_documents(prompt, docs) - - return prompt - - @staticmethod - def _replace_documents(prompt, docs): - to_replace = "" - for doc in docs: - to_replace += f"- {doc}\n" - prompt = prompt.replace("[DOCUMENTS]", to_replace) - return prompt diff --git a/bertopic/representation/_langchain.py b/bertopic/representation/_langchain.py index b80d9b77..b5560e31 100644 --- a/bertopic/representation/_langchain.py +++ b/bertopic/representation/_langchain.py @@ -3,13 +3,12 @@ from scipy.sparse import csr_matrix from typing import Callable, Mapping, List, Tuple, Union -from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters +from bertopic.representation._base import LLMRepresentation +from bertopic.representation._utils import truncate_document +from bertopic.representation._prompts import DEFAULT_COMPLETION_PROMPT -DEFAULT_PROMPT = "What are these documents about? Please give a single label." - -class LangChain(BaseRepresentation): +class LangChain(LLMRepresentation): """Using chains in langchain to generate topic labels. The classic example uses `langchain.chains.question_answering.load_qa_chain`. @@ -22,7 +21,7 @@ class LangChain(BaseRepresentation): Input keys must be `input_documents` and `question`. Output key must be `output_text`. prompt: The prompt to be used in the model. If no prompt is given, - `self.default_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_COMPLETION_PROMPT` is used instead. NOTE: Use `"[KEYWORDS]"` in the prompt to decide where the keywords need to be inserted. Keywords won't be included unless @@ -140,15 +139,17 @@ def __init__( tokenizer: Union[str, Callable] | None = None, chain_config=None, ): + super().__init__( + prompt=prompt if prompt is not None else DEFAULT_COMPLETION_PROMPT, + nr_docs=nr_docs, + diversity=diversity, + doc_length=doc_length, + tokenizer=tokenizer, + ) + + # LangChain specific parameters self.chain = chain - self.prompt = prompt if prompt is not None else DEFAULT_PROMPT - self.default_prompt_ = DEFAULT_PROMPT self.chain_config = chain_config - self.nr_docs = nr_docs - self.diversity = diversity - self.doc_length = doc_length - self.tokenizer = tokenizer - validate_truncate_document_parameters(self.tokenizer, self.doc_length) def extract_topics( self, diff --git a/bertopic/representation/_litellm.py b/bertopic/representation/_litellm.py index c9f69ae8..9a3ed66c 100644 --- a/bertopic/representation/_litellm.py +++ b/bertopic/representation/_litellm.py @@ -1,23 +1,18 @@ import time from litellm import completion import pandas as pd +from typing import Union, Callable from tqdm import tqdm from scipy.sparse import csr_matrix from typing import Mapping, List, Tuple, Any -from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import retry_with_exponential_backoff +from bertopic.representation._base import LLMRepresentation +from bertopic.representation._utils import ( + retry_with_exponential_backoff, +) +from bertopic.representation._prompts import DEFAULT_CHAT_PROMPT -DEFAULT_PROMPT = """ -I have a topic that contains the following documents: -[DOCUMENTS] -The topic is described by the following keywords: [KEYWORDS] -Based on the information above, extract a short topic label in the following format: -topic: -""" - - -class LiteLLM(BaseRepresentation): +class LiteLLM(LLMRepresentation): """Using the LiteLLM API to generate topic labels. For an overview of models see: @@ -27,7 +22,7 @@ class LiteLLM(BaseRepresentation): model: Model to use. Defaults to OpenAI's "gpt-3.5-turbo". generator_kwargs: Kwargs passed to `litellm.completion`. prompt: The prompt to be used in the model. If no prompt is given, - `self.default_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead. NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt to decide where the keywords and documents need to be inserted. @@ -44,6 +39,21 @@ class LiteLLM(BaseRepresentation): Accepts values between 0 and 1. A higher values results in passing more diverse documents whereas lower values passes more similar documents. + doc_length: The maximum length of each document. If a document is longer, + it will be truncated. If None, the entire document is passed. + tokenizer: The tokenizer used to calculate to split the document into segments + used to count the length of a document. + * If tokenizer is 'char', then the document is split up + into characters which are counted to adhere to `doc_length` + * If tokenizer is 'whitespace', the document is split up + into words separated by whitespaces. These words are counted + and truncated depending on `doc_length` + * If tokenizer is 'vectorizer', then the internal CountVectorizer + is used to tokenize the document. These tokens are counted + and truncated depending on `doc_length` + * If tokenizer is a callable, then that callable is used to tokenize + the document. These tokens are counted and truncated depending + on `doc_length` Usage: @@ -85,20 +95,22 @@ def __init__( exponential_backoff: bool = False, nr_docs: int = 4, diversity: float | None = None, + doc_length: int | None = None, + tokenizer: Union[str, Callable] | None = None, ): + super().__init__( + prompt=prompt if prompt is not None else DEFAULT_CHAT_PROMPT, + nr_docs=nr_docs, + diversity=diversity, + doc_length=doc_length, + tokenizer=tokenizer, + ) + + # LiteLLM specific parameters self.model = model - self.prompt = prompt if prompt else DEFAULT_PROMPT - self.default_prompt_ = DEFAULT_PROMPT self.delay_in_seconds = delay_in_seconds self.exponential_backoff = exponential_backoff - self.nr_docs = nr_docs - self.diversity = diversity - self.generator_kwargs = generator_kwargs - if self.generator_kwargs.get("model"): - self.model = generator_kwargs.get("model") - if self.generator_kwargs.get("prompt"): - del self.generator_kwargs["prompt"] def extract_topics( self, topic_model, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]] @@ -122,7 +134,8 @@ def extract_topics( # Generate using a (Large) Language Model updated_topics = {} for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - prompt = self._create_prompt(docs, topic, topics) + prompt = self._create_prompt(docs=docs, topic=topic, topics=topics, topic_model=topic_model) + self.prompts_.append(prompt) # Delay if self.delay_in_seconds: @@ -133,43 +146,14 @@ def extract_topics( {"role": "user", "content": prompt}, ] kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs} - if self.exponential_backoff: - response = chat_completions_with_backoff(**kwargs) - else: - response = completion(**kwargs) - label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "") + # Generate response + response = chat_completions_with_backoff(**kwargs) if self.exponential_backoff else completion(**kwargs) + label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "") updated_topics[topic] = [(label, 1)] return updated_topics - def _create_prompt(self, docs, topic, topics): - keywords = next(zip(*topics[topic])) - - # Use the Default Chat Prompt - if self.prompt == DEFAULT_PROMPT: - prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords)) - prompt = self._replace_documents(prompt, docs) - - # Use a custom prompt that leverages keywords, documents or both using - # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively - else: - prompt = self.prompt - if "[KEYWORDS]" in prompt: - prompt = prompt.replace("[KEYWORDS]", " ".join(keywords)) - if "[DOCUMENTS]" in prompt: - prompt = self._replace_documents(prompt, docs) - - return prompt - - @staticmethod - def _replace_documents(prompt, docs): - to_replace = "" - for doc in docs: - to_replace += f"- {doc[:255]}\n" - prompt = prompt.replace("[DOCUMENTS]", to_replace) - return prompt - def chat_completions_with_backoff(**kwargs): return retry_with_exponential_backoff( diff --git a/bertopic/representation/_llamacpp.py b/bertopic/representation/_llamacpp.py index 730a5f55..8b0b6747 100644 --- a/bertopic/representation/_llamacpp.py +++ b/bertopic/representation/_llamacpp.py @@ -3,54 +3,23 @@ from scipy.sparse import csr_matrix from llama_cpp import Llama from typing import Mapping, List, Tuple, Any, Union, Callable -from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters - - -DEFAULT_PROMPT = """ -This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title ---- -Topic: -Sample texts from this topic: -- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food. -- Meat, but especially beef, is the word food in terms of emissions. -- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one. - -Keywords: meat beef eat eating emissions steak food health processed chicken -Topic name: Environmental impacts of eating meat ---- -Topic: -Sample texts from this topic: -- I have ordered the product weeks ago but it still has not arrived! -- The website mentions that it only takes a couple of days to deliver but I still have not received mine. -- I got a message stating that I received the monitor but that is not true! -- It took a month longer to deliver than was advised... - -Keywords: deliver weeks product shipping long delivery received arrived arrive week -Topic name: Shipping and delivery issues ---- -Topic: -Sample texts from this topic: -[DOCUMENTS] -Keywords: [KEYWORDS] -Topic name:""" - -DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts." - - -class LlamaCPP(BaseRepresentation): +from bertopic.representation._base import LLMRepresentation +from bertopic.representation._prompts import DEFAULT_SYSTEM_PROMPT, DEFAULT_CHAT_PROMPT + + +class LlamaCPP(LLMRepresentation): """A llama.cpp implementation to use as a representation model. Arguments: model: Either a string pointing towards a local LLM or a `llama_cpp.Llama` object. prompt: The prompt to be used in the model. If no prompt is given, - `self.default_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead. NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt to decide where the keywords and documents need to be inserted. system_prompt: The system prompt to be used in the model. If no system prompt is given, - `self.default_system_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_SYSTEM_PROMPT` is used instead. pipeline_kwargs: Kwargs that you can pass to the `llama_cpp.Llama` when it is called such as `max_tokens` to be generated. nr_docs: The number of documents to pass to OpenAI if a prompt @@ -123,6 +92,15 @@ def __init__( doc_length: int | None = None, tokenizer: Union[str, Callable] | None = None, ): + super().__init__( + prompt=prompt if prompt is not None else DEFAULT_CHAT_PROMPT, + nr_docs=nr_docs, + diversity=diversity, + doc_length=doc_length, + tokenizer=tokenizer, + ) + + # Llama.cpp specific parameters if isinstance(model, str): self.model = Llama(model_path=model, n_gpu_layers=-1, stop="\n", chat_format="ChatML") elif isinstance(model, Llama): @@ -133,18 +111,8 @@ def __init__( "pass is either a string referring to a" "local LLM or a ` llama_cpp.Llama` object." ) - self.prompt = prompt if prompt is not None else DEFAULT_PROMPT self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT - self.default_prompt_ = DEFAULT_PROMPT - self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT self.pipeline_kwargs = pipeline_kwargs - self.nr_docs = nr_docs - self.diversity = diversity - self.doc_length = doc_length - self.tokenizer = tokenizer - validate_truncate_document_parameters(self.tokenizer, self.doc_length) - - self.prompts_ = [] def extract_topics( self, @@ -172,8 +140,7 @@ def extract_topics( updated_topics = {} for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): # Prepare prompt - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] - prompt = self._create_prompt(truncated_docs, topic, topics) + prompt = self._create_prompt(docs=docs, topic=topic, topics=topics, topic_model=topic_model) self.prompts_.append(prompt) # Extract result from generator and use that as label @@ -186,30 +153,3 @@ def extract_topics( updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)] return updated_topics - - def _create_prompt(self, docs, topic, topics): - keywords = next(zip(*topics[topic])) - - # Use the Default Chat Prompt - if self.prompt == DEFAULT_PROMPT: - prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords)) - prompt = self._replace_documents(prompt, docs) - - # Use a custom prompt that leverages keywords, documents or both using - # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively - else: - prompt = self.prompt - if "[KEYWORDS]" in prompt: - prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords)) - if "[DOCUMENTS]" in prompt: - prompt = self._replace_documents(prompt, docs) - - return prompt - - @staticmethod - def _replace_documents(prompt, docs): - to_replace = "" - for doc in docs: - to_replace += f"- {doc}\n" - prompt = prompt.replace("[DOCUMENTS]", to_replace) - return prompt diff --git a/bertopic/representation/_ollama.py b/bertopic/representation/_ollama.py new file mode 100644 index 00000000..f7fc4803 --- /dev/null +++ b/bertopic/representation/_ollama.py @@ -0,0 +1,154 @@ +import json +import pandas as pd +from ollama import chat +from ollama import ChatResponse +from tqdm import tqdm +from scipy.sparse import csr_matrix +from typing import Mapping, List, Tuple, Any, Union, Callable +from bertopic.representation._base import LLMRepresentation +from json.decoder import JSONDecodeError +from bertopic.representation._prompts import DEFAULT_SYSTEM_PROMPT, DEFAULT_CHAT_PROMPT, DEFAULT_JSON_SCHEMA + + +class Ollama(LLMRepresentation): + r"""Using the Ollama API to generate topic labels based on a local LLM. + + Arguments: + model: Model to use within Ollama. + generator_kwargs: Kwargs passed to `ollama.chat` for fine-tuning the output. + prompt: The prompt to be used in the model. If no prompt is given, + `bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead. + NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt + to decide where the keywords and documents need to be + inserted. + system_prompt: The system prompt to be used in the model. If no system prompt is given, + `bertopic.representation._prompts.DEFAULT_SYSTEM_PROMPT` is used instead. + json_schema: A dictionary representing the JSON schema to enforce structured output. + If set to True, a default schema will be used (`bertopic.representation._prompts.DEFAULT_JSON_SCHEMA`). + nr_docs: The number of documents to pass to Ollama if a prompt + with the `["DOCUMENTS"]` tag is used. + diversity: The diversity of documents to pass to Ollama. + Accepts values between 0 and 1. A higher + values results in passing more diverse documents + whereas lower values passes more similar documents. + doc_length: The maximum length of each document. If a document is longer, + it will be truncated. If None, the entire document is passed. + tokenizer: The tokenizer used to calculate to split the document into segments + used to count the length of a document. + * If tokenizer is 'char', then the document is split up + into characters which are counted to adhere to `doc_length` + * If tokenizer is 'whitespace', the document is split up + into words separated by whitespaces. These words are counted + and truncated depending on `doc_length` + * If tokenizer is 'vectorizer', then the internal CountVectorizer + is used to tokenize the document. These tokens are counted + and truncated depending on `doc_length` + * If tokenizer is a callable, then that callable is used to tokenize + the document. These tokens are counted and truncated depending + on `doc_length` + + Usage: + + To use this, you will need to install the ollama package first: + + `pip install ollama` + + Then, you can use the Ollama representation model as follows: + + ```python + from bertopic.representation import Ollama + from bertopic import BERTopic + + # Create your representation model + representation_model = Ollama("gemma3") + + # Use the representation model in BERTopic on top of the default pipeline + topic_model = BERTopic(representation_model=representation_model) + ``` + + You can also use a custom prompt: + + ```python + prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '" + representation_model = Ollama("gemma3", prompt=prompt) + ``` + """ + + def __init__( + self, + model: str, + prompt: str | None = None, + system_prompt: str | None = None, + json_schema: Mapping[str, Any] | bool = False, + generator_kwargs: Mapping[str, Any] = {}, + nr_docs: int = 4, + diversity: float | None = None, + doc_length: int | None = None, + tokenizer: Union[str, Callable] | None = None, + ): + super().__init__( + prompt=prompt if prompt is not None else DEFAULT_CHAT_PROMPT, + nr_docs=nr_docs, + diversity=diversity, + doc_length=doc_length, + tokenizer=tokenizer, + ) + + # Ollama specific parameters + self.model = model + self.system_prompt = DEFAULT_SYSTEM_PROMPT if system_prompt is None else system_prompt + self.json_schema = DEFAULT_JSON_SCHEMA if json_schema is True else json_schema + self.generator_kwargs = generator_kwargs + + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. + + Arguments: + topic_model: A BERTopic model + documents: All input documents + c_tf_idf: The topic c-TF-IDF representation + topics: The candidate topics as calculated with c-TF-IDF + + Returns: + updated_topics: Updated topic representations + """ + # Extract the top n representative documents per topic + repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( + c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity + ) + + # Generate using Ollama's Language Model + updated_topics = {} + for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): + prompt = self._create_prompt(docs=docs, topic=topic, topics=topics, topic_model=topic_model) + self.prompts_.append(prompt) + + # Call Ollama API + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": prompt}, + ] + kwargs = { + "model": self.model, + "messages": messages, + **self.generator_kwargs, + } + response: ChatResponse = chat(**kwargs) + + # Update labels + if self.json_schema: + try: + label = json.loads(response.message.content)["topic_label"] + except (JSONDecodeError, KeyError): + label = response.message.content.strip().replace("topic: ", "") + else: + label = response.message.content.strip().replace("topic: ", "") + updated_topics[topic] = [(label, 1)] + + return updated_topics diff --git a/bertopic/representation/_openai.py b/bertopic/representation/_openai.py index 71eb8c9a..4f0ecfe8 100644 --- a/bertopic/representation/_openai.py +++ b/bertopic/representation/_openai.py @@ -4,50 +4,14 @@ from tqdm import tqdm from scipy.sparse import csr_matrix from typing import Mapping, List, Tuple, Any, Union, Callable -from bertopic.representation._base import BaseRepresentation +from bertopic.representation._base import LLMRepresentation from bertopic.representation._utils import ( retry_with_exponential_backoff, - truncate_document, - validate_truncate_document_parameters, ) +from bertopic.representation._prompts import DEFAULT_SYSTEM_PROMPT, DEFAULT_CHAT_PROMPT -DEFAULT_CHAT_PROMPT = """You will extract a short topic label from given documents and keywords. -Here are two examples of topics you created before: - -# Example 1 -Sample texts from this topic: -- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food. -- Meat, but especially beef, is the worst food in terms of emissions. -- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one. - -Keywords: meat beef eat eating emissions steak food health processed chicken -topic: Environmental impacts of eating meat - -# Example 2 -Sample texts from this topic: -- I have ordered the product weeks ago but it still has not arrived! -- The website mentions that it only takes a couple of days to deliver but I still have not received mine. -- I got a message stating that I received the monitor but that is not true! -- It took a month longer to deliver than was advised... - -Keywords: deliver weeks product shipping long delivery received arrived arrive week -topic: Shipping and delivery issues - -# Your task -Sample texts from this topic: -[DOCUMENTS] - -Keywords: [KEYWORDS] - -Based on the information above, extract a short topic label (three words at most) in the following format: -topic: -""" - -DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts." - - -class OpenAI(BaseRepresentation): +class OpenAI(LLMRepresentation): r"""Using the OpenAI API to generate topic labels based on one of their Completion of ChatCompletion models. @@ -60,12 +24,12 @@ class OpenAI(BaseRepresentation): generator_kwargs: Kwargs passed to `openai.Completion.create` for fine-tuning the output. prompt: The prompt to be used in the model. If no prompt is given, - `self.default_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead. NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt to decide where the keywords and documents need to be inserted. system_prompt: The system prompt to be used in the model. If no system prompt is given, - `self.default_system_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_SYSTEM_PROMPT` is used instead. delay_in_seconds: The delay in seconds between consecutive prompts in order to prevent RateLimitErrors. exponential_backoff: Retry requests with a random exponential backoff. @@ -143,41 +107,22 @@ def __init__( diversity: float | None = None, doc_length: int | None = None, tokenizer: Union[str, Callable] | None = None, - **kwargs, ): + super().__init__( + prompt=prompt if prompt is not None else DEFAULT_CHAT_PROMPT, + nr_docs=nr_docs, + diversity=diversity, + doc_length=doc_length, + tokenizer=tokenizer, + ) + + # OpenAI specific parameters self.client = client self.model = model - - if prompt is None: - self.prompt = DEFAULT_CHAT_PROMPT - else: - self.prompt = prompt - - if system_prompt is None: - self.system_prompt = DEFAULT_SYSTEM_PROMPT - else: - self.system_prompt = system_prompt - - self.default_prompt_ = DEFAULT_CHAT_PROMPT - self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT + self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT self.delay_in_seconds = delay_in_seconds self.exponential_backoff = exponential_backoff - self.nr_docs = nr_docs - self.diversity = diversity - self.doc_length = doc_length - self.tokenizer = tokenizer - validate_truncate_document_parameters(self.tokenizer, self.doc_length) - - self.prompts_ = [] - self.generator_kwargs = generator_kwargs - if self.generator_kwargs.get("model"): - self.model = generator_kwargs.get("model") - del self.generator_kwargs["model"] - if self.generator_kwargs.get("prompt"): - del self.generator_kwargs["prompt"] - if not self.generator_kwargs.get("stop"): - self.generator_kwargs["stop"] = "\n" def extract_topics( self, @@ -205,8 +150,7 @@ def extract_topics( # Generate using OpenAI's Language Model updated_topics = {} for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] - prompt = self._create_prompt(truncated_docs, topic, topics) + prompt = self._create_prompt(docs=docs, topic=topic, topics=topics, topic_model=topic_model) self.prompts_.append(prompt) # Delay @@ -239,33 +183,6 @@ def extract_topics( return updated_topics - def _create_prompt(self, docs, topic, topics): - keywords = next(zip(*topics[topic])) - - # Use the Default Chat Prompt - if self.prompt == DEFAULT_CHAT_PROMPT: - prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords)) - prompt = self._replace_documents(prompt, docs) - - # Use a custom prompt that leverages keywords, documents or both using - # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively - else: - prompt = self.prompt - if "[KEYWORDS]" in prompt: - prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords)) - if "[DOCUMENTS]" in prompt: - prompt = self._replace_documents(prompt, docs) - - return prompt - - @staticmethod - def _replace_documents(prompt, docs): - to_replace = "" - for doc in docs: - to_replace += f"- {doc}\n" - prompt = prompt.replace("[DOCUMENTS]", to_replace) - return prompt - def chat_completions_with_backoff(client, **kwargs): return retry_with_exponential_backoff( diff --git a/bertopic/representation/_prompts.py b/bertopic/representation/_prompts.py new file mode 100644 index 00000000..268c4bb1 --- /dev/null +++ b/bertopic/representation/_prompts.py @@ -0,0 +1,46 @@ +DEFAULT_COMPLETION_PROMPT = "What are these documents about? Please give a single label." +DEFAULT_CHAT_PROMPT = """You will extract a short topic label from given documents and keywords. +Here are two examples of topics you created before: + +# Example 1 +Sample texts from this topic: +- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food. +- Meat, but especially beef, is the worst food in terms of emissions. +- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one. + +Keywords: meat beef eat eating emissions steak food health processed chicken +topic: Environmental impacts of eating meat + +# Example 2 +Sample texts from this topic: +- I have ordered the product weeks ago but it still has not arrived! +- The website mentions that it only takes a couple of days to deliver but I still have not received mine. +- I got a message stating that I received the monitor but that is not true! +- It took a month longer to deliver than was advised... + +Keywords: deliver weeks product shipping long delivery received arrived arrive week +topic: Shipping and delivery issues + +# Your task +Sample texts from this topic: +[DOCUMENTS] + +Keywords: [KEYWORDS] + +Based on the information above, extract a short topic label (three words at most) in the following format: +topic: +""" + +DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts." +DEFAULT_JSON_SCHEMA = { + "properties": { + "topic_label": { + "description": "A short, human-readable label that summarizes the main topic.", + "title": "Topic Label", + "type": "string", + } + }, + "required": ["topic_label"], + "title": "Topic", + "type": "object", +} diff --git a/bertopic/representation/_textgeneration.py b/bertopic/representation/_textgeneration.py index 8c8e36f6..e9639fa4 100644 --- a/bertopic/representation/_textgeneration.py +++ b/bertopic/representation/_textgeneration.py @@ -4,17 +4,11 @@ from transformers import pipeline, set_seed from transformers.pipelines.base import Pipeline from typing import Mapping, List, Tuple, Any, Union, Callable -from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters +from bertopic.representation._base import LLMRepresentation +from bertopic.representation._prompts import DEFAULT_CHAT_PROMPT -DEFAULT_PROMPT = """ -I have a topic described by the following keywords: [KEYWORDS]. -The name of this topic is: -""" - - -class TextGeneration(BaseRepresentation): +class TextGeneration(LLMRepresentation): """Text2Text or text generation with transformers. Arguments: @@ -23,7 +17,7 @@ class TextGeneration(BaseRepresentation): For example, `pipeline('text-generation', model='gpt2')`. If a string is passed, "text-generation" will be selected by default. prompt: The prompt to be used in the model. If no prompt is given, - `self.default_prompt_` is used instead. + `bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead. NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt to decide where the keywords and documents need to be inserted. @@ -93,8 +87,18 @@ def __init__( doc_length: int | None = None, tokenizer: Union[str, Callable] | None = None, ): + super().__init__( + prompt=prompt if prompt is not None else DEFAULT_CHAT_PROMPT, + nr_docs=nr_docs, + diversity=diversity, + doc_length=doc_length, + tokenizer=tokenizer, + ) + + # Transformer specific parameters self.random_state = random_state set_seed(random_state) + if isinstance(model, str): self.model = pipeline("text-generation", model=model) elif isinstance(model, Pipeline): @@ -105,16 +109,7 @@ def __init__( "pass is either a string referring to a" "HF model or a `transformers.pipeline` object." ) - self.prompt = prompt if prompt is not None else DEFAULT_PROMPT - self.default_prompt_ = DEFAULT_PROMPT self.pipeline_kwargs = pipeline_kwargs - self.nr_docs = nr_docs - self.diversity = diversity - self.doc_length = doc_length - self.tokenizer = tokenizer - validate_truncate_document_parameters(self.tokenizer, self.doc_length) - - self.prompts_ = [] def extract_topics( self, @@ -134,23 +129,14 @@ def extract_topics( Returns: updated_topics: Updated topic representations """ - # Extract the top 4 representative documents per topic - if self.prompt != DEFAULT_PROMPT and "[DOCUMENTS]" in self.prompt: - repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( - c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity - ) - else: - repr_docs_mappings = {topic: None for topic in topics.keys()} + # Extract the top n representative documents per topic + repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( + c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity + ) updated_topics = {} for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - # Prepare prompt - truncated_docs = ( - [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] - if docs is not None - else docs - ) - prompt = self._create_prompt(truncated_docs, topic, topics) + prompt = self._create_prompt(docs=docs, topic=topic, topics=topics, topic_model=topic_model) self.prompts_.append(prompt) # Extract result from generator and use that as label @@ -165,24 +151,3 @@ def extract_topics( updated_topics[topic] = topic_description return updated_topics - - def _create_prompt(self, docs, topic, topics): - keywords = ", ".join(next(zip(*topics[topic]))) - - # Use the default prompt and replace keywords - if self.prompt == DEFAULT_PROMPT: - prompt = self.prompt.replace("[KEYWORDS]", keywords) - - # Use a prompt that leverages either keywords or documents in - # a custom location - else: - prompt = self.prompt - if "[KEYWORDS]" in prompt: - prompt = prompt.replace("[KEYWORDS]", keywords) - if "[DOCUMENTS]" in prompt: - to_replace = "" - for doc in docs: - to_replace += f"- {doc}\n" - prompt = prompt.replace("[DOCUMENTS]", to_replace) - - return prompt