mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
handle empty index
This commit is contained in:
parent
957e3a521a
commit
1644adf684
@ -113,7 +113,6 @@ class BaseConsumer:
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
i = 0
|
||||
allow_sync_model = False
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
@ -139,7 +138,6 @@ class BaseConsumer:
|
||||
else:
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
if loss is not None:
|
||||
allow_sync_model = True
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
if self.lr_scheduler is not None:
|
||||
@ -153,31 +151,29 @@ class BaseConsumer:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
if allow_sync_model:
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
allow_sync_model = False
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -245,17 +245,22 @@ class GRPOConsumer(BaseConsumer):
|
||||
# TODO: customize excessive prompts calculation.
|
||||
if excessive_prompts_per_rank != 0:
|
||||
# Mask excessive prompts to False
|
||||
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
true_indices = torch.nonzero(effective_prompts_mask)
|
||||
# Make sure the indices are not empty.
|
||||
if true_indices.numel() > 0:
|
||||
true_indices = true_indices.squeeze()
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
else:
|
||||
excessive_prompts_idx = torch.empty([0])
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
Loading…
Reference in New Issue
Block a user