mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
fix bert unit test
This commit is contained in:
@@ -13,6 +13,7 @@ from colossalai.utils import free_port
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import \
|
||||
TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
@@ -45,8 +46,7 @@ def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
test_models = ['bert']
|
||||
# repeated_computed_layers resnet18
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
shard_strategy = TensorShardStrategy()
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
@@ -65,8 +65,7 @@ def run_dist(rank, world_size, port):
|
||||
run_fwd_bwd_no_criterion(model, data, label, False)
|
||||
run_fwd_bwd_no_criterion(zero_model, data, label, False)
|
||||
else:
|
||||
# FIXME() data can be interger!
|
||||
data, label = data.half().cuda(), label.cuda()
|
||||
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, False)
|
||||
|
||||
@@ -76,7 +75,6 @@ def run_dist(rank, world_size, port):
|
||||
check_grads(model, zero_model, loose=True)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Under development")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||
def test_shard_model_v2(world_size):
|
||||
|
||||
Reference in New Issue
Block a user