use consumer global step

This commit is contained in:
YeAnbang 2025-05-15 14:15:40 +08:00
parent 094f119b3a
commit 4ec73298b2
3 changed files with 7 additions and 5 deletions

View File

@ -266,7 +266,6 @@ class GRPOConsumer(BaseConsumer):
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
pbar.set_postfix(
{
"Global Step": self.global_step,
@ -522,7 +521,6 @@ class GRPOConsumer(BaseConsumer):
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
else:
return None, excessive_prompts_idx

View File

@ -56,7 +56,6 @@ def launch_distributed(
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
else:

View File

@ -58,6 +58,7 @@ class BaseProducer:
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.lastest_eval_step = -1
self.train_dataset_config = train_dataset_config
self.model_config = model_config
@ -178,12 +179,15 @@ class BaseProducer:
if i >= num_valid_microbatches:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if i % self.eval_interval == 0:
if (
self.consumer_global_step % self.eval_interval == 0
and self.consumer_global_step > self.lastest_eval_step
):
to_log_msg = {}
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
@ -223,6 +227,7 @@ class BaseProducer:
if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.lastest_eval_step = self.consumer_global_step
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")