diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index 54ca611de..ed6b991c3 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -128,9 +128,8 @@ class BaseConsumer:
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
}
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
- if dist.get_rank() == 0:
- if hasattr(self, "wandb_run"):
- self.wandb_run.log(eval_statistics, step=eval_global_step)
+ if hasattr(self, "wandb_run"):
+ self.wandb_run.log(eval_statistics, step=eval_global_step)
print(f"Eval statistics: {eval_statistics}")
for _ in range(self.num_recv_per_update):
# receive data from producers
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index b4175fc26..5dde66435 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -149,9 +149,13 @@ class GRPOConsumer(BaseConsumer):
def setup(self):
super().setup()
- if self.use_wandb and (
- (not self.plugin.pp_size > 1 and self.rank == 0)
- or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
+ if (
+ self.use_wandb
+ and self.dp_rank == 0
+ and (
+ (not self.plugin.pp_size > 1 and self.rank == 0)
+ or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
+ )
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
@@ -482,8 +486,13 @@ class GRPOConsumer(BaseConsumer):
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
- if (not self.plugin.pp_size > 1 and self.rank == 0) or (
- self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
+ if self.dp_rank == 0 and (
+ (not self.plugin.pp_size > 1 and self.rank == 0)
+ or (
+ self.plugin.pp_size > 1
+ and self.booster.plugin.stage_manager.is_last_stage()
+ and self.tp_rank == 0
+ )
):
to_log_msg = [
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
index 1255124b9..607b5eefc 100644
--- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py
+++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
@@ -148,7 +148,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
- if "tags" in kwargs:
+ if "tags" in kwargs and kwargs["tags"]:
tags = kwargs["tags"]
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 6611835f2..dc503459e 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -6,13 +6,6 @@ import ray
import torch
from coati.distributed.launch import launch_distributed
-DEFAULT_RESPONSE_FORMAT_TAGS = {
- "think_start": {"text": "", "num_occur": 1},
- "think_end": {"text": "", "num_occur": 1},
- "answer_start": {"text": "", "num_occur": 1},
- "answer_end": {"text": "", "num_occur": 1},
-}
-
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@@ -248,18 +241,18 @@ if __name__ == "__main__":
num_generations=args.num_generations,
train_model_config=train_model_config,
grpo_config=grpo_config,
- plugin_config={
- "zero_stage": 2,
- }, # for zero
# plugin_config={
- # "tp_size": 4,
- # "pp_size": 2,
- # "microbatch_size": max(
- # 1, args.train_microbatch_size // 2
- # ), # microbatch size should be set to train_microbatch_size // pp_size
- # "zero_stage": 1,
- # "max_norm": 1.0,
- # }, # for pp, tp
+ # "zero_stage": 2,
+ # }, # for zero
+ plugin_config={
+ "tp_size": 4,
+ "pp_size": 2,
+ "microbatch_size": max(
+ 1, args.train_microbatch_size // 2
+ ), # microbatch size should be set to train_microbatch_size // pp_size
+ "zero_stage": 1,
+ "max_norm": 1.0,
+ }, # for pp, tp
inference_backend=args.backend,
master_port=args.torch_ddp_master_port,
core_algo=args.algo,
@@ -272,9 +265,5 @@ if __name__ == "__main__":
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
- response_format_tags=(
- json.loads(args.response_format_tags)
- if args.response_format_tags is not None
- else DEFAULT_RESPONSE_FORMAT_TAGS
- ),
+ response_format_tags=(json.loads(args.response_format_tags) if args.response_format_tags is not None else {}),
)