mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fix] revert reward update and evaluation (#6295)
* Revert "rewrite reward fn" This reverts commitd06042b434
. * Revert "upgrade reward math verification" This reverts commita6085ff676
. * Revert "fix bug" This reverts commit01640ebd65
. * Revert "reuse comm-group" This reverts commitbd61918dcf
. * Revert "Support evaluation during training" This reverts commit57a88395fe
.
This commit is contained in:
@@ -1,13 +1,9 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import tqdm
|
||||
from coati.dataset.loader import RawConversationDataset
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -15,12 +11,7 @@ from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .inference_backend import BACKEND_MAP
|
||||
from .utils import pre_send, safe_write_jsonl
|
||||
|
||||
try:
|
||||
from vllm import SamplingParams
|
||||
except ImportError:
|
||||
LLM = None
|
||||
from .utils import pre_send
|
||||
|
||||
|
||||
class BaseProducer:
|
||||
@@ -31,7 +22,7 @@ class BaseProducer:
|
||||
num_consumer_procs: int,
|
||||
num_episodes: int,
|
||||
batch_size: int,
|
||||
train_dataset_config: Dict[str, Any],
|
||||
dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
@@ -39,10 +30,6 @@ class BaseProducer:
|
||||
microbatch_size: int = 1,
|
||||
backend: str = "transformers",
|
||||
consumer_plugin_config: Dict[str, Any] = None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
eval_save_dir: str = "./eval",
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
@@ -53,17 +40,10 @@ class BaseProducer:
|
||||
assert batch_size % microbatch_size == 0
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.dataset_config = dataset_config
|
||||
self.model_config = model_config
|
||||
self.generate_config = generate_config
|
||||
self.tokenizer_config = tokenizer_config
|
||||
self.consumer_plugin_config = consumer_plugin_config
|
||||
self.eval_interval = eval_interval
|
||||
self.eval_save_dir = eval_save_dir
|
||||
self.consumer_global_step = 0
|
||||
|
||||
if os.path.exists(self.eval_save_dir):
|
||||
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
||||
|
||||
# init tokenizer
|
||||
if tokenizer_config is None:
|
||||
@@ -75,13 +55,13 @@ class BaseProducer:
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# init dataloader
|
||||
train_dataset_path = train_dataset_config.pop("path")
|
||||
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
dataset_path = dataset_config.pop("path")
|
||||
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
|
||||
self.dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_size=microbatch_size,
|
||||
sampler=DistributedSampler(
|
||||
self.train_dataset,
|
||||
self.dataset,
|
||||
num_replicas=num_producers,
|
||||
rank=producer_idx,
|
||||
shuffle=True,
|
||||
@@ -91,36 +71,6 @@ class BaseProducer:
|
||||
num_workers=4,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
self.eval_dataset_config = eval_dataset_config
|
||||
if self.eval_dataset_config is not None:
|
||||
self.eval_dataloaders = {}
|
||||
for eval_task_name in self.eval_dataset_config:
|
||||
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
|
||||
eval_dataset = RawConversationDataset(
|
||||
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
|
||||
)
|
||||
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
||||
self.eval_dataloaders[eval_task_name] = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=microbatch_size,
|
||||
sampler=DistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=num_producers,
|
||||
rank=producer_idx,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
seed=42,
|
||||
),
|
||||
)
|
||||
if evaluation_function_type == "think_answer_tags":
|
||||
self.evaluation_function = math_reward_fn
|
||||
elif evaluation_function_type == "boxed":
|
||||
self.evaluation_function = boxed_math_reward_fn
|
||||
else:
|
||||
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||
else:
|
||||
raise ValueError("eval_dataset_config is not defined")
|
||||
self.device = get_current_device()
|
||||
|
||||
# init backend
|
||||
@@ -129,7 +79,7 @@ class BaseProducer:
|
||||
else:
|
||||
raise ValueError(f"Unexpected backend {backend}")
|
||||
|
||||
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
||||
|
||||
def setup(self) -> None:
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||
@@ -146,67 +96,29 @@ class BaseProducer:
|
||||
raise NotImplementedError
|
||||
|
||||
def loop(self) -> None:
|
||||
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||
num_update_per_episode = len(self.dataloader) // self.num_microbatches
|
||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||
|
||||
print(
|
||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
|
||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
self.train_dataloader.sampler.set_epoch(episode)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
self.dataloader.sampler.set_epoch(episode)
|
||||
for i, batch in enumerate(self.dataloader):
|
||||
if i >= num_valid_microbatches:
|
||||
break
|
||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||
if i % self.eval_interval == 0:
|
||||
eval_statistics = {}
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
||||
for eval_batch in tqdm.tqdm(
|
||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||
):
|
||||
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||
eval_results = eval_results + [
|
||||
self.evaluation_function(
|
||||
eval_outputs["input_ids"][m][n],
|
||||
eval_outputs["gt_answer"][m][n],
|
||||
eval_outputs["response_idx"][m][n],
|
||||
tokenizer=self.tokenizer,
|
||||
eval_mode=True,
|
||||
)
|
||||
for m in range(eval_outputs["input_ids"].size(0))
|
||||
for n in range(eval_outputs["input_ids"].size(1))
|
||||
]
|
||||
eval_statistics[eval_task_name][0] += len(
|
||||
[res for res in eval_results if res["ans_valid"] == 1]
|
||||
)
|
||||
eval_statistics[eval_task_name][1] += len(eval_results)
|
||||
# save eval results
|
||||
result_file_name = os.path.join(
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||
)
|
||||
# delete the file if it exists
|
||||
safe_write_jsonl(result_file_name, eval_results)
|
||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
||||
eval_statistics["consumer_global_step"] = torch.tensor(
|
||||
[self.consumer_global_step], device=self.device
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
eval_statistics,
|
||||
src=0,
|
||||
device=self.device,
|
||||
group_name=f"sync_data_{self.producer_idx}",
|
||||
)
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
[
|
||||
(
|
||||
self.model.generate_config["temperature"]
|
||||
if isinstance(self.model.generate_config.temperature, dict)
|
||||
else self.model.generate_config.temperature
|
||||
)
|
||||
]
|
||||
* outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
outputs = pre_send(outputs)
|
||||
ray_broadcast_tensor_dict(
|
||||
@@ -238,8 +150,6 @@ class BaseProducer:
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
@@ -249,12 +159,15 @@ class BaseProducer:
|
||||
self.model.llm.wake_up()
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
if hasattr(self.model, "sample_params"):
|
||||
self.model.sample_params.temperature = self.model.generate_config["temperature"]
|
||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||
if isinstance(self.model.generate_config.temperature, dict):
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
else:
|
||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -266,7 +179,7 @@ class SimpleProducer(BaseProducer):
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
dataset_config,
|
||||
dataloaders_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
@@ -275,10 +188,6 @@ class SimpleProducer(BaseProducer):
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
consumer_plugin_config=None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
eval_save_dir: str = "./eval",
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@@ -286,7 +195,7 @@ class SimpleProducer(BaseProducer):
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
dataset_config,
|
||||
dataloaders_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
@@ -294,15 +203,8 @@ class SimpleProducer(BaseProducer):
|
||||
microbatch_size,
|
||||
backend,
|
||||
consumer_plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
evaluation_function_type=evaluation_function_type,
|
||||
eval_save_dir=eval_save_dir,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
|
Reference in New Issue
Block a user