mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 19:49:30 +00:00
fix loop issue
This commit is contained in:
parent
8abf186ce2
commit
71ef6b32c6
@ -151,8 +151,9 @@ class BaseConsumer:
|
||||
for step in pbar:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
i = 0
|
||||
for _ in range(self.num_recv_per_update):
|
||||
|
||||
self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
|
||||
for _ in range(self.num_recv_per_update):
|
||||
if self.n_behind > 0:
|
||||
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||
@ -244,36 +245,36 @@ class BaseConsumer:
|
||||
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||
)
|
||||
|
||||
if self.n_behind == 0:
|
||||
# If n_behind is 0, we start training after receiving data from producers.
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||
step=step
|
||||
)
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
self.profiler.log(
|
||||
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||
)
|
||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.profiler.exit("step")
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
|
||||
step=step
|
||||
)
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
if self.n_behind == 0:
|
||||
# If n_behind is 0, we start training after receiving data from producers.
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
self.profiler.log(
|
||||
f"Collect {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
|
||||
)
|
||||
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.profiler.exit("step")
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(
|
||||
effective_group_to_raw_group_mapping
|
||||
)
|
||||
effective_group_to_raw_group_mapping = (
|
||||
self.calculate_effective_group_to_raw_group_mapping(step=step)
|
||||
)
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before
|
||||
- self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
@ -314,6 +315,7 @@ class BaseConsumer:
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.exit("sync_model")
|
||||
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
self.profiler.exit(f"rollout_episode_{episode}_step_{step}")
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "profiler"):
|
||||
|
Loading…
Reference in New Issue
Block a user