mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
support code generation tasks
This commit is contained in:
@@ -8,8 +8,9 @@ import ray.util.collective as cc
|
||||
import torch
|
||||
import tqdm
|
||||
import wandb
|
||||
from coati.dataset.loader import RawConversationDataset
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from ray.util.collective import allreduce
|
||||
from ray.util.collective.types import Backend, ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
@@ -36,7 +37,6 @@ class BaseProducer:
|
||||
num_episodes: int,
|
||||
batch_size: int,
|
||||
train_dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
@@ -45,8 +45,7 @@ class BaseProducer:
|
||||
consumer_plugin_config: Dict[str, Any] = None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
response_format_tags=None,
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
@@ -75,6 +74,13 @@ class BaseProducer:
|
||||
self.eval_mode = False
|
||||
self.log_rollout_interval = log_rollout_interval
|
||||
self.latest_rollout_log_step = -1
|
||||
self.grpo_config = grpo_config
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
|
||||
}
|
||||
self.response_format_tags = grpo_config.get("response_format_tags", None)
|
||||
if producer_idx == 0:
|
||||
if os.path.exists(rollout_log_file):
|
||||
raise ValueError(
|
||||
@@ -120,6 +126,7 @@ class BaseProducer:
|
||||
),
|
||||
num_workers=4,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn_grpo,
|
||||
)
|
||||
|
||||
self.eval_dataset_config = eval_dataset_config
|
||||
@@ -142,17 +149,25 @@ class BaseProducer:
|
||||
drop_last=False,
|
||||
seed=42,
|
||||
),
|
||||
collate_fn=collate_fn_grpo,
|
||||
)
|
||||
if evaluation_function_type == "think_answer_tags":
|
||||
if grpo_config["reward_fn_type"] == "think_answer_tags":
|
||||
self.evaluation_function = math_reward_fn
|
||||
elif evaluation_function_type == "boxed":
|
||||
elif grpo_config["reward_fn_type"] == "boxed":
|
||||
self.evaluation_function = boxed_math_reward_fn
|
||||
elif grpo_config["reward_fn_type"] == "code":
|
||||
self.evaluation_function = code_reward_fn
|
||||
else:
|
||||
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||
self.response_format_tags = response_format_tags
|
||||
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
|
||||
else:
|
||||
print("No eval dataset provided, skip eval")
|
||||
self.device = get_current_device()
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[self.evaluation_function], # multiple reward functions can be added here
|
||||
tokenizer=self.tokenizer,
|
||||
tags=self.response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
|
||||
# init backend
|
||||
if backend in BACKEND_MAP:
|
||||
@@ -215,7 +230,13 @@ class BaseProducer:
|
||||
eval_results = eval_results + [
|
||||
self.evaluation_function(
|
||||
eval_outputs["input_ids"][m][n],
|
||||
eval_outputs["gt_answer"][m][n],
|
||||
eval_outputs[
|
||||
(
|
||||
"test_cases"
|
||||
if self.grpo_config["reward_fn_type"] == "code"
|
||||
else "gt_answer"
|
||||
)
|
||||
][m],
|
||||
eval_outputs["response_idx"][m][n],
|
||||
tokenizer=self.tokenizer,
|
||||
eval_mode=True,
|
||||
@@ -251,8 +272,50 @@ class BaseProducer:
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
||||
# breakpoint()
|
||||
if self.grpo_config["reward_fn_type"] == "code":
|
||||
test_cases = []
|
||||
for prompt_id in range(bs):
|
||||
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
|
||||
# assert all([isinstance(test_cases[i], str) for i in range(len(test_cases))]), f"{[type(test_cases[i]) for i in range(len(test_cases))]}, {test_cases}"
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
test_cases=test_cases,
|
||||
response_idx=outputs["response_idx"].view((-1, 2)),
|
||||
)
|
||||
else:
|
||||
gt_answer = []
|
||||
for prompt_id in range(bs):
|
||||
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
gt_answer=gt_answer,
|
||||
response_idx=outputs["response_idx"],
|
||||
)
|
||||
outputs["reward"] = (
|
||||
torch.tensor([value[0] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["format_acc"] = (
|
||||
torch.tensor([value[1] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["ans_acc"] = (
|
||||
torch.tensor([value[2] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
if "gt_answer" in outputs:
|
||||
outputs.pop("gt_answer")
|
||||
if "test_cases" in outputs:
|
||||
outputs.pop("test_cases")
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
# breakpoint()
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
@@ -315,7 +378,6 @@ class SimpleProducer(BaseProducer):
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
dataloaders_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config=None,
|
||||
@@ -325,8 +387,7 @@ class SimpleProducer(BaseProducer):
|
||||
consumer_plugin_config=None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
response_format_tags=None,
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
@@ -342,7 +403,6 @@ class SimpleProducer(BaseProducer):
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
dataloaders_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config,
|
||||
@@ -351,8 +411,7 @@ class SimpleProducer(BaseProducer):
|
||||
consumer_plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
evaluation_function_type=evaluation_function_type,
|
||||
response_format_tags=response_format_tags,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
|
Reference in New Issue
Block a user