fix loop issue

This commit is contained in:
Tong Li 2025-06-26 15:08:27 +08:00
parent 8abf186ce2
commit 71ef6b32c6

View File

@ -151,8 +151,9 @@ class BaseConsumer:
for step in pbar: for step in pbar:
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
i = 0 i = 0
for _ in range(self.num_recv_per_update):
self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
for _ in range(self.num_recv_per_update):
if self.n_behind > 0: if self.n_behind > 0:
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data # after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
@ -246,12 +247,9 @@ class BaseConsumer:
if self.n_behind == 0: if self.n_behind == 0:
# If n_behind is 0, we start training after receiving data from producers. # If n_behind is 0, we start training after receiving data from producers.
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
self.profiler.log( self.profiler.log(
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training" f"Collect {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
) )
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch( batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping effective_group_to_raw_group_mapping
@ -263,13 +261,16 @@ class BaseConsumer:
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
] ]
# recalculate the effective group to raw group mapping # recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) effective_group_to_raw_group_mapping_size_before = len(
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping( effective_group_to_raw_group_mapping
step=step )
effective_group_to_raw_group_mapping = (
self.calculate_effective_group_to_raw_group_mapping(step=step)
) )
assert ( assert (
len(effective_group_to_raw_group_mapping) len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size == effective_group_to_raw_group_mapping_size_before
- self.dp_size * self.minibatch_size
) )
if loss is not None: if loss is not None:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
@ -314,6 +315,7 @@ class BaseConsumer:
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.profiler.exit("sync_model") self.profiler.exit("sync_model")
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
self.profiler.exit(f"rollout_episode_{episode}_step_{step}")
def __del__(self): def __del__(self):
if hasattr(self, "profiler"): if hasattr(self, "profiler"):