mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[shardformer] made tensor parallelism configurable (#4144)
* [shardformer] made tensor parallelism configurable * polish code
This commit is contained in:
@@ -3,12 +3,13 @@ import copy
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
def build_model(model_fn):
|
||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=True)
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model = shard_former.optimize(model_copy).cuda()
|
||||
|
@@ -3,7 +3,14 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
@@ -33,34 +40,48 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# compare self attention grad
|
||||
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# compare embedding grad
|
||||
org_grad = bert.embeddings.word_embeddings.weight.grad
|
||||
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad
|
||||
shard_weight = sharded_bert.embeddings.word_embeddings.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_bert(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
run_bert_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -3,7 +3,14 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
@@ -32,10 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check attention grad
|
||||
org_grad = bloom.h[0].self_attention.query_key_value.weight.grad
|
||||
shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad
|
||||
shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
@@ -43,25 +54,33 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check embedding weights
|
||||
org_grad = bloom.word_embeddings.weight.grad
|
||||
shard_grad = sharded_bloom.word_embeddings.weight.grad
|
||||
shard_weight = sharded_bloom.word_embeddings.weight
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_bloom(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
run_bloom_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -3,7 +3,14 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
@@ -32,11 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check mlp grad
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_weight = sharded_model.h[0].mlp.c_fc.weight
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
@@ -44,25 +54,33 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check embedding weights
|
||||
org_grad = org_model.wte.weight.grad
|
||||
shard_grad = sharded_model.wte.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
shard_weight = sharded_model.wte.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_gpt2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
run_gpt2_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -5,7 +5,14 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
@@ -37,33 +44,46 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check attention grad
|
||||
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
# check embedding grad
|
||||
org_grad = llama_model.embed_tokens.weight.grad
|
||||
shard_grad = shard_llama_model.embed_tokens.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
shard_weight = shard_llama_model.embed_tokens.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
run_gpt2_llama()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -6,10 +6,11 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
@@ -42,32 +43,46 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check attention grad
|
||||
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
# check embedding grad
|
||||
org_grad = opt_model.decoder.embed_tokens.weight.grad
|
||||
shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
shard_weight = shard_opt_model.decoder.embed_tokens.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_OPTModel(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
run_t5_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -5,7 +5,14 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
@@ -27,19 +34,28 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
# check attention grad
|
||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||
shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
# check self attention embed
|
||||
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
||||
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
@@ -52,23 +68,32 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
|
||||
|
||||
shard_grad = sharded_model.shared.weight.grad
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
shard_weight = sharded_model.shared.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_t5(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
run_t5_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user