support code generation tasks

This commit is contained in:
YeAnbang
2025-06-05 17:56:42 +08:00
parent ceb7065d6d
commit dc3033e68a
12 changed files with 1027 additions and 127 deletions

View File

@@ -8,8 +8,9 @@ import ray.util.collective as cc
import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
@@ -36,7 +37,6 @@ class BaseProducer:
num_episodes: int,
batch_size: int,
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
@@ -45,8 +45,7 @@ class BaseProducer:
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
response_format_tags=None,
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
@@ -75,6 +74,13 @@ class BaseProducer:
self.eval_mode = False
self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.response_format_tags = grpo_config.get("response_format_tags", None)
if producer_idx == 0:
if os.path.exists(rollout_log_file):
raise ValueError(
@@ -120,6 +126,7 @@ class BaseProducer:
),
num_workers=4,
drop_last=True,
collate_fn=collate_fn_grpo,
)
self.eval_dataset_config = eval_dataset_config
@@ -142,17 +149,25 @@ class BaseProducer:
drop_last=False,
seed=42,
),
collate_fn=collate_fn_grpo,
)
if evaluation_function_type == "think_answer_tags":
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
elif grpo_config["reward_fn_type"] == "boxed":
self.evaluation_function = boxed_math_reward_fn
elif grpo_config["reward_fn_type"] == "code":
self.evaluation_function = code_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
else:
print("No eval dataset provided, skip eval")
self.device = get_current_device()
self.reward_model = VerifiableReward(
reward_fns=[self.evaluation_function], # multiple reward functions can be added here
tokenizer=self.tokenizer,
tags=self.response_format_tags,
**reward_model_kwargs,
)
# init backend
if backend in BACKEND_MAP:
@@ -215,7 +230,13 @@ class BaseProducer:
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs["gt_answer"][m][n],
eval_outputs[
(
"test_cases"
if self.grpo_config["reward_fn_type"] == "code"
else "gt_answer"
)
][m],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
@@ -251,8 +272,50 @@ class BaseProducer:
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
# breakpoint()
if self.grpo_config["reward_fn_type"] == "code":
test_cases = []
for prompt_id in range(bs):
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
# assert all([isinstance(test_cases[i], str) for i in range(len(test_cases))]), f"{[type(test_cases[i]) for i in range(len(test_cases))]}, {test_cases}"
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
test_cases=test_cases,
response_idx=outputs["response_idx"].view((-1, 2)),
)
else:
gt_answer = []
for prompt_id in range(bs):
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
gt_answer=gt_answer,
response_idx=outputs["response_idx"],
)
outputs["reward"] = (
torch.tensor([value[0] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["format_acc"] = (
torch.tensor([value[1] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["ans_acc"] = (
torch.tensor([value[2] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
if "gt_answer" in outputs:
outputs.pop("gt_answer")
if "test_cases" in outputs:
outputs.pop("test_cases")
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
# breakpoint()
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
@@ -315,7 +378,6 @@ class SimpleProducer(BaseProducer):
num_episodes,
batch_size,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
tokenizer_config=None,
@@ -325,8 +387,7 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
response_format_tags=None,
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
@@ -342,7 +403,6 @@ class SimpleProducer(BaseProducer):
num_episodes,
batch_size,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
tokenizer_config,
@@ -351,8 +411,7 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
response_format_tags=response_format_tags,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,