fix behind

This commit is contained in:
Tong Li 2025-06-26 10:27:00 +08:00
parent db8baeeaf2
commit 8abf186ce2
3 changed files with 52 additions and 42 deletions

View File

@ -131,10 +131,10 @@ class BaseConsumer:
batch = post_recv(batch) batch = post_recv(batch)
return batch, raw_mini_batches_metric_dict return batch, raw_mini_batches_metric_dict
def calculate_effective_group_to_raw_group_mapping(self): def calculate_effective_group_to_raw_group_mapping(self, step):
effective_group_to_raw_group_mapping = {} effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)): for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None: if self.buffer[buffer_idx][0] is not None and (self.buffer[buffer_idx][-1] <= step - self.n_behind):
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
return effective_group_to_raw_group_mapping return effective_group_to_raw_group_mapping
@ -152,37 +152,40 @@ class BaseConsumer:
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
i = 0 i = 0
for _ in range(self.num_recv_per_update): for _ in range(self.num_recv_per_update):
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() if self.n_behind > 0:
while len(effective_group_to_raw_group_mapping) > max( # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
self.dp_size * self.batch_size effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
- self.dp_size step=step
* self.minibatch_size
* self.grpo_config.get("num_minibatch_during_rollout", 1),
self.dp_size * self.minibatch_size,
):
self.profiler.log(
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
) )
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
effective_group_to_raw_group_mapping self.profiler.log(
) f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
self.profiler.enter("step") )
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
self.profiler.exit("step") effective_group_to_raw_group_mapping
self.buffer = self.buffer[ )
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : self.profiler.enter("step")
] loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
# recalculate the effective group to raw group mapping self.profiler.exit("step")
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) self.buffer = self.buffer[
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
assert ( ]
len(effective_group_to_raw_group_mapping) # recalculate the effective group to raw group mapping
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size effective_group_to_raw_group_mapping_size_before = len(
) effective_group_to_raw_group_mapping
if loss is not None: )
pbar.set_postfix({"loss": loss}) effective_group_to_raw_group_mapping = (
i += 1 self.calculate_effective_group_to_raw_group_mapping(step=step)
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before
- self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
# receive data from producers # receive data from producers
for r in range(self.num_producers): for r in range(self.num_producers):
@ -226,6 +229,7 @@ class BaseConsumer:
format_acc[group_idx], format_acc[group_idx],
ans_acc[group_idx], ans_acc[group_idx],
response_len[group_idx], response_len[group_idx],
step,
] ]
) )
if effective_group_mask is not None: if effective_group_mask is not None:
@ -233,17 +237,22 @@ class BaseConsumer:
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
) )
# mapping the effective group to the raw group for indexing # mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
print( print(
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}" f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
) )
while len(effective_group_to_raw_group_mapping) > self.dp_size * self.batch_size: if self.n_behind == 0:
# If n_behind is 0, we start training after receiving data from producers.
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
self.profiler.log( self.profiler.log(
f"Received {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.batch_size}, start training after recv" f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
) )
# always keep at least dp_size * batch_size effective samples in the buffer for training during the rollout times after each sync model
# on each dp_rank, we use minibatch_size effective samples to form a batch
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping effective_group_to_raw_group_mapping
) )
@ -255,7 +264,9 @@ class BaseConsumer:
] ]
# recalculate the effective group to raw group mapping # recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping() effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
assert ( assert (
len(effective_group_to_raw_group_mapping) len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
@ -263,6 +274,7 @@ class BaseConsumer:
if loss is not None: if loss is not None:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
self.lr_scheduler.step() self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:

View File

@ -337,9 +337,8 @@ class BaseProducer:
): ):
self.model.llm.sleep() # revict KV_cache to avoid OOM self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration # don't sync model for last iteration
self.profiler.enter("sync_model")
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1: if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size): for pp_idx in range(self.consumer_pp_size):
print( print(
@ -361,9 +360,9 @@ class BaseProducer:
if "consumer_global_step" in state_dict: if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.profiler.exit("sync_model")
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False "enable_sleep_mode", False
): ):

View File

@ -263,7 +263,6 @@ if __name__ == "__main__":
grpo_config = { grpo_config = {
"lr": args.learning_rate, "lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size, "train_microbatch_size": args.train_microbatch_size,
"num_minibatch_during_rollout": 1, # number of mini batches to pop out from buffer and used for training during rollout of the producer after it syncs the model. Hint, set to a proper value close to the number of mini batches for training that takes roughly the same time as the rollout of the producer. A value that is too large or too small will cause bubble time on the trainer or the producer.
"beta": args.kl_coeff, # KL penalty coefficient "beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level", "loss_variation": "sample_level",
"reward_fn_type": args.reward_type, "reward_fn_type": args.reward_type,