Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions plugins/online-data-mixing/artifacts/custom_loop_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
from fms_acceleration_odm import OnlineMixingDataset
from fms_acceleration_odm.odm.reward import Reward

model_name = "ibm-granite/granite-4.0-h-1b"
model_name = "ibm-granite/granite-4.0-350m"
output_dir = "./odm_custom_use"
max_steps = 125
batch_size = 4
log_file = os.path.join(output_dir, "loss.jsonl")

# odm related
step_idx = 0
update_interval = 1 # every step
update_interval = 10 # every 10 steps

# model
model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -102,7 +102,7 @@ def collate_fn(batch, tokenizer):
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=None)

# distributed setup
dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True)
dataloader_config = DataLoaderConfiguration(dispatch_batches=False)
accelerator = Accelerator(dataloader_config=dataloader_config)
model, dataloader = accelerator.prepare(model, dataloader)

Expand Down Expand Up @@ -141,7 +141,7 @@ class State:
if step_idx % update_interval == 0:
with torch.no_grad():
model.eval()
dataloader.dataset.update_sampling_weights(model, accelerator, state)
dataset.update_sampling_weights(model, accelerator, state)
model.train()
if step_idx > max_steps:
break
Expand Down
7 changes: 5 additions & 2 deletions plugins/online-data-mixing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ license = {text = "Apache-2.0"}
readme = "README.md"
requires-python = "~=3.11"
keywords = ['fms-hf-tuning', 'acceleration', 'online-data-mixing']
classifiers=[
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
]

dependencies = [
"torch==2.8.0",
"torchvision==0.23.0",
"torchaudio==2.8.0",
"scikit-learn",
"datasets==4.*",
"torchdata==0.11.0",
Expand All @@ -43,4 +46,4 @@ include = [
"src",
"pyproject.toml",
"README.md",
]
]
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
# Third Party
from datasets import Dataset, DatasetDict
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import numpy as np
import torch

Expand Down Expand Up @@ -175,6 +174,14 @@ def _cluster_embeddings(
"Unsupported clustering algorithm '%s'. Only 'kmeans' is currently supported."
% self.config.cluster_algo
)

try:
from cuml import KMeans # pylint: disable=import-outside-toplevel
print("Using GPU accelerated Kmeans")
except ImportError:
print("GPU accelerated KMeans is not avaialble. Falling back to CPU based KMeans")
from sklearn.cluster import KMeans # pylint: disable=import-outside-toplevel

kwargs = {"n_init": 10}
kwargs.update(self.config.cluster_kwargs)
model = KMeans(n_clusters=num_categories, **kwargs)
Expand Down