mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-03 02:26:26 +00:00
use consumer global step
This commit is contained in:
parent
094f119b3a
commit
4ec73298b2
@ -266,7 +266,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
self.total_sample_count += total_samples.item()
|
self.total_sample_count += total_samples.item()
|
||||||
|
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Global Step": self.global_step,
|
"Global Step": self.global_step,
|
||||||
@ -522,7 +521,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
# All gather excessive prompts index across DP ranks.
|
# All gather excessive prompts index across DP ranks.
|
||||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||||
|
|
||||||
return loss_scalar, excessive_prompts_idx
|
return loss_scalar, excessive_prompts_idx
|
||||||
else:
|
else:
|
||||||
return None, excessive_prompts_idx
|
return None, excessive_prompts_idx
|
||||||
|
@ -56,7 +56,6 @@ def launch_distributed(
|
|||||||
eval_save_dir: Optional[str] = None,
|
eval_save_dir: Optional[str] = None,
|
||||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||||
else:
|
else:
|
||||||
|
@ -58,6 +58,7 @@ class BaseProducer:
|
|||||||
self.microbatch_size = microbatch_size
|
self.microbatch_size = microbatch_size
|
||||||
assert batch_size % microbatch_size == 0
|
assert batch_size % microbatch_size == 0
|
||||||
self.num_microbatches = batch_size // microbatch_size
|
self.num_microbatches = batch_size // microbatch_size
|
||||||
|
self.lastest_eval_step = -1
|
||||||
|
|
||||||
self.train_dataset_config = train_dataset_config
|
self.train_dataset_config = train_dataset_config
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@ -178,12 +179,15 @@ class BaseProducer:
|
|||||||
if i >= num_valid_microbatches:
|
if i >= num_valid_microbatches:
|
||||||
break
|
break
|
||||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||||
if i % self.eval_interval == 0:
|
if (
|
||||||
|
self.consumer_global_step % self.eval_interval == 0
|
||||||
|
and self.consumer_global_step > self.lastest_eval_step
|
||||||
|
):
|
||||||
to_log_msg = {}
|
to_log_msg = {}
|
||||||
for eval_task_name in self.eval_dataloaders:
|
for eval_task_name in self.eval_dataloaders:
|
||||||
if self.producer_idx == 0:
|
if self.producer_idx == 0:
|
||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}"
|
||||||
)
|
)
|
||||||
eval_results = []
|
eval_results = []
|
||||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||||
@ -223,6 +227,7 @@ class BaseProducer:
|
|||||||
|
|
||||||
if self.producer_idx == 0:
|
if self.producer_idx == 0:
|
||||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||||
|
self.lastest_eval_step = self.consumer_global_step
|
||||||
outputs = self.rollout(**batch)
|
outputs = self.rollout(**batch)
|
||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
|
Loading…
Reference in New Issue
Block a user