Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ docs/_build/

# Jupyter Notebook
.ipynb_checkpoints
*.ipynb
notebooks/

# IPython
Expand Down
8 changes: 8 additions & 0 deletions bertopic/representation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
msg = "`pip install openai` \n\n"
OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)

# Ollama Generator
try:
from bertopic.representation._ollama import Ollama
except ModuleNotFoundError:
msg = "`pip install ollama` \n\n"
Ollama = NotInstalled("Ollama", "ollama", custom_msg=msg)

# LiteLLM Generator
try:
from bertopic.representation._litellm import LiteLLM
Expand Down Expand Up @@ -68,6 +75,7 @@
"LiteLLM",
"LlamaCPP",
"MaximalMarginalRelevance",
"Ollama",
"OpenAI",
"PartOfSpeech",
"TextGeneration",
Expand Down
115 changes: 114 additions & 1 deletion bertopic/representation/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pandas as pd
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator
from typing import Mapping, List, Tuple
from typing import Mapping, List, Tuple, Union, Callable

from bertopic.representation._prompts import DEFAULT_CHAT_PROMPT
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


class BaseRepresentation(BaseEstimator):
Expand Down Expand Up @@ -38,3 +41,113 @@ def extract_topics(
model is used.
"""
return topic_model.topic_representations_


class LLMRepresentation(BaseRepresentation):
"""Base class for LLM-based representation models."""

def __init__(
self,
prompt: str | None = None,
nr_docs: int = 4,
diversity: float | None = None,
doc_length: int | None = None,
tokenizer: Union[str, Callable] | None = None,
):
"""Generate a representation model that leverages LLMs.

Arguments:
prompt: The prompt to be used in the model. If no prompt is given,
`bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
nr_docs: The number of documents to pass to Ollama if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to Ollama.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
"""
self.prompt = prompt

# Representative document extraction parameters
self.nr_docs = nr_docs
self.diversity = diversity

# Document truncation
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

# Store prompts for inspection
self.prompts_ = []

def _create_prompt(
self, docs: list[str], topic: int, topics: Mapping[str, List[Tuple[str, float]]], topic_model
) -> str:
"""Create prompt for LLM by either using the default prompt or replacing custom tags.
Specifically, [KEYWORDS] and [DOCUMENTS] can be used in custom prompts to insert the topic's keywords and most representative documents.

Arguments:
docs: The most representative documents for a given topic.
topic: The topic for which to create the prompt.
topics: A dictionary with topic (key) and tuple of word and
weight (value) as calculated by c-TF-IDF.
topic_model: The BERTopic model that is fitted until topic
representations are calculated.

Returns:
prompt: The created prompt.
"""
keywords = next(zip(*topics[topic]))

# Use the Default Chat Prompt
if self.prompt == DEFAULT_CHAT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs, topic_model)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs, topic_model)

return prompt

def _replace_documents(self, prompt: str, docs: list[str], topic_model) -> str:
"""Replace [DOCUMENTS] tag in prompt with actual documents.

Arguments:
prompt: The prompt containing the [DOCUMENTS] tag.
docs: The most representative documents for a given topic.
topic_model: The BERTopic model that is fitted until topic
representations are calculated.

Returns:
The prompt with the [DOCUMENTS] tag replaced by actual documents.
"""
# Truncate documents if needed
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]

# Replace tag with documents
formatted_docs = "\n".join([f"- {doc}" for doc in truncated_docs])
return prompt.replace("[DOCUMENTS]", formatted_docs)
94 changes: 17 additions & 77 deletions bertopic/representation/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,11 @@
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


DEFAULT_PROMPT = """
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
---
Topic:
Sample texts from this topic:
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
- Meat, but especially beef, is the word food in terms of emissions.
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.

Keywords: meat beef eat eating emissions steak food health processed chicken
Topic name: Environmental impacts of eating meat
---
Topic:
Sample texts from this topic:
- I have ordered the product weeks ago but it still has not arrived!
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
- I got a message stating that I received the monitor but that is not true!
- It took a month longer to deliver than was advised...

Keywords: deliver weeks product shipping long delivery received arrived arrive week
Topic name: Shipping and delivery issues
---
Topic:
Sample texts from this topic:
[DOCUMENTS]
Keywords: [KEYWORDS]
Topic name:"""

DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."


class Cohere(BaseRepresentation):
from bertopic.representation._base import LLMRepresentation
from bertopic.representation._prompts import DEFAULT_SYSTEM_PROMPT, DEFAULT_CHAT_PROMPT


class Cohere(LLMRepresentation):
"""Use the Cohere API to generate topic labels based on their
generative model.

Expand All @@ -49,12 +18,12 @@ class Cohere(BaseRepresentation):
client: A `cohere.Client`
model: Model to use within Cohere, defaults to `"xlarge"`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
`bertopic.representation._prompts.DEFAULT_CHAT_PROMPT` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
system_prompt: The system prompt to be used in the model. If no system prompt is given,
`self.default_system_prompt_` is used instead.
`bertopic.representation._prompts.DEFAULT_SYSTEM_PROMPT` is used instead.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
nr_docs: The number of documents to pass to OpenAI if a prompt
Expand Down Expand Up @@ -120,20 +89,19 @@ def __init__(
doc_length: int | None = None,
tokenizer: Union[str, Callable] | None = None,
):
super().__init__(
prompt=prompt if prompt is not None else DEFAULT_CHAT_PROMPT,
nr_docs=nr_docs,
diversity=diversity,
doc_length=doc_length,
tokenizer=tokenizer,
)

# Cohere specific parameters
self.client = client
self.model = model
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
self.delay_in_seconds = delay_in_seconds
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

def extract_topics(
self,
Expand Down Expand Up @@ -161,8 +129,7 @@ def extract_topics(
# Generate using Cohere's Language Model
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
prompt = self._create_prompt(docs=docs, topic=topic, topics=topics, topic_model=topic_model)
self.prompts_.append(prompt)

# Delay
Expand All @@ -180,30 +147,3 @@ def extract_topics(
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]

return updated_topics

def _create_prompt(self, docs, topic, topics):
keywords = next(zip(*topics[topic]))

# Use the Default Chat Prompt
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)

return prompt

@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
27 changes: 14 additions & 13 deletions bertopic/representation/_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from scipy.sparse import csr_matrix
from typing import Callable, Mapping, List, Tuple, Union

from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
from bertopic.representation._base import LLMRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._prompts import DEFAULT_COMPLETION_PROMPT

DEFAULT_PROMPT = "What are these documents about? Please give a single label."


class LangChain(BaseRepresentation):
class LangChain(LLMRepresentation):
"""Using chains in langchain to generate topic labels.

The classic example uses `langchain.chains.question_answering.load_qa_chain`.
Expand All @@ -22,7 +21,7 @@ class LangChain(BaseRepresentation):
Input keys must be `input_documents` and `question`.
Output key must be `output_text`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
`bertopic.representation._prompts.DEFAULT_COMPLETION_PROMPT` is used instead.
NOTE: Use `"[KEYWORDS]"` in the prompt
to decide where the keywords need to be
inserted. Keywords won't be included unless
Expand Down Expand Up @@ -140,15 +139,17 @@ def __init__(
tokenizer: Union[str, Callable] | None = None,
chain_config=None,
):
super().__init__(
prompt=prompt if prompt is not None else DEFAULT_COMPLETION_PROMPT,
nr_docs=nr_docs,
diversity=diversity,
doc_length=doc_length,
tokenizer=tokenizer,
)

# LangChain specific parameters
self.chain = chain
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.chain_config = chain_config
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

def extract_topics(
self,
Expand Down
Loading
Loading