mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 03:03:37 +00:00
Migrated project
This commit is contained in:
26
tests/test_layers/test_sequence/test_layer.py
Normal file
26
tests/test_layers/test_sequence/test_layer.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import TransformerSelfAttentionRing
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def check_selfattention():
|
||||
WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
SUB_SEQ_LENGTH = 8
|
||||
BATCH = 4
|
||||
HIDDEN_SIZE = 16
|
||||
|
||||
layer = TransformerSelfAttentionRing(
|
||||
16,
|
||||
8,
|
||||
8,
|
||||
0.1
|
||||
)
|
||||
layer = layer.to(get_current_device())
|
||||
|
||||
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
|
||||
attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
|
||||
get_current_device())
|
||||
out = layer(hidden_states, attention_mask)
|
34
tests/test_layers/test_sequence/test_sequence.py
Normal file
34
tests/test_layers/test_sequence/test_sequence.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.initialize import init_dist
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from test_layer import *
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode='sequence', size=4)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def check_layer():
|
||||
check_selfattention()
|
||||
|
||||
|
||||
def _test_main():
|
||||
# init dist
|
||||
init_dist(CONFIG)
|
||||
logger = get_global_dist_logger()
|
||||
logger.info('Distributed environment is initialzied.', ranks=[0])
|
||||
|
||||
gpc.set_seed()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# check layers
|
||||
check_layer()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_main()
|
Reference in New Issue
Block a user