ColossalAI/applications/ColossalChat/coati/distributed/producer.py
2025-06-26 10:27:00 +08:00

478 lines
22 KiB
Python

import copy
import json
import os
from typing import Any, Dict, Optional
import ray
import ray.util.collective as cc
import torch
import tqdm
import wandb
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP
from .utils import pre_send, safe_append_to_jsonl_file
try:
from vllm import SamplingParams
except ImportError:
LLM = None
class BaseProducer:
def __init__(
self,
producer_idx: int,
num_producers: int,
num_consumer_procs: int,
num_episodes: int,
batch_size: int,
train_dataset_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
microbatch_size: int = 1,
backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
n_behind: int = 0,
):
self.producer_idx = producer_idx
self.num_producers = num_producers
self.num_consumer_procs = num_consumer_procs
self.num_episodes = num_episodes
self.batch_size = batch_size
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.latest_eval_step = -1
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
self.train_dataset_config = train_dataset_config
self.model_config = model_config
self.generate_config = generate_config
self.tokenizer_config = tokenizer_config
self.consumer_plugin_config = consumer_plugin_config
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
self.eval_mode = False
self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
self.n_behind = n_behind
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.response_format_tags = grpo_config.get("response_format_tags", None)
if producer_idx == 0:
if os.path.exists(rollout_log_file):
raise ValueError(
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
)
else:
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
sync_tensorboard=False,
dir="./wandb",
name=run_name + "_eval",
group=wandb_group_name,
)
if os.path.exists(self.eval_save_dir) and self.eval_interval > 0:
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer
if tokenizer_config is None:
tokenizer_path = model_config["path"]
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
tokenizer_path = tokenizer_config.pop("path")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"
# init dataloader
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
self.train_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=True,
drop_last=True,
seed=42,
),
num_workers=4,
drop_last=True,
collate_fn=collate_fn_grpo,
)
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed":
self.evaluation_function = boxed_math_reward_fn
elif grpo_config["reward_fn_type"] == "code":
self.evaluation_function = code_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None:
self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
eval_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
eval_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=False,
drop_last=False,
seed=42,
),
collate_fn=collate_fn_grpo,
)
else:
print("No eval dataset provided, skip eval")
self.device = get_current_device()
self.reward_model = VerifiableReward(
reward_fns=[self.evaluation_function], # multiple reward functions can be added here
tokenizer=self.tokenizer,
tags=self.response_format_tags,
**reward_model_kwargs,
)
# init backend
if backend in BACKEND_MAP:
self.backend_cls = BACKEND_MAP[backend]
else:
raise ValueError(f"Unexpected backend {backend}")
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
def setup(self) -> None:
cc.init_collective_group(
world_size=self.num_producers,
rank=self.producer_idx,
backend=Backend.NCCL,
group_name="producer_group",
)
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
raise NotImplementedError
def loop(self) -> None:
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
)
for episode in range(self.num_episodes):
self.train_dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.train_dataloader):
if i >= num_valid_microbatches:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if (
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
and self.consumer_global_step > self.latest_eval_step
) or self.latest_eval_step == -1:
to_log_msg = {}
self.eval_mode = True
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs[
(
"test_cases"
if self.grpo_config["reward_fn_type"] == "code"
else "gt_answer"
)
][m],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
tags=self.response_format_tags,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
eval_statistics_tensor[1] += len(eval_results)
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
to_log_msg[f"eval/{eval_task_name}"] = (
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
)
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
)
# save eval results
safe_append_to_jsonl_file(
os.path.join(
self.eval_save_dir,
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
),
eval_results,
)
if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
self.profiler.enter("rollout")
outputs = self.rollout(**batch)
self.profiler.exit("rollout")
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code":
test_cases = []
for prompt_id in range(bs):
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
test_cases=test_cases,
response_idx=outputs["response_idx"].view((-1, 2)),
)
else:
gt_answer = []
for prompt_id in range(bs):
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
gt_answer=gt_answer,
response_idx=outputs["response_idx"].view((-1, 2)),
)
outputs["reward"] = (
torch.tensor([value[0] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["format_acc"] = (
torch.tensor([value[1] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["ans_acc"] = (
torch.tensor([value[2] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
if "gt_answer" in outputs:
outputs.pop("gt_answer")
if "test_cases" in outputs:
outputs.pop("test_cases")
self.profiler.exit("calculate_reward")
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
self.profiler.enter("send_broadcast_data")
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
self.profiler.exit("send_broadcast_data")
if (
(i + 1) % self.num_microbatches == 0
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
):
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
print(
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
del state_dict
torch.cuda.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
if isinstance(self.model, BACKEND_MAP["vllm"]):
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
def __del__(self):
self.profiler.close()
@ray.remote
class SimpleProducer(BaseProducer):
def __init__(
self,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
n_behind: int = 0,
):
super().__init__(
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config,
microbatch_size,
backend,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 0 and not self.eval_mode:
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
):
new_record = (
json.dumps(
{
"train_step": self.consumer_global_step,
"rollout": self.tokenizer.batch_decode(
rollouts["input_ids"][:, 0], skip_special_tokens=True
),
}
)
+ "\n"
)
self.rollout_log_file.write(new_record)
self.rollout_log_file.flush()
self.latest_rollout_log_step = self.consumer_global_step
return rollouts
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
if hasattr(self, "rollout_log_file"):
self.rollout_log_file.close()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)