mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[hotfix] fixed bugs in ShardStrategy and PcieProfiler (#394)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch.distributed as dist
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
@@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
payload_numel = t.payload.numel()
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
buffer_list.append(t.payload.cuda())
|
||||
buffer_list.append(t.payload.cuda(get_current_device()))
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda())
|
||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
|
||||
|
||||
torch.distributed.all_gather(buffer_list,
|
||||
buffer_list[self.local_rank],
|
||||
|
Reference in New Issue
Block a user