update files

This commit is contained in:
Tong Li 2025-04-30 11:33:23 +08:00
parent bb8d370b44
commit 3454b10884

View File

@ -132,19 +132,21 @@ class BaseProducer:
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
torch.cuda.empty_cache()
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
print(f"[P{self.producer_idx}] Sync model PP stage {i}")
print(
f"[P{self.producer_idx}] Sync model PP stage {i} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
)
self.load_state_dict(state_dict)
else:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)