Open
Conversation
e1c4622 to
3ae3fe6
Compare
ghstack-source-id: 2ad7bfb Pull-Request: InternLM#1516
…eModel - Remove float8_handler as a direct attribute of TrainEngine - Add float8_handler as a lazy-initialized property in BaseModel - Move Float8Handler.build() logic to Float8Config.build() - Update _maybe_precompute_float8_dynamic_scale_for_fsdp to iterate through model modules ghstack-source-id: b3e7123 Pull-Request: InternLM#1517
…layer
Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models
This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types
- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
- `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
- `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system
- **BaseModel**:
- Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
- Implement default `pre_micro_batch_forward()` to compute token statistics
- Implement default `post_micro_batch_forward()` to aggregate losses and extra info
- Add overload type hints for `__call__` to improve type inference
- **MoE Model**:
- Override `post_micro_batch_forward()` to handle MoE-specific logic:
- Compute maxvio for router load balancing
- Update router bias based on expert load
- Add `need_update_bias` property for cleaner code
- Properly scale balancing_loss and z_loss by batch_size
- **ComposeModel**:
- Override `pre_micro_batch_forward()` to compute image token statistics
- Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field
- **TrainEngine**:
- Simplify `train_step()` to delegate statistics to model hooks
- Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
- Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
- Remove all model-specific branching logic
- **EngineConfig**:
- Remove conditional logic for VisionComposeTrainEngine
- Use single TrainEngine.build() path
- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic
- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types
Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)
`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.
The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer
1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain
- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed
ghstack-source-id: ba87e63
Pull-Request: InternLM#1518
3ae3fe6 to
6dcb97b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
ModelOutputspydanticBaseModel#1516Previously, we had two separate train engines:
TrainEnginefor regular modelsVisionComposeTrainEnginefor vision-language modelsThis duplication led to:
Code maintenance overhead (242 lines of duplicated logic)
Tight coupling between engine and model-specific details
Difficulty in extending to new model types
Remove
VisionComposeTrainEngineentirely (242 lines deleted)Add
pre_micro_batch_forward()andpost_micro_batch_forward()hooks toBaseModelpre_micro_batch_forward(): Compute data batch statistics before forward passpost_micro_batch_forward(): Aggregate micro-batch results and compute metricsUnify
TrainEngineto handle all model types through the new hook systemBaseModel:
DataBatchInfoandBatchForwardInfoTypedDicts for return typespre_micro_batch_forward()to compute token statisticspost_micro_batch_forward()to aggregate losses and extra info__call__to improve type inferenceMoE Model:
post_micro_batch_forward()to handle MoE-specific logic:need_update_biasproperty for cleaner codeComposeModel:
pre_micro_batch_forward()to compute image token statisticsComposeDataBatchInfowithstep_consumed_img_tokensfieldTrainEngine:
train_step()to delegate statistics to model hooksLossLogandOtherLogwith unifiedTrainStepInfo_get_total_loss()to aggregate all losses (with TODO for future refactor)EngineConfig:
Update to use
TrainStepInfoinstead of separateLossLogandOtherLogSimplify hook signatures (from 2 params to 1)
Remove conditional engine instantiation logic
Replace
VisionComposeTrainEngineimports withTrainEngineUpdate test assertions to use new
TrainStepInfostructureRemove TypeAdapter validation for deprecated types
Currently,
TrainEngine._get_total_loss()aggregates losses by iteratingthrough ModelOutputs fields. This is pragmatic but not ideal:
loss_ctx.batch_sizerepresents the full gradient accumulation batch size,not intra_layer_micro_batch. This is correctly set in
CELossContext.build_batches()and used for scaling balancing_loss and z_loss.
The pre/post hooks provide clean extension points: