update producer

This commit is contained in:
Tong Li 2025-04-30 11:59:44 +08:00
parent 3454b10884
commit e8cc51066b

View File

@ -39,6 +39,7 @@ class BaseProducer:
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
print("Num micro batches: ", self.num_microbatches)
self.dataset_config = dataset_config
self.model_config = model_config
@ -135,12 +136,12 @@ class BaseProducer:
torch.cuda.empty_cache()
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
for pp_idx in range(self.consumer_pp_size):
print(
f"[P{self.producer_idx}] Sync model PP stage {i} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
self.load_state_dict(state_dict)
else: