Skip to content

[Refactor] Unify TrainEngine by moving model-specific logic to model layer#1518

Open
HAOCHENYE wants to merge 3 commits intogh/HAOCHENYE/18/basefrom
gh/HAOCHENYE/18/head
Open

[Refactor] Unify TrainEngine by moving model-specific logic to model layer#1518
HAOCHENYE wants to merge 3 commits intogh/HAOCHENYE/18/basefrom
gh/HAOCHENYE/18/head

Conversation

@HAOCHENYE
Copy link
Collaborator

@HAOCHENYE HAOCHENYE commented Mar 2, 2026

Stack from ghstack (oldest at bottom):

Previously, we had two separate train engines:

  • TrainEngine for regular models
  • VisionComposeTrainEngine for vision-language models

This duplication led to:

  • Code maintenance overhead (242 lines of duplicated logic)

  • Tight coupling between engine and model-specific details

  • Difficulty in extending to new model types

  • Remove VisionComposeTrainEngine entirely (242 lines deleted)

  • Add pre_micro_batch_forward() and post_micro_batch_forward() hooks to BaseModel

    • pre_micro_batch_forward(): Compute data batch statistics before forward pass
    • post_micro_batch_forward(): Aggregate micro-batch results and compute metrics
  • Unify TrainEngine to handle all model types through the new hook system

  • BaseModel:

    • Add DataBatchInfo and BatchForwardInfo TypedDicts for return types
    • Implement default pre_micro_batch_forward() to compute token statistics
    • Implement default post_micro_batch_forward() to aggregate losses and extra info
    • Add overload type hints for __call__ to improve type inference
  • MoE Model:

    • Override post_micro_batch_forward() to handle MoE-specific logic:
      • Compute maxvio for router load balancing
      • Update router bias based on expert load
    • Add need_update_bias property for cleaner code
    • Properly scale balancing_loss and z_loss by batch_size
  • ComposeModel:

    • Override pre_micro_batch_forward() to compute image token statistics
    • Add ComposeDataBatchInfo with step_consumed_img_tokens field
  • TrainEngine:

    • Simplify train_step() to delegate statistics to model hooks
    • Replace LossLog and OtherLog with unified TrainStepInfo
    • Add _get_total_loss() to aggregate all losses (with TODO for future refactor)
    • Remove all model-specific branching logic
  • EngineConfig:

    • Remove conditional logic for VisionComposeTrainEngine
    • Use single TrainEngine.build() path
  • Update to use TrainStepInfo instead of separate LossLog and OtherLog

  • Simplify hook signatures (from 2 params to 1)

  • Remove conditional engine instantiation logic

  • Replace VisionComposeTrainEngine imports with TrainEngine

  • Update test assertions to use new TrainStepInfo structure

  • Remove TypeAdapter validation for deprecated types

Currently, TrainEngine._get_total_loss() aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:

  • Pros: Avoids large-scale changes to model forward() logic
  • Cons: Engine knows about loss field names (coupling)
  • Future: Model should return total_loss directly (see TODO comment)

loss_ctx.batch_size represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in CELossContext.build_batches()
and used for scaling balancing_loss and z_loss.

The pre/post hooks provide clean extension points:

  • Subclasses can override to add model-specific logic
  • Default implementations in BaseModel handle common cases
  • No conditional logic needed in engine layer
  1. Code Reduction: -242 lines (VisionComposeTrainEngine removed)
  2. Better Separation of Concerns: Engine focuses on training orchestration, models handle their own statistics
  3. Extensibility: New model types can override hooks without changing engine
  4. Type Safety: Unified TrainStepInfo with clear field definitions
  5. Maintainability: Single engine implementation to maintain
  • Loss reduce logic still needs clarification (minor issue, doesn't affect training)
  • TODO added for future refactor: move total_loss aggregation to model layer
  • All format issues (extra blank lines, class formatting) fixed

[ghstack-poisoned]
HAOCHENYE added a commit that referenced this pull request Mar 2, 2026
…layer

Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models

This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types

- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
  - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
  - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system

- **BaseModel**:
  - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
  - Implement default `pre_micro_batch_forward()` to compute token statistics
  - Implement default `post_micro_batch_forward()` to aggregate losses and extra info
  - Add overload type hints for `__call__` to improve type inference

- **MoE Model**:
  - Override `post_micro_batch_forward()` to handle MoE-specific logic:
    - Compute maxvio for router load balancing
    - Update router bias based on expert load
  - Add `need_update_bias` property for cleaner code
  - Properly scale balancing_loss and z_loss by batch_size

- **ComposeModel**:
  - Override `pre_micro_batch_forward()` to compute image token statistics
  - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field

- **TrainEngine**:
  - Simplify `train_step()` to delegate statistics to model hooks
  - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
  - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
  - Remove all model-specific branching logic

- **EngineConfig**:
  - Remove conditional logic for VisionComposeTrainEngine
  - Use single TrainEngine.build() path

- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic

- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types

Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)

`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.

The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer

1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain

- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed


ghstack-source-id: 0b1735a
Pull-Request: #1518
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 2, 2026
…layer

Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models

This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types

- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
  - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
  - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system

- **BaseModel**:
  - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
  - Implement default `pre_micro_batch_forward()` to compute token statistics
  - Implement default `post_micro_batch_forward()` to aggregate losses and extra info
  - Add overload type hints for `__call__` to improve type inference

- **MoE Model**:
  - Override `post_micro_batch_forward()` to handle MoE-specific logic:
    - Compute maxvio for router load balancing
    - Update router bias based on expert load
  - Add `need_update_bias` property for cleaner code
  - Properly scale balancing_loss and z_loss by batch_size

- **ComposeModel**:
  - Override `pre_micro_batch_forward()` to compute image token statistics
  - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field

- **TrainEngine**:
  - Simplify `train_step()` to delegate statistics to model hooks
  - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
  - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
  - Remove all model-specific branching logic

- **EngineConfig**:
  - Remove conditional logic for VisionComposeTrainEngine
  - Use single TrainEngine.build() path

- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic

- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types

Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)

`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.

The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer

1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain

- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed

ghstack-source-id: 0b1735a
Pull-Request: InternLM#1518
[ghstack-poisoned]
HAOCHENYE added a commit that referenced this pull request Mar 2, 2026
…layer

Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models

This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types

- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
  - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
  - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system

- **BaseModel**:
  - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
  - Implement default `pre_micro_batch_forward()` to compute token statistics
  - Implement default `post_micro_batch_forward()` to aggregate losses and extra info
  - Add overload type hints for `__call__` to improve type inference

- **MoE Model**:
  - Override `post_micro_batch_forward()` to handle MoE-specific logic:
    - Compute maxvio for router load balancing
    - Update router bias based on expert load
  - Add `need_update_bias` property for cleaner code
  - Properly scale balancing_loss and z_loss by batch_size

- **ComposeModel**:
  - Override `pre_micro_batch_forward()` to compute image token statistics
  - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field

- **TrainEngine**:
  - Simplify `train_step()` to delegate statistics to model hooks
  - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
  - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
  - Remove all model-specific branching logic

- **EngineConfig**:
  - Remove conditional logic for VisionComposeTrainEngine
  - Use single TrainEngine.build() path

- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic

- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types

Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)

`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.

The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer

1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain

- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed

ghstack-source-id: 9ce752d
Pull-Request: #1518
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 3, 2026
…layer

Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models

This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types

- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
  - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
  - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system

- **BaseModel**:
  - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
  - Implement default `pre_micro_batch_forward()` to compute token statistics
  - Implement default `post_micro_batch_forward()` to aggregate losses and extra info
  - Add overload type hints for `__call__` to improve type inference

- **MoE Model**:
  - Override `post_micro_batch_forward()` to handle MoE-specific logic:
    - Compute maxvio for router load balancing
    - Update router bias based on expert load
  - Add `need_update_bias` property for cleaner code
  - Properly scale balancing_loss and z_loss by batch_size

- **ComposeModel**:
  - Override `pre_micro_batch_forward()` to compute image token statistics
  - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field

- **TrainEngine**:
  - Simplify `train_step()` to delegate statistics to model hooks
  - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
  - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
  - Remove all model-specific branching logic

- **EngineConfig**:
  - Remove conditional logic for VisionComposeTrainEngine
  - Use single TrainEngine.build() path

- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic

- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types

Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)

`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.

The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer

1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain

- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed

ghstack-source-id: 9ce752d
Pull-Request: InternLM#1518
[ghstack-poisoned]
HAOCHENYE added a commit that referenced this pull request Mar 3, 2026
…layer

Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models

This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types

- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
  - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
  - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system

- **BaseModel**:
  - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
  - Implement default `pre_micro_batch_forward()` to compute token statistics
  - Implement default `post_micro_batch_forward()` to aggregate losses and extra info
  - Add overload type hints for `__call__` to improve type inference

- **MoE Model**:
  - Override `post_micro_batch_forward()` to handle MoE-specific logic:
    - Compute maxvio for router load balancing
    - Update router bias based on expert load
  - Add `need_update_bias` property for cleaner code
  - Properly scale balancing_loss and z_loss by batch_size

- **ComposeModel**:
  - Override `pre_micro_batch_forward()` to compute image token statistics
  - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field

- **TrainEngine**:
  - Simplify `train_step()` to delegate statistics to model hooks
  - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
  - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
  - Remove all model-specific branching logic

- **EngineConfig**:
  - Remove conditional logic for VisionComposeTrainEngine
  - Use single TrainEngine.build() path

- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic

- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types

Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)

`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.

The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer

1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain

- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed

ghstack-source-id: ba87e63
Pull-Request: #1518
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 3, 2026
…layer

Previously, we had two separate train engines:
- `TrainEngine` for regular models
- `VisionComposeTrainEngine` for vision-language models

This duplication led to:
- Code maintenance overhead (242 lines of duplicated logic)
- Tight coupling between engine and model-specific details
- Difficulty in extending to new model types

- **Remove** `VisionComposeTrainEngine` entirely (242 lines deleted)
- **Add** `pre_micro_batch_forward()` and `post_micro_batch_forward()` hooks to `BaseModel`
  - `pre_micro_batch_forward()`: Compute data batch statistics before forward pass
  - `post_micro_batch_forward()`: Aggregate micro-batch results and compute metrics
- **Unify** `TrainEngine` to handle all model types through the new hook system

- **BaseModel**:
  - Add `DataBatchInfo` and `BatchForwardInfo` TypedDicts for return types
  - Implement default `pre_micro_batch_forward()` to compute token statistics
  - Implement default `post_micro_batch_forward()` to aggregate losses and extra info
  - Add overload type hints for `__call__` to improve type inference

- **MoE Model**:
  - Override `post_micro_batch_forward()` to handle MoE-specific logic:
    - Compute maxvio for router load balancing
    - Update router bias based on expert load
  - Add `need_update_bias` property for cleaner code
  - Properly scale balancing_loss and z_loss by batch_size

- **ComposeModel**:
  - Override `pre_micro_batch_forward()` to compute image token statistics
  - Add `ComposeDataBatchInfo` with `step_consumed_img_tokens` field

- **TrainEngine**:
  - Simplify `train_step()` to delegate statistics to model hooks
  - Replace `LossLog` and `OtherLog` with unified `TrainStepInfo`
  - Add `_get_total_loss()` to aggregate all losses (with TODO for future refactor)
  - Remove all model-specific branching logic

- **EngineConfig**:
  - Remove conditional logic for VisionComposeTrainEngine
  - Use single TrainEngine.build() path

- Update to use `TrainStepInfo` instead of separate `LossLog` and `OtherLog`
- Simplify hook signatures (from 2 params to 1)
- Remove conditional engine instantiation logic

- Replace `VisionComposeTrainEngine` imports with `TrainEngine`
- Update test assertions to use new `TrainStepInfo` structure
- Remove TypeAdapter validation for deprecated types

Currently, `TrainEngine._get_total_loss()` aggregates losses by iterating
through ModelOutputs fields. This is pragmatic but not ideal:
- **Pros**: Avoids large-scale changes to model forward() logic
- **Cons**: Engine knows about loss field names (coupling)
- **Future**: Model should return total_loss directly (see TODO comment)

`loss_ctx.batch_size` represents the full gradient accumulation batch size,
not intra_layer_micro_batch. This is correctly set in `CELossContext.build_batches()`
and used for scaling balancing_loss and z_loss.

The pre/post hooks provide clean extension points:
- Subclasses can override to add model-specific logic
- Default implementations in BaseModel handle common cases
- No conditional logic needed in engine layer

1. **Code Reduction**: -242 lines (VisionComposeTrainEngine removed)
2. **Better Separation of Concerns**: Engine focuses on training orchestration, models handle their own statistics
3. **Extensibility**: New model types can override hooks without changing engine
4. **Type Safety**: Unified TrainStepInfo with clear field definitions
5. **Maintainability**: Single engine implementation to maintain

- Loss reduce logic still needs clarification (minor issue, doesn't affect training)
- TODO added for future refactor: move total_loss aggregation to model layer
- All format issues (extra blank lines, class formatting) fixed

ghstack-source-id: ba87e63
Pull-Request: InternLM#1518
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant