fix style

This commit is contained in:
YeAnbang
2024-07-26 09:55:15 +00:00
parent 9688e19b32
commit 8a3ff4f315
7 changed files with 10 additions and 15 deletions

View File

@@ -46,8 +46,7 @@ def supervised_tokenize_sft(
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
@@ -146,8 +145,6 @@ def tokenize_prompt_dataset(
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
@@ -226,8 +223,6 @@ def tokenize_rlhf(
{"context": [{"from": "user", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
context = data_point["context"]
template = deepcopy(conversation_template)

View File

@@ -26,7 +26,7 @@ from .utils import is_rank_0, to_device
class DPOTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Trainer for DPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm

View File

@@ -27,7 +27,7 @@ from .utils import is_rank_0, to_device
class KTOTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Trainer for KTO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm

View File

@@ -26,7 +26,7 @@ from .utils import is_rank_0, to_device
class ORPOTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Trainer for ORPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm