mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
move logging to producer
This commit is contained in:
@@ -6,8 +6,11 @@ import ray
|
||||
import ray.util.collective as cc
|
||||
import torch
|
||||
import tqdm
|
||||
import wandb
|
||||
from coati.dataset.loader import RawConversationDataset
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from ray.util.collective import allreduce
|
||||
from ray.util.collective.types import Backend, ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -15,7 +18,7 @@ from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .inference_backend import BACKEND_MAP
|
||||
from .utils import pre_send, safe_write_jsonl
|
||||
from .utils import pre_send, safe_append_to_jsonl_file
|
||||
|
||||
try:
|
||||
from vllm import SamplingParams
|
||||
@@ -43,6 +46,9 @@ class BaseProducer:
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
eval_save_dir: str = "./eval",
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
@@ -61,6 +67,14 @@ class BaseProducer:
|
||||
self.eval_interval = eval_interval
|
||||
self.eval_save_dir = eval_save_dir
|
||||
self.consumer_global_step = 0
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run = wandb.init(
|
||||
project=project_name,
|
||||
sync_tensorboard=True,
|
||||
dir="./wandb",
|
||||
name=run_name + "_eval",
|
||||
group=wandb_group_name,
|
||||
)
|
||||
|
||||
if os.path.exists(self.eval_save_dir):
|
||||
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
||||
@@ -132,13 +146,18 @@ class BaseProducer:
|
||||
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||
|
||||
def setup(self) -> None:
|
||||
cc.init_collective_group(
|
||||
world_size=self.num_producers,
|
||||
rank=self.producer_idx,
|
||||
backend=Backend.NCCL,
|
||||
group_name="producer_group",
|
||||
)
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||
if self.consumer_pp_size > 1:
|
||||
for i in range(self.consumer_pp_size):
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||
else:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@@ -160,13 +179,14 @@ class BaseProducer:
|
||||
break
|
||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||
if i % self.eval_interval == 0:
|
||||
eval_statistics = {}
|
||||
to_log_msg = {}
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
||||
)
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||
for eval_batch in tqdm.tqdm(
|
||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||
):
|
||||
@@ -182,24 +202,27 @@ class BaseProducer:
|
||||
for m in range(eval_outputs["input_ids"].size(0))
|
||||
for n in range(eval_outputs["input_ids"].size(1))
|
||||
]
|
||||
eval_statistics[eval_task_name][0] += len(
|
||||
[res for res in eval_results if res["ans_valid"] == 1]
|
||||
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
|
||||
eval_statistics_tensor[1] += len(eval_results)
|
||||
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
|
||||
to_log_msg[f"eval/{eval_task_name}"] = (
|
||||
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
|
||||
)
|
||||
eval_statistics[eval_task_name][1] += len(eval_results)
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
|
||||
)
|
||||
# save eval results
|
||||
result_file_name = os.path.join(
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||
safe_append_to_jsonl_file(
|
||||
os.path.join(
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||
),
|
||||
eval_results,
|
||||
)
|
||||
# delete the file if it exists
|
||||
safe_write_jsonl(result_file_name, eval_results)
|
||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
||||
ray_broadcast_tensor_dict(
|
||||
eval_statistics,
|
||||
src=0,
|
||||
device=self.device,
|
||||
group_name=f"sync_eval_statistics_{self.producer_idx}",
|
||||
)
|
||||
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
@@ -248,12 +271,11 @@ class BaseProducer:
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||
if isinstance(self.model.generate_config.temperature, dict):
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
else:
|
||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
|
||||
@@ -280,6 +302,10 @@ class SimpleProducer(BaseProducer):
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@@ -299,10 +325,14 @@ class SimpleProducer(BaseProducer):
|
||||
eval_interval=eval_interval,
|
||||
evaluation_function_type=evaluation_function_type,
|
||||
eval_save_dir=eval_save_dir,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
Reference in New Issue
Block a user