handle empty index

This commit is contained in:
Tong Li 2025-05-15 18:30:27 +08:00 committed by YeAnbang
parent 957e3a521a
commit 1644adf684
2 changed files with 37 additions and 36 deletions

View File

@ -113,7 +113,6 @@ class BaseConsumer:
) as pbar: ) as pbar:
for step in pbar: for step in pbar:
i = 0 i = 0
allow_sync_model = False
for _ in range(self.num_recv_per_update): for _ in range(self.num_recv_per_update):
# receive data from producers # receive data from producers
for r in range(self.num_producers): for r in range(self.num_producers):
@ -139,7 +138,6 @@ class BaseConsumer:
else: else:
self.buffer = self.buffer[self.dp_size * self.minibatch_size :] self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
if loss is not None: if loss is not None:
allow_sync_model = True
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
@ -153,31 +151,29 @@ class BaseConsumer:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
if allow_sync_model: if self.pp_size > 1:
if self.pp_size > 1: print(
print( f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" )
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
) )
else: else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") if self.rank == 0:
torch.cuda.empty_cache() ray_broadcast_tensor_dict(
state_dict = self.state_dict() state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
if self.pp_size > 1: )
if self.tp_rank == 0 and self.dp_rank == 0: del state_dict
ray_broadcast_tensor_dict( torch.cuda.empty_cache()
state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
)
else:
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()
allow_sync_model = False
@ray.remote @ray.remote

View File

@ -245,17 +245,22 @@ class GRPOConsumer(BaseConsumer):
# TODO: customize excessive prompts calculation. # TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0: if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False # Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze() true_indices = torch.nonzero(effective_prompts_mask)
if excessive_prompts_per_rank <= len(true_indices): # Make sure the indices are not empty.
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] if true_indices.numel() > 0:
else: true_indices = true_indices.squeeze()
excessive_prompts_idx = true_indices if excessive_prompts_per_rank <= len(true_indices):
effective_prompts_mask[excessive_prompts_idx] = False excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)): for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False: if effective_prompts_mask[mask_idx] == False:
# Update loss mask. # Update loss mask.
loss_mask[mask_idx] = False loss_mask[mask_idx] = False
else:
excessive_prompts_idx = torch.empty([0])
else: else:
# If dynamic batching is disabled, we need to use all samples for training. # If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0 need_update = (step_idx + 1) % self.num_microbatches == 0