Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions docs/source/evaluating-a-custom-model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,46 @@ Create a Python file containing your custom model implementation. The model must
Here's a basic example:

```python
from transformers import AutoTokenizer

from lighteval.models.abstract_model import LightevalModel
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.cache_management import SampleCache, cached

class MyCustomModel(LightevalModel):
def __init__(self, config):
super().__init__(config)
self.config = config
# Initialize your model here...
self._tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Enable caching (recommended)
self._cache = SampleCache(config)

@property
def tokenizer(self):
return self._tokenizer

@property
def add_special_tokens(self) -> bool:
return False

@property
def max_length(self) -> int:
return 2048

@cached(SamplingMethod.GENERATIVE)
def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]:
def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]:
# Implement generation logic
pass

@cached(SamplingMethod.LOGPROBS)
def loglikelihood(self, docs: List[Doc]) -> List[ModelResponse]:
def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
# Implement loglikelihood computation
pass

@cached(SamplingMethod.PERPLEXITY)
def loglikelihood_rolling(self, docs: List[Doc]) -> List[ModelResponse]:
def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]:
# Implement rolling loglikelihood computation
pass
```
Expand All @@ -59,7 +74,7 @@ You can evaluate your custom model using either the command-line interface or th
lighteval custom \
"google-translate" \
"examples/custom_models/google_translate_model.py" \
"wmt20:fr-de \
"lighteval|wmt20:fr-de|0" \
--max-samples 10
```

Expand Down Expand Up @@ -94,7 +109,7 @@ model_config = CustomModelConfig(

# Create and run the pipeline
pipeline = Pipeline(
tasks=truthfulqa:mc,
tasks="truthfulqa:mc",
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config
Expand Down Expand Up @@ -174,17 +189,18 @@ from lighteval.utils.cache_management import SampleCache, cached
### Step 2: Initialize Cache in Constructor
```python
def __init__(self, config):
super().__init__(config)
self.config = config
# Your initialization code...
self._cache = SampleCache(config)
```

3. Add cache decorators to your prediction methods:
```python
@cached(SamplingMethod.GENERATIVE)
def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]:
# Your implementation...
```
### Step 3: Add Cache Decorators
```python
@cached(SamplingMethod.GENERATIVE)
def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]:
# Your implementation...
pass
```

For detailed information about the caching system, see the [Caching Documentation](caching).

Expand Down
26 changes: 22 additions & 4 deletions src/lighteval/models/custom/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CustomModelConfig(ModelConfig):
This class will be automatically detected and instantiated when loading the model.

Args:
model (str):
model_name (str):
An identifier for the model. This can be used to track which model was evaluated
in the results and logs.

Expand All @@ -46,24 +46,42 @@ class CustomModelConfig(ModelConfig):
```python
# Define config
config = CustomModelConfig(
model="my-custom-model",
model_name="my-custom-model",
model_definition_file_path="path/to/my_model.py"
)

# Example custom model file (my_model.py):
from lighteval.models.abstract_model import LightevalModel
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.requests import Doc

class MyCustomModel(LightevalModel):
def __init__(self, config, env_config):
super().__init__(config, env_config)
def __init__(self, config):
self.config = config
# Custom initialization...

@property
def tokenizer(self):
# Return the tokenizer used by your model
...

@property
def add_special_tokens(self) -> bool:
return False

@property
def max_length(self) -> int:
return 2048

def greedy_until(self, docs: list[Doc]) -> list[ModelResponse]:
# Custom generation logic...
pass

def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
pass

def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]:
pass
```

An example of a custom model can be found in `examples/custom_models/google_translate_model.py`.
Expand Down