From 601feecf6d2d1c63a065bdc8a894e52a5111b3a6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 19 May 2026 14:00:33 -0400 Subject: [PATCH 1/2] Add GSPO loss Group Sequence Policy Optimization: per-segment geometric-mean IS-ratio clipping. Mirrors GRPO's structure via shared abstract bases (LanguageModelPolicyGradientLossConfig / LanguageModelPolicyGradientLoss); the kernel matches GRPO except for a segment-aggregation block that produces per-segment R and A and broadcasts them back, so the softmax-chain backward is identical to GRPO. SDP-aware via optional all-reduce of segment sums; per-token weighting (mask / token_count_s) lets the SUM reduction at LossDef level give the canonical result without further correction. PyTorch kernel only; no Triton variant yet. Also lifts document_index_q/k from MixerKwargs to BlockKwargs so the LM head can read them without cross-namespace coupling, and renames grpo.py -> policy_gradient.py. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/config.py | 2 - fast_llm/layers/block/config.py | 2 + fast_llm/layers/language_model/loss/config.py | 42 +++- .../loss/{grpo.py => policy_gradient.py} | 209 ++++++++++++++++-- tests/layers/test_lm_losses.py | 126 ++++++++++- 5 files changed, 357 insertions(+), 24 deletions(-) rename fast_llm/layers/language_model/loss/{grpo.py => policy_gradient.py} (60%) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 7282de090..c8840d0f9 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -21,8 +21,6 @@ class MixerKwargs(BlockKwargs): cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" - document_index_q = "document_index_q" - document_index_k = "document_index_k" position_ids = "position_ids" diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 25c5fcc82..9093f730c 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -40,6 +40,8 @@ class BlockKwargs: # TODO: These are confusing sequence_length = "sequence_length" lengths = "lengths" + document_index_q = "document_index_q" + document_index_k = "document_index_k" # TODO: Belongs elsewhere? grad_output = "grad_output" activation_distillation_targets = "activation_distillation_targets" diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 8e9594534..9cdb8c962 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -9,15 +9,17 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass - from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss from fast_llm.layers.language_model.loss.entropy_loss import ( LanguageModelDistillationLoss, LanguageModelLabelEntropyLoss, ) - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss from fast_llm.layers.language_model.loss.loss import LanguageModelLoss + from fast_llm.layers.language_model.loss.policy_gradient import ( + LanguageModelGRPOLoss, + LanguageModelGSPOLoss, + LanguageModelPolicyGradientLoss, + ) from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss @@ -209,13 +211,26 @@ class GRPOMetricsLevel(enum.StrEnum): with_entropy = "with_entropy" -@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) -class LanguageModelGRPOLossConfig(LanguageModelLossConfig): +@config_class() +class LanguageModelPolicyGradientLossConfig(LanguageModelLossConfig): + """Shared base for policy-gradient losses (GRPO, GSPO).""" - _abstract: typing.ClassVar[bool] = False + _abstract: typing.ClassVar[bool] = True epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") + + @property + def loss_class(self) -> "type[LanguageModelPolicyGradientLoss]": + raise NotImplementedError() + + +@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) +class LanguageModelGRPOLossConfig(LanguageModelPolicyGradientLossConfig): + """Group-Relative Policy Optimization: per-token IS-ratio clipping.""" + + _abstract: typing.ClassVar[bool] = False + use_triton: bool | None = Field( default=None, desc="Enable triton implementation. Default: use if available.", @@ -234,6 +249,19 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + from fast_llm.layers.language_model.loss.policy_gradient import LanguageModelGRPOLoss return LanguageModelGRPOLoss + + +@config_class(dynamic_type={LanguageModelLossConfig: "gspo"}) +class LanguageModelGSPOLossConfig(LanguageModelPolicyGradientLossConfig): + """Group Sequence Policy Optimization: sequence-level geometric-mean IS-ratio clipping.""" + + _abstract: typing.ClassVar[bool] = False + + @property + def loss_class(self) -> "type[LanguageModelGSPOLoss]": + from fast_llm.layers.language_model.loss.policy_gradient import LanguageModelGSPOLoss + + return LanguageModelGSPOLoss diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/policy_gradient.py similarity index 60% rename from fast_llm/layers/language_model/loss/grpo.py rename to fast_llm/layers/language_model/loss/policy_gradient.py index 4bbaeb581..0e190280a 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/policy_gradient.py @@ -9,11 +9,14 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import ( GRPOMetricsLevel, LanguageModelGRPOLossConfig, + LanguageModelGSPOLossConfig, LanguageModelLossKwargs, + LanguageModelPolicyGradientLossConfig, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.utils import Assert @@ -33,7 +36,42 @@ class GRPOMetrics(typing.NamedTuple): entropy: torch.Tensor | None -class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): +class LanguageModelPolicyGradientLoss[ConfigType: LanguageModelPolicyGradientLossConfig]( + LanguageModelLoss[ConfigType] +): + """Shared scaffolding for policy-gradient losses (GRPO, GSPO). + + Subclasses implement `_forward_backward` to call their specific kernel. + """ + + def _register_new_logprobs( + self, + new_logprobs_mean: torch.Tensor | None, + kwargs: dict[str, typing.Any], + losses: dict | None, + ) -> None: + if new_logprobs_mean is not None: + new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] + self._register_loss( + self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM + ) + + def get_loss_definitions(self) -> list[LossDef]: + defs = super().get_loss_definitions() + defs.append(LossDef(self._logprob_metric_name)) + return defs + + def get_preprocessing_config(self) -> dict[str, typing.Any]: + return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + + @functools.cached_property + def _logprob_metric_name(self) -> str: + return f"{self._name}_new_logprobs" + + +class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): + """GRPO: per-token IS-ratio clipping.""" + def __init__( self, config: ConfigType, @@ -99,11 +137,7 @@ def _forward_backward( divisor=self._get_label_count(kwargs), ) - if new_logprobs_mean is not None: - new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] - self._register_loss( - self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM - ) + self._register_new_logprobs(new_logprobs_mean, kwargs, losses) # Skip the extra softmax pass when there is nothing to register. if losses is not None and self._config.metrics != GRPOMetricsLevel.none: @@ -167,7 +201,6 @@ def _register_extra_metrics( def get_loss_definitions(self) -> list[LossDef]: defs = super().get_loss_definitions() - defs.append(LossDef(self._logprob_metric_name)) if self._config.metrics != GRPOMetricsLevel.none: defs.extend( [ @@ -187,14 +220,59 @@ def get_loss_definitions(self) -> list[LossDef]: defs.append(LossDef(f"{self._name}_entropy")) return defs - def get_preprocessing_config( + +class LanguageModelGSPOLoss[ConfigType: LanguageModelGSPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): + """GSPO: sequence-level geometric-mean IS-ratio clipping.""" + + def _forward_backward( self, - ) -> dict[str, typing.Any]: - return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + document_index = self._prepare_target( + kwargs[BlockKwargs.document_index_q], split_index, split_by_distance=False + ) + # `document_index_q` is 1-based per the data preprocessor convention; shift to 0-based. + document_index = document_index.long() - 1 + # Buffer size must agree across SDP ranks (sharing the sequence) for the per-segment + # all-reduce inside the kernel; MAX-reduce the local max here. + local_max = int(document_index.max().item()) if document_index.numel() > 0 else -1 + if self._sdp_active: + max_buffer = document_index.new_tensor(local_max) + torch.distributed.all_reduce(max_buffer, op=torch.distributed.ReduceOp.MAX, group=self._sdp_dim.group) + local_max = int(max_buffer.item()) + num_segments = local_max + 1 + + loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + document_index, + num_segments, + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + sdp_group=self._sdp_dim.group if self._sdp_active else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._logits_scale_factor, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + divisor=kwargs[LanguageModelKwargs.num_documents_in_batch], + ) - @functools.cached_property - def _logprob_metric_name(self) -> str: - return f"{self._name}_new_logprobs" + self._register_new_logprobs(new_logprobs_mean, kwargs, losses) + return loss, grad + + def get_preprocessing_config(self) -> dict[str, typing.Any]: + return super().get_preprocessing_config() | {"return_document_index": True} @torch.compile @@ -326,3 +404,108 @@ def fused_grpo_loss_forward_backward( grad_logits.add_(grad) return loss, grad_logits, new_logprobs_mean + + +def fused_gspo_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + document_index: torch.Tensor, # (*batch,) int — 0-based segment ID per token + num_segments: int, # buffer size, ≥ document_index.max() + 1 + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, # TP vocab group + sdp_group: torch.distributed.ProcessGroup | None = None, # SDP group for cross-rank segment aggregation + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + num_labels_in_seq: torch.Tensor | None = None, + divisor: float | None = None, # default: num_segments +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """GSPO loss: sequence-level geometric-mean IS-ratio clipping. + + Per-segment ratio R_s = exp(mean_t(log p_new_t / p_old_t)), clipped per segment. + Per-segment loss = -min(R_s * A_s, clip(R_s) * A_s), summed over segments and divided by `divisor`. + + Computed as an equivalent per-token sum: each token contributes `(1 / token_count_s) * -min(...)`, + making the gradient chain (softmax → log_prob → log_ratio → R_s) identical in structure to GRPO, + only with the per-token IS ratio replaced by the segment-broadcast R_s, the per-token advantage + by the segment-mean A_s, and the loss mask scaled by 1 / token_count_s. + + With `sdp_group`, segment sums are all-reduced so each rank computes the same global R_s/A_s. + Token-level contributions remain per-rank, so summing the kernel loss across SDP via SUM reduction + matches the canonical single-rank result without further correction. + """ + if divisor is None: + divisor = float(num_segments) if num_segments > 0 else 1.0 + grad_output_scaled = None if grad_output is None else grad_output / divisor * logits_scale_factor + loss_mask = target >= 0 + + # === Setup phase (identical to GRPO) === + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels( + logits_norm, target, loss_mask, group + ) + new_log_probs = predicted_logits - sum_exp_logits.log() + log_ratio = new_log_probs - old_log_probabilities + + # === Segment aggregation === + flat_doc = document_index.reshape(-1).long() + flat_mask = loss_mask.reshape(-1).to(log_ratio.dtype) + log_ratio_sum = log_ratio.new_zeros(num_segments).index_add_( + 0, flat_doc, (log_ratio.reshape(-1) * flat_mask).to(log_ratio.dtype) + ) + advantage_sum = advantages.new_zeros(num_segments).index_add_( + 0, flat_doc, (advantages.reshape(-1) * flat_mask).to(advantages.dtype) + ) + token_count = log_ratio.new_zeros(num_segments).index_add_(0, flat_doc, flat_mask) + if sdp_group is not None: + torch.distributed.all_reduce(log_ratio_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(advantage_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(token_count, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + + safe_count = token_count.clamp(min=1) + segment_ratio = (log_ratio_sum / safe_count).exp() # (num_segments,) — geometric-mean IS ratio + segment_advantage = (advantage_sum / safe_count).detach() # (num_segments,) — no grad through A + inv_token_count = torch.where(token_count > 0, 1.0 / safe_count, token_count.new_zeros(())) + + # Broadcast back to per-token + probability_ratio = segment_ratio[flat_doc].reshape(log_ratio.shape) + seg_advantage = segment_advantage[flat_doc].reshape(log_ratio.shape) + # Per-token weight 1/token_count_s so each segment contributes once to the sum. + token_weight = flat_mask.reshape(log_ratio.shape) * inv_token_count[flat_doc].reshape(log_ratio.shape) + + # === Per-token loss (mirrors GRPO; mask folded into token_weight) === + losses = -torch.min( + probability_ratio * seg_advantage, + torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * seg_advantage, + ) + loss = (losses * token_weight).sum() / divisor + + new_logprobs_mean = ( + None if num_labels_in_seq is None else (new_log_probs * loss_mask / num_labels_in_seq.clamp(min=1)).sum() + ) + + if grad_output_scaled is not None: + probability_ratio_grad = ( + grad_output_scaled + * ( + torch.clamp_min(seg_advantage, 0) * (probability_ratio <= 1 + epsilon_high) + + torch.clamp_max(seg_advantage, 0) * (probability_ratio >= 1 - epsilon_low) + ) + * token_weight + ) + predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze_(-1) + grad = (probability_ratio_grad * probability_ratio).unsqueeze(-1) * predicted_probabilities.scatter_add( + -1, + target_masked.unsqueeze(-1), + -(loss_mask if target_mask is None else target_mask).unsqueeze(-1).to(torch.float32), + ) + grad = grad.to(logits.dtype) + if grad_logits is None: + grad_logits = grad + else: + grad_logits.add_(grad) + + return loss, grad_logits, new_logprobs_mean diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 19200476a..8f1cf9424 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -16,12 +16,13 @@ from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.grpo import ( +from fast_llm.layers.language_model.loss.loss import loss_forward_backward +from fast_llm.layers.language_model.loss.policy_gradient import ( GRPOMetrics, compute_grpo_metrics, fused_grpo_loss_forward_backward, + fused_gspo_loss_forward_backward, ) -from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert from tests.utils.dataset import get_random_spans @@ -167,6 +168,49 @@ def reference_grpo_metrics( ) +def reference_gspo_loss( + logits: torch.Tensor, + labels: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + document_index: torch.Tensor, + num_segments: int, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + logits_ = logits.float() + loss_mask = labels >= 0 + labels_safe = labels * loss_mask + target_log_probabilities = ( + torch.nn.functional.log_softmax(logits_ * logits_scale_factor, dim=-1) + .gather(dim=-1, index=labels_safe.unsqueeze(-1)) + .squeeze(-1) + ) + log_ratio = target_log_probabilities - old_log_probabilities + + flat_doc = document_index.reshape(-1) + flat_mask = loss_mask.reshape(-1) + flat_log_ratio = log_ratio.reshape(-1) + flat_advantages = advantages.reshape(-1) + + total = log_ratio.new_zeros(()) + for segment in range(num_segments): + in_segment = (flat_doc == segment) & flat_mask + count = in_segment.sum() + if int(count) == 0: + continue + ratio = (flat_log_ratio[in_segment].sum() / count.float()).exp() + advantage = (flat_advantages[in_segment].sum() / count.float()).detach() + clipped_ratio = ratio.clamp(1 - epsilon_low, 1 + epsilon_high) + total = total + -torch.minimum(ratio * advantage, clipped_ratio * advantage) + total = total / max(num_segments, 1) + + log_probs = torch.nn.functional.log_softmax(logits_, -1).gather(-1, labels_safe.unsqueeze(-1)).squeeze(-1) + new_logprobs = (log_probs * loss_mask).sum() / max(float(loss_mask.sum()), 1.0) + return total, new_logprobs + + def reference_grpo_loss( logits: torch.Tensor, labels: torch.Tensor, @@ -350,6 +394,55 @@ def _test_grpo_loss( Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) +def _test_gspo_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + num_segments, + accumulate, + group=None, +): + logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( + num_columns, loss_masking, batch_shape, dtype + ) + # Build per-token segment IDs by partitioning each batch row into `num_segments` contiguous spans. + seq_len = batch_shape[-1] if len(batch_shape) > 1 else batch_shape[0] + span = max(seq_len // num_segments, 1) + base = torch.arange(seq_len, device=target.device) // span + document_index = base.clamp(max=num_segments - 1).expand(batch_shape).contiguous() + out_ref, grad_ref = loss_forward_backward( + grad_output, + lambda *args, **kwargs: reference_gspo_loss(*args, **kwargs)[0], + logits, + target, + advantages, + old_log_probabilities, + document_index, + num_segments, + logits_scale_factor=logits_scale_factor, + ) + if accumulate: + previous_grad = torch.randn_like(grad_ref) + grad_ref = grad_ref + previous_grad + local_previous_grad = split_op(previous_grad, group, -1).contiguous() + out_fused, grad_fused, new_logprobs_fused = fused_gspo_loss_forward_backward( + split_op(logits, group, -1), + target, + advantages, + old_log_probabilities, + document_index, + num_segments, + grad_logits=local_previous_grad.clone() if accumulate else None, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + ) + _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + + def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> None: for name in GRPOMetrics._fields: ref_value = getattr(ref, name) @@ -514,6 +607,35 @@ def test_grpo_loss( ) +_GSPO_PARAMETERS = ( + # (num_columns, grad_output, logits_scale_factor, loss_masking, dtype, num_segments, accumulate) + (500, 1.0, 1.0, False, DataType.float32, 4, False), # Simple + (256, 1.0, 1.0, False, DataType.float32, 4, False), # Power of 2 + (500, None, 1.0, False, DataType.float32, 4, False), # No grad + (500, 1.0, 1.0, False, DataType.float32, 4, True), # Accumulate + (500, 1.0, 4.0, False, DataType.float32, 4, False), # Loss scaling + (500, 4.0, 1.0, False, DataType.float32, 4, False), # Grad scaling + (500, 1.0, 1.0, True, DataType.float32, 4, False), # Loss masking + (500, 1.0, 1.0, False, DataType.float16, 4, False), # Fp16 + (500, 1.0, 1.0, False, DataType.float32, 1, False), # One segment + (500, 1.0, 1.0, True, DataType.float32, 16, True), # Many segments + masking + accumulate +) + + +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "num_segments", "accumulate"), + _GSPO_PARAMETERS, +) +def test_gspo_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, num_segments, accumulate +): + _test_gspo_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, num_segments, accumulate + ) + + @pytest.mark.slow @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( From 82d223deae2d284cd66f33efa34ea761f04a8c13 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 19 May 2026 17:10:30 -0400 Subject: [PATCH 2/2] Add Triton GSPO kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Triton backward kernel mirrors GRPO's backward — same softmax chain rule through (softmax_k - delta_{k,target}), with the per-token IS ratio replaced by the segment-broadcast R_{s(t)} and the loss mask scaled by 1/token_count_s (token_weight). The forward pass reuses the existing triton_cross_entropy_forward_from_labels_parallel_kernel to produce max/sum/predicted_logit per token (with TP support via the same parallel_sum_exp_logits dance as GRPO); segment aggregation, loss, and the SDP all-reduce live in PyTorch between the two Triton passes. Triton is opt-in via a new LanguageModelGSPOLossConfig.use_triton field (mirrors GRPO config). Test coverage: `test_gspo_loss` now also runs the Triton path when available — 20 cases pass under TRITON_INTERPRET=1. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/functional/triton/gspo_loss.py | 218 ++++++++++++++++++ fast_llm/layers/language_model/loss/config.py | 6 + .../language_model/loss/policy_gradient.py | 8 +- tests/layers/test_lm_losses.py | 17 ++ 4 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 fast_llm/functional/triton/gspo_loss.py diff --git a/fast_llm/functional/triton/gspo_loss.py b/fast_llm/functional/triton/gspo_loss.py new file mode 100644 index 000000000..a010447bb --- /dev/null +++ b/fast_llm/functional/triton/gspo_loss.py @@ -0,0 +1,218 @@ +import torch + +from fast_llm.core.distributed import ReduceOp, all_reduce +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton.entropy_loss import ( + parallel_sum_exp_logits, + triton_cross_entropy_forward_from_labels_parallel_kernel, +) + + +@triton_jit() +def triton_gspo_loss_backward_kernel( + logits_ptr, + labels_ptr, + max_logits_ptr, + sum_exp_logits_ptr, + probability_ratio_ptr, + seg_advantage_ptr, + token_weight_ptr, + grad_logits_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + grad_logits_stride_0: tl_constexpr, + block_size: tl_constexpr, + grad_losses, + col_min: tl_constexpr = 0, + logits_scale_factor: tl_constexpr = 1.0, + epsilon_low: tl_constexpr = 0.2, + epsilon_high: tl_constexpr = 0.2, + accumulate: tl_constexpr = False, +): + block_idx = tl.program_id(0).to(tl.int64) + + # token_weight = mask_t / token_count_{s(t)}; zero for masked tokens and empty segments. + token_weight = tl.load(token_weight_ptr + block_idx).to(tl.float32) + if token_weight == 0.0: + if not accumulate: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + label_idx = tl.load(labels_ptr + block_idx) - col_min + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + probability_ratio = tl.load(probability_ratio_ptr + block_idx).to(tl.float32) + seg_advantage = tl.load(seg_advantage_ptr + block_idx).to(tl.float32) + + # effective_grad = grad_losses * scale * weight * R_s * clip_factor + # clip_factor = clamp_min(A_s, 0) * (R_s <= 1+eps_h) + clamp_max(A_s, 0) * (R_s >= 1-eps_l) + grad_scale = grad_losses + if logits_scale_factor != 1.0: + grad_scale *= logits_scale_factor + effective_grad = ( + ( + tl.maximum(seg_advantage, 0.0) * (probability_ratio <= 1.0 + epsilon_high) + + tl.minimum(seg_advantage, 0.0) * (probability_ratio >= 1.0 - epsilon_low) + ) + * probability_ratio + * grad_scale + * token_weight + ) + + logits_ptr = logits_ptr + block_idx * logits_stride_0 + + # grad_logit_i = effective_grad * (softmax_i - delta_{i, label}) + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + prob = tl.exp(logits - max_logits) / sum_exp_logits + if label_idx < 0 or label_idx >= n_cols: + # Target not in this TP shard. + grad = effective_grad * prob + else: + grad = effective_grad * tl.where(col_offsets == label_idx, prob - 1.0, prob) + grad_col_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets + if accumulate: + grad += tl.load(grad_col_ptr, mask=mask) + tl.store(grad_col_ptr, grad, mask=mask) + + +def triton_gspo_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + document_index: torch.Tensor, # (*batch,) int — 0-based segment ID per token + num_segments: int, # buffer size, >= document_index.max() + 1 + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, + sdp_group: torch.distributed.ProcessGroup | None = None, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + num_labels_in_seq: torch.Tensor | None = None, + divisor: float | None = None, + block_size: int | None = None, + num_warps: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Triton GSPO loss. Forward fuses softmax + predicted-logit lookup; backward fuses the + softmax chain rule with the per-token GSPO gradient factor (R_s * clip * token_weight). + Segment aggregation, loss, and the SDP all-reduce live in PyTorch between the two passes. + + See `fused_gspo_loss_forward_backward` in policy_gradient.py for the math derivation; + this kernel produces identical outputs. + """ + assert logits.is_contiguous() + assert target.is_contiguous() + assert advantages.is_contiguous() + assert old_log_probabilities.is_contiguous() + assert document_index.is_contiguous() + + n_rows = logits.shape[:-1].numel() + n_cols = logits.size(-1) + if divisor is None: + divisor = float(num_segments) if num_segments > 0 else 1.0 + if block_size is None: + block_size = min(triton.next_power_of_2(n_cols), 32768) + if num_warps is None: + num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + col_min = n_cols * group.rank() if group is not None else 0 + + # === Forward (Triton): per-token softmax, save max / sum / predicted_logit === + max_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + sum_exp_logits = torch.empty_like(max_logits) + predicted_logits = torch.empty_like(max_logits) + triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( + logits, + target, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + predicted_logits_ptr=predicted_logits, + col_min=col_min, + n_cols=n_cols, + logits_stride_0=logits.stride(-2), + block_size=block_size, + num_warps=num_warps, + logits_scale_factor=logits_scale_factor, + ) + if group is not None: + # Merge per-shard local max / sum_exp into global values. + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, max_logits, group) + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + + # === Segment aggregation (PyTorch) === + flat_target = target.reshape(-1) + flat_doc = document_index.reshape(-1).long() + flat_advantages = advantages.reshape(-1).float() + loss_mask = (flat_target >= 0).to(max_logits.dtype) + + new_log_probs = predicted_logits - max_logits - sum_exp_logits.log() + log_ratio = (new_log_probs - old_log_probabilities.reshape(-1).float()) * loss_mask + + new_logprobs_mean = ( + (new_log_probs * loss_mask / num_labels_in_seq.reshape(-1).clamp(min=1).to(new_log_probs.dtype)).sum() + if num_labels_in_seq is not None + else None + ) + + log_ratio_sum = log_ratio.new_zeros(num_segments).index_add_(0, flat_doc, log_ratio) + advantage_sum = log_ratio.new_zeros(num_segments).index_add_(0, flat_doc, flat_advantages * loss_mask) + token_count = log_ratio.new_zeros(num_segments).index_add_(0, flat_doc, loss_mask) + if sdp_group is not None: + torch.distributed.all_reduce(log_ratio_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(advantage_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(token_count, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + + safe_count = token_count.clamp(min=1) + segment_ratio = (log_ratio_sum / safe_count).exp() + segment_advantage = advantage_sum / safe_count + inv_token_count = torch.where(token_count > 0, 1.0 / safe_count, token_count.new_zeros(())) + + probability_ratio = segment_ratio[flat_doc].contiguous() + seg_advantage = segment_advantage[flat_doc].contiguous() + token_weight = (loss_mask * inv_token_count[flat_doc]).contiguous() + + losses = -torch.min( + probability_ratio * seg_advantage, + torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * seg_advantage, + ) + loss = (losses * token_weight).sum() / divisor + + if grad_output is None: + return loss, grad_logits, new_logprobs_mean + + # === Backward (Triton) === + accumulate = grad_logits is not None + grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits + triton_gspo_loss_backward_kernel[(n_rows,)]( + logits, + target, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + probability_ratio_ptr=probability_ratio, + seg_advantage_ptr=seg_advantage, + token_weight_ptr=token_weight, + grad_logits_ptr=grad_logits, + n_cols=n_cols, + logits_stride_0=logits.stride(-2), + grad_logits_stride_0=grad_logits.stride(-2), + block_size=block_size, + grad_losses=grad_output / divisor, + col_min=col_min, + logits_scale_factor=logits_scale_factor, + epsilon_low=epsilon_low, + epsilon_high=epsilon_high, + accumulate=accumulate, + num_warps=num_warps, + ) + + return loss, grad_logits, new_logprobs_mean diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 9cdb8c962..9a220aacf 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -260,6 +260,12 @@ class LanguageModelGSPOLossConfig(LanguageModelPolicyGradientLossConfig): _abstract: typing.ClassVar[bool] = False + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) + @property def loss_class(self) -> "type[LanguageModelGSPOLoss]": from fast_llm.layers.language_model.loss.policy_gradient import LanguageModelGSPOLoss diff --git a/fast_llm/layers/language_model/loss/policy_gradient.py b/fast_llm/layers/language_model/loss/policy_gradient.py index 0e190280a..69dae518e 100644 --- a/fast_llm/layers/language_model/loss/policy_gradient.py +++ b/fast_llm/layers/language_model/loss/policy_gradient.py @@ -246,7 +246,13 @@ def _forward_backward( local_max = int(max_buffer.item()) num_segments = local_max + 1 - loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.gspo_loss import triton_gspo_loss_forward_backward + + fn = triton_gspo_loss_forward_backward + else: + fn = fused_gspo_loss_forward_backward + loss, grad, new_logprobs_mean = fn( logits, self._get_labels(kwargs, split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 8f1cf9424..b79bc51fb 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -14,6 +14,7 @@ from fast_llm.functional.triton import triton_available from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward +from fast_llm.functional.triton.gspo_loss import triton_gspo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.loss import loss_forward_backward @@ -442,6 +443,22 @@ def _test_gspo_loss( ) _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + if not triton_available: + return + out_triton, grad_triton, new_logprobs_triton = triton_gspo_loss_forward_backward( + split_op(logits, group, -1).contiguous(), + target, + advantages, + old_log_probabilities, + document_index, + num_segments, + grad_logits=local_previous_grad.clone() if accumulate else None, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + ) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref, group=group) + def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> None: for name in GRPOMetrics._fields: