mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 02:57:20 +00:00
use consumer global step
This commit is contained in:
parent
094f119b3a
commit
4ec73298b2
@ -266,7 +266,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||
self.effective_sample_count += effective_samples.item()
|
||||
self.total_sample_count += total_samples.item()
|
||||
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Global Step": self.global_step,
|
||||
@ -522,7 +521,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
# All gather excessive prompts index across DP ranks.
|
||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||
|
||||
return loss_scalar, excessive_prompts_idx
|
||||
else:
|
||||
return None, excessive_prompts_idx
|
||||
|
@ -56,7 +56,6 @@ def launch_distributed(
|
||||
eval_save_dir: Optional[str] = None,
|
||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
else:
|
||||
|
@ -58,6 +58,7 @@ class BaseProducer:
|
||||
self.microbatch_size = microbatch_size
|
||||
assert batch_size % microbatch_size == 0
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
self.lastest_eval_step = -1
|
||||
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.model_config = model_config
|
||||
@ -178,12 +179,15 @@ class BaseProducer:
|
||||
if i >= num_valid_microbatches:
|
||||
break
|
||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||
if i % self.eval_interval == 0:
|
||||
if (
|
||||
self.consumer_global_step % self.eval_interval == 0
|
||||
and self.consumer_global_step > self.lastest_eval_step
|
||||
):
|
||||
to_log_msg = {}
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||
@ -223,6 +227,7 @@ class BaseProducer:
|
||||
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||
self.lastest_eval_step = self.consumer_global_step
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
|
Loading…
Reference in New Issue
Block a user