Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Direct Preference Optimization (DPO) and its variants (SimPO, ORPO, CPO) to the Twinkle framework, adding new loss functions, specialized data preprocessors, and a Ray-based training recipe. The Trajectory data format was updated to include user_data, and the template encoding logic was enhanced with parallel processing. Feedback identifies critical issues such as a type mismatch in the template encoding return value and a hardcoded parameter that breaks reference-free loss modes. Additionally, logical errors in conversation parsing and multiple inconsistencies between documentation and implementation regarding default values, configurable keys, and supported loss types were noted, along with opportunities to improve the robustness of message role parsing and preprocessor outputs.
| trajectory.update(input_feature) | ||
| return trajectory |
There was a problem hiding this comment.
The _encode_messages method now updates the input trajectory with input_feature fields and returns the modified trajectory. However, the _invoke_post_pipeline (called by batch_encode) expects a List[InputFeature], not a List[Trajectory]. This type mismatch will cause a runtime error when _invoke_post_pipeline attempts to process Trajectory objects as InputFeature objects, as the internal pipeline functions like _check_max_length are designed for InputFeature.
| trajectory.update(input_feature) | |
| return trajectory | |
| trajectory.update(input_feature) | |
| return input_feature |
cookbook/rl/dpo.py
Outdated
| reference_free = LOSS_TYPE in ['simpo', 'orpo', 'cpo'] | ||
|
|
||
| # Set up loss function | ||
| loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False) |
There was a problem hiding this comment.
The create_loss function is called with reference_free=False explicitly, even though reference_free is already determined by LOSS_TYPE on line 196. If LOSS_TYPE is one of the reference-free types (simpo, orpo, cpo), passing reference_free=False here will incorrectly force the loss function to expect a reference model, leading to a runtime error or incorrect behavior. The reference_free parameter should be passed dynamically based on the LOSS_TYPE evaluation.
| loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=False) | |
| loss_fn = create_loss(LOSS_TYPE, DPO_BETA, reference_free=reference_free) |
| messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) | ||
| return messages |
There was a problem hiding this comment.
In _build_prompt_messages, when parsing a dictionary message, msg.get('role') is called without a default value. If the 'role' key is missing from the dictionary, this will result in None being assigned as the role, which could lead to unexpected behavior or type errors downstream. It's best to provide a default value or handle the missing key explicitly.
| messages.append(Message(role=msg.get('role'), content=msg.get('content', ''))) | |
| return messages | |
| messages.append(Message(role=msg.get('role', 'user'), content=msg.get('content', ''))) |
| parts = text.split('\n\nHuman: ') | ||
| for i, part in enumerate(parts): | ||
| if i == 0 and not part.startswith('Human: '): | ||
| if part.strip(): | ||
| if part.startswith('Human: '): | ||
| part = part[7:] | ||
| messages.append(Message(role='user', content=part.strip())) |
There was a problem hiding this comment.
The logic for handling the first part of the conversation in _parse_hh_conversation contains a contradictory condition: if i == 0 and not part.startswith('Human: ') followed by if part.startswith('Human: '). This inner if statement will never be true if the outer if statement's condition not part.startswith('Human: ') is met. This indicates a logical error that might prevent correct parsing of certain conversation formats.
| BATCH_SIZE – global batch size (preference pairs) (default: 8) | ||
| MICRO_BATCH_SIZE – per-device micro batch size (default: 2) |
There was a problem hiding this comment.
The default BATCH_SIZE and MICRO_BATCH_SIZE values specified in the docstring comments here (BATCH_SIZE default: 8, MICRO_BATCH_SIZE default: 2) are inconsistent with their actual default values set in the code (lines 76-77, BATCH_SIZE: 4, MICRO_BATCH_SIZE: 4). Please update the docstring to reflect the correct default values.
| Dataset field mapping (for custom datasets): | ||
| PROMPT_KEY – key for prompt field (default: 'prompt') | ||
| CHOSEN_KEY – key for chosen response (default: 'answer_zh') | ||
| REJECTED_KEY – key for rejected response (default: 'answer_en') |
There was a problem hiding this comment.
The docstring mentions PROMPT_KEY, CHOSEN_KEY, and REJECTED_KEY as configurable environment variables for custom datasets. However, the EmojiDPOProcessor used in create_dpo_dataset (lines 94-97) does not read these values from environment variables or accept them as init_args. This creates a discrepancy between the documented configurability and the actual implementation, making these environment variables ineffective for EmojiDPOProcessor.
| beta: Temperature parameter controlling how much to deviate from ref policy (default: 0.1). | ||
| label_smoothing: Label smoothing parameter for soft labels (default: 0.0). | ||
| ignore_index: Index to ignore in labels (default: -100). | ||
| loss_type: Type of DPO loss variant ('sigmoid', 'hinge', 'ipo', 'kto_pair') (default: 'sigmoid'). |
There was a problem hiding this comment.
The docstring for DPOLoss lists 'kto_pair' as a supported loss_type. While the implementation for 'kto_pair' exists (line 228), this loss type is not exposed or used in the cookbook/rl/dpo.py script or dpo.sh script. This creates an inconsistency between the documentation and the practical usage examples provided.
| messages.append(Message(role=msg.get('role', 'assistant'), content=msg.get('content', ''))) | ||
| return messages |
There was a problem hiding this comment.
In _parse_response, when a dictionary message is encountered, the role defaults to 'assistant' if not present. This might be too restrictive or incorrect if the message could originate from a user or another role. Consider making the default role more general or raising an error if the role is missing and cannot be inferred.
| results = [self.preprocess(row) for row in rows] | ||
| results = [r for r in results if r is not None] | ||
| if not results: | ||
| return {} |
There was a problem hiding this comment.
If results is empty, the __call__ method returns an empty dictionary {}. This can cause issues for downstream consumers expecting a dictionary with positive and negative keys, even if those lists are empty. Returning {'positive': [], 'negative': []} would provide a more consistent and predictable output structure.
if not results:
return {'positive': [], 'negative': []}| chosen_key: Key for chosen response (default: 'answer_zh'). | ||
| rejected_key: Key for rejected response (default: 'answer_en'). | ||
| prompt_key: Key for prompt (default: 'prompt'). |
There was a problem hiding this comment.
The EmojiDPOProcessor hardcodes chosen_key, rejected_key, and prompt_key to specific values ('answer_zh', 'answer_en', 'prompt'). This contradicts the dpo.py script's docstring, which implies these keys can be configured via environment variables. To align with the documentation, these keys should either be configurable through init_args or the docstring should clarify that EmojiDPOProcessor uses fixed keys.
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).