mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[shardformer] adapted T5 and LLaMa test to use kit (#4049)
* [shardformer] adapted T5 and LLaMa test to use kit * polish code
This commit is contained in:
38
tests/test_shardformer/test_model/_utils.py
Normal file
38
tests/test_shardformer/test_model/_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import copy
|
||||
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
def build_model(world_size, model_fn):
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy)
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
# switch to train mode
|
||||
original_model.train()
|
||||
sharded_model.train()
|
||||
|
||||
# run forward
|
||||
org_output = original_model(**data)
|
||||
org_output = output_transform_fn(org_output)
|
||||
org_loss = loss_fn(org_output)
|
||||
|
||||
shard_output = sharded_model(**data)
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
|
||||
return org_output, org_loss, shard_output, shard_loss
|
Reference in New Issue
Block a user