mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
fix missing tags parameter
This commit is contained in:
parent
88e3b09c79
commit
78a06f5ce3
@ -120,16 +120,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||||
)
|
)
|
||||||
# Initialize verifiable reward.
|
# Initialize verifiable reward.
|
||||||
response_format_tags = (
|
response_format_tags = grpo_config.get("response_format_tags", None)
|
||||||
{
|
|
||||||
"think_start": {"text": "<think>", "num_occur": 1},
|
|
||||||
"think_end": {"text": "</think>", "num_occur": 1},
|
|
||||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
|
||||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
|
||||||
}
|
|
||||||
if grpo_config.get("reward_fn_type") == "think_answer_tags"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
reward_model_kwargs = {
|
reward_model_kwargs = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in grpo_config.items()
|
for k, v in grpo_config.items()
|
||||||
|
@ -100,6 +100,7 @@ def launch_distributed(
|
|||||||
eval_dataset_config=eval_dataset_config,
|
eval_dataset_config=eval_dataset_config,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
|
response_format_tags=grpo_config["response_format_tags"],
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
eval_generation_config=eval_generation_config,
|
eval_generation_config=eval_generation_config,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
@ -46,6 +46,7 @@ class BaseProducer:
|
|||||||
eval_dataset_config=None,
|
eval_dataset_config=None,
|
||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
evaluation_function_type="think_answer_tags",
|
evaluation_function_type="think_answer_tags",
|
||||||
|
response_format_tags=None,
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
run_name: str = None,
|
run_name: str = None,
|
||||||
@ -148,6 +149,7 @@ class BaseProducer:
|
|||||||
self.evaluation_function = boxed_math_reward_fn
|
self.evaluation_function = boxed_math_reward_fn
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||||
|
self.response_format_tags = response_format_tags
|
||||||
else:
|
else:
|
||||||
raise ValueError("eval_dataset_config is not defined")
|
raise ValueError("eval_dataset_config is not defined")
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
@ -217,6 +219,7 @@ class BaseProducer:
|
|||||||
eval_outputs["response_idx"][m][n],
|
eval_outputs["response_idx"][m][n],
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
eval_mode=True,
|
eval_mode=True,
|
||||||
|
tags=self.response_format_tags,
|
||||||
)
|
)
|
||||||
for m in range(eval_outputs["input_ids"].size(0))
|
for m in range(eval_outputs["input_ids"].size(0))
|
||||||
for n in range(eval_outputs["input_ids"].size(1))
|
for n in range(eval_outputs["input_ids"].size(1))
|
||||||
@ -324,6 +327,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
eval_dataset_config=None,
|
eval_dataset_config=None,
|
||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
evaluation_function_type="think_answer_tags",
|
evaluation_function_type="think_answer_tags",
|
||||||
|
response_format_tags=None,
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
eval_generation_config={},
|
eval_generation_config={},
|
||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
@ -349,6 +353,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
eval_dataset_config=eval_dataset_config,
|
eval_dataset_config=eval_dataset_config,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
evaluation_function_type=evaluation_function_type,
|
evaluation_function_type=evaluation_function_type,
|
||||||
|
response_format_tags=response_format_tags,
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
@ -231,6 +231,16 @@ if __name__ == "__main__":
|
|||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"response_format_tags": (
|
||||||
|
{
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
if args.reward_type == "think_answer_tags"
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
elif args.algo == "DAPO":
|
elif args.algo == "DAPO":
|
||||||
# DAPO variant settings
|
# DAPO variant settings
|
||||||
@ -250,6 +260,16 @@ if __name__ == "__main__":
|
|||||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
|
"response_format_tags": (
|
||||||
|
{
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
if args.reward_type == "think_answer_tags"
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||||
|
Loading…
Reference in New Issue
Block a user