mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 08:53:22 +00:00
handle empty index
This commit is contained in:
parent
88f49ddc5e
commit
6ebd813b5f
@ -113,7 +113,6 @@ class BaseConsumer:
|
|||||||
) as pbar:
|
) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
i = 0
|
i = 0
|
||||||
allow_sync_model = True
|
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
@ -140,7 +139,6 @@ class BaseConsumer:
|
|||||||
loss = self.step(i, pbar, **batch)
|
loss = self.step(i, pbar, **batch)
|
||||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
allow_sync_model = True
|
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
i += 1
|
i += 1
|
||||||
if self.lr_scheduler is not None:
|
if self.lr_scheduler is not None:
|
||||||
@ -154,7 +152,6 @@ class BaseConsumer:
|
|||||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||||
|
|
||||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||||
if allow_sync_model:
|
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
print(
|
print(
|
||||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||||
@ -178,7 +175,6 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
allow_sync_model = True
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -218,7 +218,29 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if self.grpo_config.get("dynamic_batching", True):
|
if self.grpo_config.get("dynamic_batching", True):
|
||||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||||
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
|
|
||||||
|
if excessive_prompts > 0:
|
||||||
|
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
||||||
|
# Only count excessive prompts if they are greater than 1 per rank.
|
||||||
|
# TODO: customize excessive prompts calculation.
|
||||||
|
if excessive_prompts_per_rank != 0:
|
||||||
|
# Mask excessive prompts to False
|
||||||
|
true_indices = torch.nonzero(effective_prompts_mask)
|
||||||
|
# Make sure the indices are not empty.
|
||||||
|
if true_indices.numel() > 0:
|
||||||
|
true_indices = true_indices.squeeze()
|
||||||
|
if excessive_prompts_per_rank <= len(true_indices):
|
||||||
|
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||||
|
else:
|
||||||
|
excessive_prompts_idx = true_indices
|
||||||
|
effective_prompts_mask[excessive_prompts_idx] = False
|
||||||
|
|
||||||
|
for mask_idx in range(len(effective_prompts_mask)):
|
||||||
|
if effective_prompts_mask[mask_idx] == False:
|
||||||
|
# Update loss mask.
|
||||||
|
loss_mask[mask_idx] = False
|
||||||
|
else:
|
||||||
|
excessive_prompts_idx = torch.empty([0])
|
||||||
else:
|
else:
|
||||||
# If dynamic batching is disabled, we need to use all samples for training.
|
# If dynamic batching is disabled, we need to use all samples for training.
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user