[shardformer] Unit test (#3928)

* fix bug in slicer, add slicer unit test

* add dropout test

* use pid as dropout seed

* updata dropout test with local pattern

* ad todo
This commit is contained in:
FoolPlayer
2023-06-12 13:56:09 +08:00
committed by Frank Lee
parent f1cb5ac6bf
commit a73130482d
4 changed files with 139 additions and 10 deletions

View File

@@ -1,5 +1,4 @@
import os
import time
from contextlib import contextmanager
import torch
@@ -14,7 +13,8 @@ class SeedManager:
def __init__(self):
original_state = torch.cuda.get_rng_state()
seed = int(f"{int(time.time())}{os.environ['RANK']}")
# TODO: unify this seed manager with the colossalai.context.random
seed = os.getpid()
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)

View File

@@ -3,7 +3,7 @@ import torch
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from .shard_config import ShardConfig
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
class Slicer():
@@ -40,7 +40,7 @@ class Slicer():
# print(weight.shape, dim)
if policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
bias = self.slice_tensor(bias, 0, True)
bias = self.slice_tensor(bias, 0, True, n_cast)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
else:
@@ -129,13 +129,13 @@ class Slicer():
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
return torch.cat(chunk_list, dim=1).contiguous()
def slice_row(
self,
@@ -152,10 +152,10 @@ class Slicer():
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=1).contiguous()
return torch.cat(chunk_list, dim=0).contiguous()