mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[shardformer] Unit test (#3928)
* fix bug in slicer, add slicer unit test * add dropout test * use pid as dropout seed * updata dropout test with local pattern * ad todo
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
@@ -14,7 +13,8 @@ class SeedManager:
|
||||
|
||||
def __init__(self):
|
||||
original_state = torch.cuda.get_rng_state()
|
||||
seed = int(f"{int(time.time())}{os.environ['RANK']}")
|
||||
# TODO: unify this seed manager with the colossalai.context.random
|
||||
seed = os.getpid()
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
self.dropout_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(original_state)
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||
from .shard_config import ShardConfig
|
||||
|
||||
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
|
||||
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
|
||||
|
||||
|
||||
class Slicer():
|
||||
@@ -40,7 +40,7 @@ class Slicer():
|
||||
# print(weight.shape, dim)
|
||||
if policy_layer_cls == Col_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
bias = self.slice_tensor(bias, 0, True)
|
||||
bias = self.slice_tensor(bias, 0, True, n_cast)
|
||||
elif policy_layer_cls == Row_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
else:
|
||||
@@ -129,13 +129,13 @@ class Slicer():
|
||||
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
||||
return torch.cat(chunk_list, dim=1).contiguous()
|
||||
|
||||
def slice_row(
|
||||
self,
|
||||
@@ -152,10 +152,10 @@ class Slicer():
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=1).contiguous()
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user