mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-15 14:43:13 +00:00
update files
This commit is contained in:
parent
bb8d370b44
commit
3454b10884
@ -132,19 +132,21 @@ class BaseProducer:
|
|||||||
):
|
):
|
||||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||||
# don't sync model for last iteration
|
# don't sync model for last iteration
|
||||||
print(
|
|
||||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
for i in range(self.consumer_pp_size):
|
for i in range(self.consumer_pp_size):
|
||||||
print(f"[P{self.producer_idx}] Sync model PP stage {i}")
|
print(
|
||||||
|
f"[P{self.producer_idx}] Sync model PP stage {i} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
|
)
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
|
None, self.num_producers, device=self.device, group_name=f"sync_model_{i}"
|
||||||
)
|
)
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
|
print(
|
||||||
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
|
)
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user