From 8abf186ce2db8367e895642cf9a61f7e886d7671 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 26 Jun 2025 10:27:00 +0800 Subject: [PATCH] fix behind --- .../coati/distributed/consumer.py | 88 +++++++++++-------- .../coati/distributed/producer.py | 5 +- applications/ColossalChat/rl_example.py | 1 - 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 4265ac7e2..8374814ee 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -131,10 +131,10 @@ class BaseConsumer: batch = post_recv(batch) return batch, raw_mini_batches_metric_dict - def calculate_effective_group_to_raw_group_mapping(self): + def calculate_effective_group_to_raw_group_mapping(self, step): effective_group_to_raw_group_mapping = {} for buffer_idx in range(len(self.buffer)): - if self.buffer[buffer_idx][0] is not None: + if self.buffer[buffer_idx][0] is not None and (self.buffer[buffer_idx][-1] <= step - self.n_behind): effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx return effective_group_to_raw_group_mapping @@ -152,37 +152,40 @@ class BaseConsumer: torch.cuda.reset_peak_memory_stats() i = 0 for _ in range(self.num_recv_per_update): - # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data - effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() - while len(effective_group_to_raw_group_mapping) > max( - self.dp_size * self.batch_size - - self.dp_size - * self.minibatch_size - * self.grpo_config.get("num_minibatch_during_rollout", 1), - self.dp_size * self.minibatch_size, - ): - self.profiler.log( - f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training" + + if self.n_behind > 0: + # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( + step=step ) - batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( - effective_group_to_raw_group_mapping - ) - self.profiler.enter("step") - loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) - self.profiler.exit("step") - self.buffer = self.buffer[ - effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : - ] - # recalculate the effective group to raw group mapping - effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) - effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() - assert ( - len(effective_group_to_raw_group_mapping) - == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size - ) - if loss is not None: - pbar.set_postfix({"loss": loss}) - i += 1 + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + self.profiler.log( + f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training" + ) + batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( + effective_group_to_raw_group_mapping + ) + self.profiler.enter("step") + loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.profiler.exit("step") + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len( + effective_group_to_raw_group_mapping + ) + effective_group_to_raw_group_mapping = ( + self.calculate_effective_group_to_raw_group_mapping(step=step) + ) + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before + - self.dp_size * self.minibatch_size + ) + if loss is not None: + pbar.set_postfix({"loss": loss}) + i += 1 # receive data from producers for r in range(self.num_producers): @@ -226,6 +229,7 @@ class BaseConsumer: format_acc[group_idx], ans_acc[group_idx], response_len[group_idx], + step, ] ) if effective_group_mask is not None: @@ -233,17 +237,22 @@ class BaseConsumer: f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" ) # mapping the effective group to the raw group for indexing - effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( + step=step + ) print( f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}" ) - while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size: + if self.n_behind == 0: + # If n_behind is 0, we start training after receiving data from producers. + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( + step=step + ) + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: self.profiler.log( - f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv" + f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training" ) - # always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model - # on each dp_rank, we use minibatch_size effective samples to form a batch batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( effective_group_to_raw_group_mapping ) @@ -255,7 +264,9 @@ class BaseConsumer: ] # recalculate the effective group to raw group mapping effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) - effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() + effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( + step=step + ) assert ( len(effective_group_to_raw_group_mapping) == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size @@ -263,6 +274,7 @@ class BaseConsumer: if loss is not None: pbar.set_postfix({"loss": loss}) i += 1 + if self.lr_scheduler is not None: self.lr_scheduler.step() if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 1b23c463d..2a3746391 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -337,9 +337,8 @@ class BaseProducer: ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration - self.profiler.enter("sync_model") torch.cuda.empty_cache() - + self.profiler.enter("sync_model") if self.consumer_pp_size > 1: for pp_idx in range(self.consumer_pp_size): print( @@ -361,9 +360,9 @@ class BaseProducer: if "consumer_global_step" in state_dict: self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) + self.profiler.exit("sync_model") del state_dict torch.cuda.empty_cache() - self.profiler.exit("sync_model") if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False ): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index cb3766e44..da381f8a7 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -263,7 +263,6 @@ if __name__ == "__main__": grpo_config = { "lr": args.learning_rate, "train_microbatch_size": args.train_microbatch_size, - "num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer. "beta": args.kl_coeff, # KL penalty coefficient "loss_variation": "sample_level", "reward_fn_type": args.reward_type,