[hotfix] fixed bugs in ShardStrategy and PcieProfiler (#394)

This commit is contained in:
HELSON
2022-03-11 18:12:46 +08:00
committed by GitHub
parent 1e4bf85cdb
commit 7c079d9c33
2 changed files with 4 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import torch.distributed as dist
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy):
@@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy):
payload_numel = t.payload.numel()
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(t.payload.cuda())
buffer_list.append(t.payload.cuda(get_current_device()))
else:
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda())
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],