mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[shardformer] adapted llama to the new API (#4036)
This commit is contained in:
@@ -4,31 +4,28 @@ import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),)
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
||||
|
||||
def build_model(rank, world_size):
|
||||
cfg = LlamaConfig(num_hidden_layers=16)
|
||||
org_model = LlamaForCausalLM(cfg)
|
||||
def build_model(world_size, model_fn):
|
||||
# create new model
|
||||
config = LlamaConfig(num_hidden_layers=8)
|
||||
org_model = model_fn(config).cuda()
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
org_model = org_model.to('cuda')
|
||||
|
||||
org_model_forshard = copy.deepcopy(org_model)
|
||||
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy)
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
@@ -38,6 +35,7 @@ def check_forward(org_model, sharded_model):
|
||||
inputs = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
del inputs["token_type_ids"]
|
||||
del inputs["attention_mask"]
|
||||
|
||||
#orgin model
|
||||
org_model.eval()
|
||||
org_out = org_model(**inputs)
|
||||
@@ -87,11 +85,20 @@ def check_backward(org_model, sharded_model):
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
org_model, sharded_model = build_model(rank, world_size)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_backward(org_model, sharded_model)
|
||||
model_list = [
|
||||
LlamaForCausalLM,
|
||||
|
||||
# TODO: do not work yet
|
||||
# LlamaModel,
|
||||
# LlamaForSequenceClassification
|
||||
]
|
||||
|
||||
for model_fn in model_list:
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_backward(org_model, sharded_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user