diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 54ca611de..ed6b991c3 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -128,9 +128,8 @@ class BaseConsumer: k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics } eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} - if dist.get_rank() == 0: - if hasattr(self, "wandb_run"): - self.wandb_run.log(eval_statistics, step=eval_global_step) + if hasattr(self, "wandb_run"): + self.wandb_run.log(eval_statistics, step=eval_global_step) print(f"Eval statistics: {eval_statistics}") for _ in range(self.num_recv_per_update): # receive data from producers diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index b4175fc26..5dde66435 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -149,9 +149,13 @@ class GRPOConsumer(BaseConsumer): def setup(self): super().setup() - if self.use_wandb and ( - (not self.plugin.pp_size > 1 and self.rank == 0) - or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) + if ( + self.use_wandb + and self.dp_rank == 0 + and ( + (not self.plugin.pp_size > 1 and self.rank == 0) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) + ) ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" @@ -482,8 +486,13 @@ class GRPOConsumer(BaseConsumer): if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - if (not self.plugin.pp_size > 1 and self.rank == 0) or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + if self.dp_rank == 0 and ( + (not self.plugin.pp_size > 1 and self.rank == 0) + or ( + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + ) ): to_log_msg = [ f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 1255124b9..607b5eefc 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -148,7 +148,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) final_answer = extract_boxed_solution(decoded_final_answer) format_valid = final_answer is not None - if "tags" in kwargs: + if "tags" in kwargs and kwargs["tags"]: tags = kwargs["tags"] format_valid = format_valid and all( [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags] diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 6611835f2..dc503459e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -6,13 +6,6 @@ import ray import torch from coati.distributed.launch import launch_distributed -DEFAULT_RESPONSE_FORMAT_TAGS = { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, -} - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") @@ -248,18 +241,18 @@ if __name__ == "__main__": num_generations=args.num_generations, train_model_config=train_model_config, grpo_config=grpo_config, - plugin_config={ - "zero_stage": 2, - }, # for zero # plugin_config={ - # "tp_size": 4, - # "pp_size": 2, - # "microbatch_size": max( - # 1, args.train_microbatch_size // 2 - # ), # microbatch size should be set to train_microbatch_size // pp_size - # "zero_stage": 1, - # "max_norm": 1.0, - # }, # for pp, tp + # "zero_stage": 2, + # }, # for zero + plugin_config={ + "tp_size": 4, + "pp_size": 2, + "microbatch_size": max( + 1, args.train_microbatch_size // 2 + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": 1, + "max_norm": 1.0, + }, # for pp, tp inference_backend=args.backend, master_port=args.torch_ddp_master_port, core_algo=args.algo, @@ -272,9 +265,5 @@ if __name__ == "__main__": }, eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), - response_format_tags=( - json.loads(args.response_format_tags) - if args.response_format_tags is not None - else DEFAULT_RESPONSE_FORMAT_TAGS - ), + response_format_tags=(json.loads(args.response_format_tags) if args.response_format_tags is not None else {}), )