mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-17 00:41:39 +00:00
polish code
This commit is contained in:
parent
54fd37f0e0
commit
63469c0f91
@ -23,6 +23,9 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
|||||||
for i in range(self.world_size):
|
for i in range(self.world_size):
|
||||||
if i == self.local_rank:
|
if i == self.local_rank:
|
||||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||||
|
# Release payload here, to decrease peak memory usage
|
||||||
|
for t in tensor_list:
|
||||||
|
t.reset_payload(None)
|
||||||
else:
|
else:
|
||||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||||
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
|
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
|
||||||
|
Loading…
Reference in New Issue
Block a user