From c2561f826ad0d81c124a925997fd05e688242204 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 20 Jun 2025 15:44:13 +0800 Subject: [PATCH] fix bugs --- .../coati/distributed/consumer.py | 105 ++++++++++++------ .../coati/distributed/grpo_consumer.py | 2 +- .../coati/distributed/profiling_utils.py | 4 + applications/ColossalChat/profiling.sh | 12 +- applications/ColossalChat/rl_example.py | 1 + applications/ColossalChat/visualization.py | 4 +- 6 files changed, 90 insertions(+), 38 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 98aa3a3c4..4265ac7e2 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -107,6 +107,37 @@ class BaseConsumer: def step(self, step_idx: int, **kwargs) -> Optional[float]: raise NotImplementedError + def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]: + """ + Prepare a mini-batch from the effective group to raw group mapping. + This method is used to create a mini-batch for training. + """ + batches = [ + self.buffer[effective_group_to_raw_group_mapping[i]] + for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size) + ] + # every dp_rank will receive a complete mini-batch, no need to sync within step() later + # each mini-batch use the first self.dp_size * minibatch_size effective samples + raw_mini_batches = self.buffer[ + : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 + ] # include the last effective sample + raw_mini_batches_metric_dict = { + "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], + "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], + "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], + "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], + } + batch = bind_batch([t[0] for t in batches]) + batch = post_recv(batch) + return batch, raw_mini_batches_metric_dict + + def calculate_effective_group_to_raw_group_mapping(self): + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx + return effective_group_to_raw_group_mapping + def loop(self) -> None: print( f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" @@ -121,6 +152,38 @@ class BaseConsumer: torch.cuda.reset_peak_memory_stats() i = 0 for _ in range(self.num_recv_per_update): + # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() + while len(effective_group_to_raw_group_mapping) > max( + self.dp_size * self.batch_size + - self.dp_size + * self.minibatch_size + * self.grpo_config.get("num_minibatch_during_rollout", 1), + self.dp_size * self.minibatch_size, + ): + self.profiler.log( + f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training" + ) + batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( + effective_group_to_raw_group_mapping + ) + self.profiler.enter("step") + loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.profiler.exit("step") + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size + ) + if loss is not None: + pbar.set_postfix({"loss": loss}) + i += 1 + # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") @@ -170,37 +233,20 @@ class BaseConsumer: f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" ) # mapping the effective group to the raw group for indexing - effective_group_to_raw_group_mapping = {} - for buffer_idx in range(len(self.buffer)): - if self.buffer[buffer_idx][0] is not None: - effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( - buffer_idx - ) + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() print( f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}" ) - while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size: + self.profiler.log( + f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv" + ) + # always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model # on each dp_rank, we use minibatch_size effective samples to form a batch - batches = [ - self.buffer[effective_group_to_raw_group_mapping[i]] - for i in range( - self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size - ) - ] - # every dp_rank will receive a complete mini-batch, no need to sync within step() later - # each mini-batch use the first self.dp_size * minibatch_size effective samples - raw_mini_batches = self.buffer[ - : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 - ] # include the last effective sample - raw_mini_batches_metric_dict = { - "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], - "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], - "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], - "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], - } - batch = bind_batch([t[0] for t in batches]) - batch = post_recv(batch) + batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( + effective_group_to_raw_group_mapping + ) self.profiler.enter("step") loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) self.profiler.exit("step") @@ -209,12 +255,7 @@ class BaseConsumer: ] # recalculate the effective group to raw group mapping effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) - effective_group_to_raw_group_mapping = {} - for buffer_idx in range(len(self.buffer)): - if self.buffer[buffer_idx][0] is not None: - effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( - buffer_idx - ) + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() assert ( len(effective_group_to_raw_group_mapping) == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5d4b95fda..5dcf3e051 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -379,7 +379,7 @@ class GRPOConsumer(BaseConsumer): reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) diff --git a/applications/ColossalChat/coati/distributed/profiling_utils.py b/applications/ColossalChat/coati/distributed/profiling_utils.py index 4c4621a48..1c1169b50 100644 --- a/applications/ColossalChat/coati/distributed/profiling_utils.py +++ b/applications/ColossalChat/coati/distributed/profiling_utils.py @@ -1,3 +1,7 @@ +import os +import time + + class CustomProfiler: def __init__(self, name, disabled=True): self.disabled = disabled diff --git a/applications/ColossalChat/profiling.sh b/applications/ColossalChat/profiling.sh index 1123f97f2..d9f3d9a93 100755 --- a/applications/ColossalChat/profiling.sh +++ b/applications/ColossalChat/profiling.sh @@ -1,7 +1,13 @@ export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 # 8K context length +# rm -rf *.prof +# MAX_NEW_TOKENS=$((8192-512)) +# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt +# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png + +# 4K context length rm -rf *.prof -MAX_NEW_TOKENS=$((8192-512)) -python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt -python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png +MAX_NEW_TOKENS=$((4096-512)) +python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt +python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index da381f8a7..cb3766e44 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -263,6 +263,7 @@ if __name__ == "__main__": grpo_config = { "lr": args.learning_rate, "train_microbatch_size": args.train_microbatch_size, + "num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer. "beta": args.kl_coeff, # KL penalty coefficient "loss_variation": "sample_level", "reward_fn_type": args.reward_type, diff --git a/applications/ColossalChat/visualization.py b/applications/ColossalChat/visualization.py index bde49303f..b3a706d69 100644 --- a/applications/ColossalChat/visualization.py +++ b/applications/ColossalChat/visualization.py @@ -74,8 +74,8 @@ for idx, (actor, func_dict) in enumerate(actors.items()): yticks.append(y_val) yticklabels.append(f"{actor}:{func}") for start, end in intervals: - if end - start < 100: - end = start + 100 # Ensure minimum length of 100ms + if end - start < 6: + end = start + 6 # Ensure minimum length of 100ms ax.plot( [start, end], [y_val, y_val],