[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (#4395)

* rewrite opt tests

* rewrite llama tests

* rewrite bloom & vit tests

* rewrite chatglm tests

* fix LinearCol for classfiers

* add judge for other tp layers, fix lazy init in util
This commit is contained in:
Baizhou Zhang
2023-08-11 15:43:23 +08:00
committed by Hongxin Liu
parent 21e0a42fd1
commit 7711bd524a
19 changed files with 1064 additions and 1273 deletions

View File

@@ -53,7 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
@@ -73,12 +74,13 @@ loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
config = transformers.BloomConfig(n_layer=1,
config = transformers.BloomConfig(n_layer=2,
n_head=4,
vocab_size=250880,
hidden_dropout=0,
attention_dropout=0,
hidden_size=64)
hidden_size=64,
pad_token_id=50256)
# register the following models
model_zoo.register(name='transformers_bloom',

View File

@@ -17,14 +17,24 @@ def data_gen():
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_conditional_generation():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum()
loss_fn = lambda x: x.logits.sum()
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
torch.ones_like(x.last_hidden_state))
loss_fn = lambda x: x.loss
config = ChatGLMConfig(num_layers=1,
config = ChatGLMConfig(num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
num_attention_heads=8,
@@ -33,7 +43,6 @@ config = ChatGLMConfig(num_layers=1,
use_cache=True,
torch_dtype=torch.float32)
model_zoo.register(name='transformers_chatglm',
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
@@ -43,7 +52,7 @@ model_zoo.register(name='transformers_chatglm',
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@@ -7,11 +7,7 @@ from ..registry import ModelAttribute, model_zoo
# Register single-sentence VIT
# ===============================
config = transformers.ViTConfig(
num_hidden_layers=4,
# hidden_size=128,
# intermediate_size=256,
num_attention_heads=4)
config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
# define data gen function

View File

@@ -104,27 +104,22 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn().cuda()
org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
org_model = ctx.materialize(org_model)
ctx.materialize(org_model)
org_model = org_model.cuda()
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
@@ -142,11 +137,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
data = data_gen_fn()
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(*([4] + [1] *
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
@@ -176,7 +172,8 @@ def check_output_hidden_state(org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3):
rtol: float = 1e-3,
dim: int = 0):
org_hidden_state = org_output.last_hidden_state
@@ -184,7 +181,7 @@ def check_output_hidden_state(org_output: Tensor,
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"

View File

@@ -3,57 +3,101 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# do backward
org_loss.backward()
shard_loss.backward()
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
assert torch.allclose(org_loss, shard_loss,
atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'BloomModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
bloom = org_model
sharded_bloom = sharded_model
sharded_bloom = sharded_model.unwrap()
else:
bloom = org_model.transformer
sharded_bloom = sharded_model.transformer
sharded_bloom = sharded_model.unwrap().transformer
# check grad
col_layer_for_check = ['h[0].self_attention.query_key_value']
row_layer_for_check = ['h[0].self_attention.dense']
check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_bloom_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@@ -67,7 +111,7 @@ def check_bloom(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, 2)
spawn(check_bloom, 4)
if __name__ == "__main__":

View File

@@ -1,90 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 0 + 2
else:
if name == 'transformers_bloom':
assert len(layers) == 1 + 1
elif name == 'transformers_bloom_for_token_classification':
assert len(layers) == 1 + 3
else:
assert len(layers) == 1 + 2
def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
if stage_manager.stage == 0:
x = torch.randint(0, 1000, (1, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (1, 3, 64)
else:
attention_mask = torch.ones((1, 3)).cuda()
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0].shape[0] == 1
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_bloom
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
check_bloom_model_policy(name, org_model, stage_manager)
check_bloom_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
def check_bloom(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bloom_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, 2)
if __name__ == "__main__":
test_bloom()

View File

@@ -1,99 +1,126 @@
import copy
import os
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
org_loss.backward()
shard_loss.backward()
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'ChatGLMModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
chatglm_model = org_model
shard_chatglm_model = sharded_model
shard_chatglm_model = sharded_model.unwrap()
else:
chatglm_model = org_model.transformer
shard_chatglm_model = sharded_model.transformer
shard_chatglm_model = sharded_model.unwrap().transformer
# check attention grad
org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight
# check grad
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(chatglm_model,
shard_chatglm_model,
row_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=0,
verbose=False)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
check_grad(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=1,
verbose=False)
# check embedding weights
org_grad = chatglm_model.embedding.word_embeddings.weight.grad
shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad
shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
atol=1e-4,
rtol=1e-3,
dim=1,
verbose=False)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_chatglm_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model
org_model = model_fn().cuda()
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
if name == "transformers_chatglm":
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy())
else:
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy())
sharded_model = sharded_model.cuda()
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
clear_layout_converter()
torch.cuda.empty_cache()
@@ -107,7 +134,7 @@ def check_chatglm(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm():
spawn(check_chatglm, 2)
spawn(check_chatglm, 4)
if __name__ == "__main__":

View File

@@ -1,86 +0,0 @@
import copy
import os
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model for test
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids = inputs['input_ids']
hidden_size = 64
batch_size, seq_len = input_ids.shape
hidden_state_shape = (seq_len, batch_size, hidden_size)
if name == "transformers_chatglm":
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy())
if stage_manager.is_last_stage():
hidden_states = torch.randn(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
outputs = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert outputs[0].shape == hidden_state_shape
else:
assert outputs['hidden_states'].shape == hidden_state_shape
if name == "transformers_chatglm_for_conditional_generation":
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init,
ChatGLMForConditionalGenerationPolicy())
if stage_manager.is_last_stage():
hidden_states = torch.randn(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
outputs = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert outputs[0].shape == (batch_size, seq_len, 65024)
else:
assert outputs['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()
def check_chatglm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chatglm_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm():
spawn(check_chatglm, 4)
if __name__ == "__main__":
test_chatglm()

View File

@@ -2,69 +2,139 @@ import os
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# forward check
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
# run backward
org_loss.backward()
shard_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'LlamaModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if hasattr(org_model, 'model'):
llama_model = org_model.model
shard_llama_model = sharded_model.model
else:
if org_model.__class__.__name__ == 'LlamaModel':
llama_model = org_model
shard_llama_model = sharded_model
shard_llama_model = sharded_model.unwrap()
else:
llama_model = org_model.model
shard_llama_model = sharded_model.unwrap().model
# check grad
col_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
row_layer_for_check = ['layers[0].self_attn.o_proj']
check_grad(llama_model, shard_llama_model, col_layer_for_check, atol=1e-6, rtol=1e-4, dim=0, verbose=False)
check_grad(llama_model, shard_llama_model, row_layer_for_check, atol=1e-6, rtol=1e-4, dim=1, verbose=False)
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(llama_model,
shard_llama_model,
row_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-4,
dim=0,
verbose=False)
check_grad(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-4,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=1e-4,
rtol=1e-3,
dim=1,
verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False
}])
def run_llama_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_llama()
run_llama_test()
@pytest.mark.dist

View File

@@ -1,89 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 2 + 1
else:
if name == "transformers_llama":
assert len(layers) == 2 + 1
else:
assert len(layers) == 2 + 2
def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
x = torch.randint(0, 1000, (2, 3)).cuda()
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (2, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0] is not None
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_llama
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
check_llama_model_policy(name, org_model, stage_manager)
check_llama_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 2)
if __name__ == "__main__":
test_llama()

View File

@@ -1,64 +1,129 @@
import copy
import os
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# run backward
org_loss.backward()
shard_loss.backward()
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'OPTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if hasattr(org_model, 'model'):
opt_model = org_model.model
shard_opt_model = sharded_model.model
else:
if org_model.__class__.__name__ == 'OPTModel':
opt_model = org_model
shard_opt_model = sharded_model
shard_opt_model = sharded_model.unwrap()
else:
opt_model = org_model.model
shard_opt_model = sharded_model.unwrap().model
# check grad
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=0,
verbose=False)
check_grad(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=1e-3,
rtol=1e-3,
dim=1,
verbose=False)
torch.cuda.empty_cache()
@parameterize('use_lazy_init', [False, True])
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
enable_jit_fused):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_opt_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()

View File

@@ -1,70 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_opt
def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
batch_size, seq_len = input_ids.shape
hidden_size = 128
hidden_state_shape = (batch_size, seq_len, hidden_size)
if not stage_manager.is_first_stage():
# change inputs if not the first stage
hidden_states = torch.zeros(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert output[0] is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()
def check_opt(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_opt_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_opt():
spawn(check_opt, 4)
if __name__ == "__main__":
test_opt()

View File

@@ -1,60 +1,127 @@
import os
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
# do backward
org_loss.backward()
shard_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
shard_vit_model = sharded_model
shard_vit_model = sharded_model.unwrap()
else:
vit_model = org_model.vit
shard_vit_model = sharded_model.vit
shard_vit_model = sharded_model.unwrap().vit
# check grad
col_layer_for_check = ['encoder.layer[0].attention.attention.query']
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False)
check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False)
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(vit_model,
shard_vit_model,
row_layer_for_check,
tp_group,
atol=1e-5,
rtol=1e-3,
dim=0,
verbose=False)
check_grad(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
atol=1e-5,
rtol=1e-3,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
atol=5e-3,
rtol=1e-3,
dim=1,
verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': False
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_vit_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -68,7 +135,7 @@ def check_vit(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 2)
spawn(check_vit, 4)
if __name__ == "__main__":

View File

@@ -1,74 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_vit
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
pixel_values = inputs['pixel_values']
batch_size = len(pixel_values)
hidden_size = 768
hidden_state_shape = (batch_size, 197, hidden_size)
if not stage_manager.is_first_stage():
# change inputs if not the first stage
hidden_states = torch.randn(*hidden_state_shape).cuda()
# inputs['pixel_values'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_vit':
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape, \
f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}'
torch.cuda.empty_cache()
def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 4)
if __name__ == "__main__":
test_vit()