use consumer global step

This commit is contained in:
YeAnbang 2025-05-15 14:15:40 +08:00
parent 094f119b3a
commit 4ec73298b2
3 changed files with 7 additions and 5 deletions

View File

@ -266,7 +266,6 @@ class GRPOConsumer(BaseConsumer):
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item() self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item() self.total_sample_count += total_samples.item()
pbar.set_postfix( pbar.set_postfix(
{ {
"Global Step": self.global_step, "Global Step": self.global_step,
@ -522,7 +521,6 @@ class GRPOConsumer(BaseConsumer):
# All gather excessive prompts index across DP ranks. # All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx return loss_scalar, excessive_prompts_idx
else: else:
return None, excessive_prompts_idx return None, excessive_prompts_idx

View File

@ -56,7 +56,6 @@ def launch_distributed(
eval_save_dir: Optional[str] = None, eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None, eval_generation_config: Optional[Dict[str, Any]] = None,
): ):
if core_algo not in ALGO_MAP: if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.") raise NotImplementedError(f"{core_algo} is not supported yet.")
else: else:

View File

@ -58,6 +58,7 @@ class BaseProducer:
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0 assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size self.num_microbatches = batch_size // microbatch_size
self.lastest_eval_step = -1
self.train_dataset_config = train_dataset_config self.train_dataset_config = train_dataset_config
self.model_config = model_config self.model_config = model_config
@ -178,12 +179,15 @@ class BaseProducer:
if i >= num_valid_microbatches: if i >= num_valid_microbatches:
break break
if self.eval_interval > 0 and self.eval_dataset_config is not None: if self.eval_interval > 0 and self.eval_dataset_config is not None:
if i % self.eval_interval == 0: if (
self.consumer_global_step % self.eval_interval == 0
and self.consumer_global_step > self.lastest_eval_step
):
to_log_msg = {} to_log_msg = {}
for eval_task_name in self.eval_dataloaders: for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0: if self.producer_idx == 0:
print( print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
) )
eval_results = [] eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
@ -223,6 +227,7 @@ class BaseProducer:
if self.producer_idx == 0: if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step) self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.lastest_eval_step = self.consumer_global_step
outputs = self.rollout(**batch) outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")