Support evaluation during training

This commit is contained in:
YeAnbang
2025-04-30 18:13:40 +08:00
parent b920af427b
commit 47a7dc7142
9 changed files with 234 additions and 65 deletions

View File

@@ -1,9 +1,13 @@
import copy
import os
from typing import Any, Dict, Optional
import ray
import ray.util.collective as cc
import torch
import tqdm
from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
@@ -11,7 +15,12 @@ from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP
from .utils import pre_send
from .utils import pre_send, safe_write_jsonl
try:
from vllm import SamplingParams
except ImportError:
LLM = None
class BaseProducer:
@@ -22,7 +31,7 @@ class BaseProducer:
num_consumer_procs: int,
num_episodes: int,
batch_size: int,
dataset_config: Dict[str, Any],
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
@@ -30,6 +39,10 @@ class BaseProducer:
microbatch_size: int = 1,
backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@@ -40,10 +53,17 @@ class BaseProducer:
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.dataset_config = dataset_config
self.train_dataset_config = train_dataset_config
self.model_config = model_config
self.generate_config = generate_config
self.tokenizer_config = tokenizer_config
self.consumer_plugin_config = consumer_plugin_config
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
if os.path.exists(self.eval_save_dir):
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer
if tokenizer_config is None:
@@ -55,13 +75,13 @@ class BaseProducer:
self.tokenizer.padding_side = "left"
# init dataloader
dataset_path = dataset_config.pop("path")
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
self.dataloader = DataLoader(
self.dataset,
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
self.dataset,
self.train_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=True,
@@ -71,6 +91,36 @@ class BaseProducer:
num_workers=4,
drop_last=True,
)
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None:
self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
eval_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
eval_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=False,
drop_last=False,
seed=42,
),
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
else:
raise ValueError("eval_dataset_config is not defined")
self.device = get_current_device()
# init backend
@@ -88,6 +138,7 @@ class BaseProducer:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError
@@ -96,29 +147,64 @@ class BaseProducer:
raise NotImplementedError
def loop(self) -> None:
num_update_per_episode = len(self.dataloader) // self.num_microbatches
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
)
for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader):
self.train_dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.train_dataloader):
if i >= num_valid_microbatches:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if i % self.eval_interval == 0:
eval_statistics = {}
for eval_task_name in self.eval_dataloaders:
print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
)
eval_results = []
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs["gt_answer"][m][n],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics[eval_task_name][0] += len(
[res for res in eval_results if res["ans_valid"] == 1]
)
eval_statistics[eval_task_name][1] += len(eval_results)
# save eval results
result_file_name = os.path.join(
self.eval_save_dir,
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
)
# delete the file if it exists
safe_write_jsonl(result_file_name, eval_results)
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
ray_broadcast_tensor_dict(
eval_statistics,
src=0,
device=self.device,
group_name=f"sync_eval_statistics_{self.producer_idx}",
)
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[
(
self.model.generate_config["temperature"]
if isinstance(self.model.generate_config.temperature, dict)
else self.model.generate_config.temperature
)
]
* outputs["input_ids"].size(0)
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
@@ -150,6 +236,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
@@ -159,7 +247,7 @@ class BaseProducer:
self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
if isinstance(self.model.generate_config.temperature, dict):
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature"
@@ -179,7 +267,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs,
num_episodes,
batch_size,
dataset_config,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
@@ -188,6 +276,10 @@ class SimpleProducer(BaseProducer):
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
):
super().__init__(
producer_idx,
@@ -195,7 +287,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs,
num_episodes,
batch_size,
dataset_config,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
@@ -203,14 +295,21 @@ class SimpleProducer(BaseProducer):
microbatch_size,
backend,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
eval_save_dir=eval_save_dir,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
# if self.producer_idx == 1:
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
return rollouts