From 50070c1e84a36f11d21efddd3c6828cd32b56f16 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 14 May 2025 18:10:57 +0800 Subject: [PATCH] move logging to producer --- .../coati/distributed/consumer.py | 25 ------ .../coati/distributed/grpo_consumer.py | 30 ++++--- .../coati/distributed/inference_backend.py | 2 +- .../ColossalChat/coati/distributed/launch.py | 14 ++- .../coati/distributed/producer.py | 86 +++++++++++++------ .../ColossalChat/coati/distributed/utils.py | 2 +- applications/ColossalChat/rl_example.py | 3 + 7 files changed, 92 insertions(+), 70 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 620ab6ff6..800a3c799 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -36,7 +36,6 @@ class BaseConsumer: minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", - eval_interval: int = -1, ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -52,7 +51,6 @@ class BaseConsumer: self.save_dir = save_dir assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size - self.eval_interval = eval_interval self.model_config = model_config self.plugin_config = plugin_config @@ -94,9 +92,6 @@ class BaseConsumer: if self.rank == 0: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") - for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}") - self.buffer = [] self.recv_cnt = 0 @@ -114,24 +109,6 @@ class BaseConsumer: with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: for step in pbar: i = 0 - if self.eval_interval > 0 and step % self.eval_interval == 0: - eval_statistics = None - for r in range(self.num_producers): - print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}") - local_eval_result = ray_broadcast_tensor_dict( - None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}" - ) - if eval_statistics is None: - eval_statistics = local_eval_result - else: - eval_statistics = { - k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics - } - eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} - if dist.get_rank() == 0: - if hasattr(self, "wandb_run") and hasattr(self, "global_step"): - self.wandb_run.log(eval_statistics, step=self.global_step) - print(f"Eval statistics: {eval_statistics}") for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -214,7 +191,6 @@ class SimpleConsumer(BaseConsumer): minibatch_size=1, save_interval: int = 100, save_dir="./model", - eval_interval: int = -1, ): super().__init__( num_producers, @@ -231,7 +207,6 @@ class SimpleConsumer(BaseConsumer): minibatch_size, save_interval, save_dir, - eval_interval, ) path = model_config.pop("path") self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 0336125d1..f301345b6 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -34,13 +34,13 @@ class GRPOConsumer(BaseConsumer): plugin_config, minibatch_size=1, num_generations=8, - use_wandb=True, generate_config=None, grpo_config={}, - project_name=None, save_interval: int = 100, save_dir="./model", - eval_interval: int = -1, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): print(f"Using GRPO config: {grpo_config}") if grpo_config.get("loss_variation", "sample_level") == "token_level": @@ -73,7 +73,6 @@ class GRPOConsumer(BaseConsumer): minibatch_size, save_interval=save_interval, save_dir=save_dir, - eval_interval=eval_interval, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -93,6 +92,9 @@ class GRPOConsumer(BaseConsumer): self.project_name = project_name self.effective_sample_count = 0 self.total_sample_count = 0 + self.project_name = project_name + self.run_name = run_name + self.wandb_group_name = wandb_group_name self.policy_loss_fn = PolicyLoss( clip_eps_low=grpo_config.get("clip_eps_low", 0.2), @@ -143,7 +145,6 @@ class GRPOConsumer(BaseConsumer): **reward_model_kwargs, ) self.global_step = 0 - self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -154,13 +155,16 @@ class GRPOConsumer(BaseConsumer): def setup(self): super().setup() - if self.use_wandb and ( - (not self.plugin.pp_size > 1 and self.rank == 0) - or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - # Initialize wandb. - name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init( + project=self.project_name, + sync_tensorboard=True, + dir="./wandb", + name=self.run_name, + group=self.wandb_group_name, + ) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler @@ -512,8 +516,8 @@ class GRPOConsumer(BaseConsumer): } if self.policy_loss_fn.beta > 0: metrics["train/kl"] = self.accum_kl.item() / self.accum_count - - self.wandb_run.log(metrics) + if self.wandb_run is not None: + self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_ans_acc.zero_() diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 64dfe5cfd..ce2a2f28c 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -238,7 +238,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): log_probs.append(p) # pad them - max_len = self.generate_config.max_tokens + max_len = self.sample_params.max_tokens action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) for i, new_token_ids in enumerate(out_tokens): diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index d1b6158e5..cb5b6e4e2 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,4 +1,5 @@ import copy +import uuid from typing import Any, Dict, Optional import ray @@ -53,6 +54,7 @@ def launch_distributed( eval_dataset_config: Optional[Dict[str, Any]] = None, eval_interval: int = 100, eval_save_dir: Optional[str] = None, + eval_generation_config: Optional[Dict[str, Any]] = None, ): if core_algo not in ALGO_MAP: @@ -69,6 +71,9 @@ def launch_distributed( num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size + run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" + wandb_group_name = str(uuid.uuid4()) + procs = [] for i in range(num_producers): producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( @@ -90,6 +95,10 @@ def launch_distributed( eval_interval=eval_interval, evaluation_function_type=grpo_config["reward_fn_type"], eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) @@ -115,10 +124,11 @@ def launch_distributed( generate_config=generate_config_consumer, grpo_config=grpo_config, num_generations=num_generations, - project_name=project_name, save_interval=save_interval, save_dir=save_dir, - eval_interval=eval_interval, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 483ad74b3..fa2d1cc8d 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -6,8 +6,11 @@ import ray import ray.util.collective as cc import torch import tqdm +import wandb from coati.dataset.loader import RawConversationDataset from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn +from ray.util.collective import allreduce +from ray.util.collective.types import Backend, ReduceOp from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer @@ -15,7 +18,7 @@ from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP -from .utils import pre_send, safe_write_jsonl +from .utils import pre_send, safe_append_to_jsonl_file try: from vllm import SamplingParams @@ -43,6 +46,9 @@ class BaseProducer: eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", eval_save_dir: str = "./eval", + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -61,6 +67,14 @@ class BaseProducer: self.eval_interval = eval_interval self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 + if self.producer_idx == 0: + self.wandb_run = wandb.init( + project=project_name, + sync_tensorboard=True, + dir="./wandb", + name=run_name + "_eval", + group=wandb_group_name, + ) if os.path.exists(self.eval_save_dir): raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") @@ -132,13 +146,18 @@ class BaseProducer: self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size def setup(self) -> None: + cc.init_collective_group( + world_size=self.num_producers, + rank=self.producer_idx, + backend=Backend.NCCL, + group_name="producer_group", + ) cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") if self.consumer_pp_size > 1: for i in range(self.consumer_pp_size): cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") else: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") - cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -160,13 +179,14 @@ class BaseProducer: break if self.eval_interval > 0 and self.eval_dataset_config is not None: if i % self.eval_interval == 0: - eval_statistics = {} + to_log_msg = {} for eval_task_name in self.eval_dataloaders: - print( - f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" - ) + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" + ) eval_results = [] - eval_statistics[eval_task_name] = torch.zeros(2, device=self.device) + eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) for eval_batch in tqdm.tqdm( self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 ): @@ -182,24 +202,27 @@ class BaseProducer: for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) ] - eval_statistics[eval_task_name][0] += len( - [res for res in eval_results if res["ans_valid"] == 1] + eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1]) + eval_statistics_tensor[1] += len(eval_results) + allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group") + to_log_msg[f"eval/{eval_task_name}"] = ( + eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item() ) - eval_statistics[eval_task_name][1] += len(eval_results) + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}" + ) # save eval results - result_file_name = os.path.join( - self.eval_save_dir, - f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + safe_append_to_jsonl_file( + os.path.join( + self.eval_save_dir, + f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + ), + eval_results, ) - # delete the file if it exists - safe_write_jsonl(result_file_name, eval_results) - print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}") - ray_broadcast_tensor_dict( - eval_statistics, - src=0, - device=self.device, - group_name=f"sync_eval_statistics_{self.producer_idx}", - ) + + if self.producer_idx == 0: + self.wandb_run.log(to_log_msg, step=self.consumer_global_step) outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") @@ -248,12 +271,11 @@ class BaseProducer: # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) - if isinstance(self.model.generate_config.temperature, dict): - self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.9 - else: - self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + if isinstance(self.model, BACKEND_MAP["vllm"]): + self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ "temperature" ] + ratio * 0.9 @@ -280,6 +302,10 @@ class SimpleProducer(BaseProducer): eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): super().__init__( producer_idx, @@ -299,10 +325,14 @@ class SimpleProducer(BaseProducer): eval_interval=eval_interval, evaluation_function_type=evaluation_function_type, eval_save_dir=eval_save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.eval_generation_config = copy.deepcopy(self.model.generate_config) self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) self.eval_sample_params = SamplingParams(**self.eval_generation_config) @torch.no_grad() diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index fb2a0dfd3..a40ebbcfb 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -135,7 +135,7 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch. return tensor.sum(dim=dim) -def safe_write_jsonl(file_path, data): +def safe_append_to_jsonl_file(file_path, data): with FileLock(file_path + ".lock"): # Ensure file exists os.makedirs(os.path.dirname(file_path), exist_ok=True) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index b2fafd42e..9fecdb52d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -161,6 +161,7 @@ if __name__ == "__main__": stop_strings=[""] if args.reward_type == "think_answer_tags" else None, ) ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation elif args.backend == "vllm": inference_model_config.update( dict( @@ -179,6 +180,7 @@ if __name__ == "__main__": stop=[""] if args.reward_type == "think_answer_tags" else None, ) ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation else: raise ValueError(f"Unsupported backend: {args.backend}") @@ -257,4 +259,5 @@ if __name__ == "__main__": }, eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), + eval_generation_config=eval_generation_config, )