[autoparallel] integrate_gpt_related_tests (#2134)

* [autoparallel] integrate_gpt_related_tests

* polish code

* polish code

* add GPT2Model into runtime test
This commit is contained in:
YuliangLiu0306 2022-12-23 12:36:59 +08:00 committed by GitHub
parent 59e343328d
commit 550f8f8905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 217 additions and 207 deletions

View File

@ -230,7 +230,12 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
new_slice_items = [] new_slice_items = []
for slice_item in getitem_index: for slice_item in getitem_index:
if slice_item is None:
new_slice_items.append(None)
continue
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
if slice_item.start in node_pairs: if slice_item.start in node_pairs:
new_start = node_pairs[slice_item.start] new_start = node_pairs[slice_item.start]
elif slice_item.stop in node_pairs: elif slice_item.stop in node_pairs:
@ -355,7 +360,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
for node in nodes: for node in nodes:
if node.op == 'call_module': if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target) target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed:
continue
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters(): for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
# apply the sharding spec of parameters # apply the sharding spec of parameters
@ -404,7 +412,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
target_module = root target_module = root
target = getattr(root, atoms[0]) target = getattr(root, atoms[0])
else: else:
target_module = root.get_submodule(atoms[-2]) target_module = root
for atom in atoms[:-1]:
target_module = getattr(target_module, atom)
target = getattr(target_module, atoms[-1]) target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec target_sharding_spec = node.sharding_spec

View File

@ -2,32 +2,30 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers from transformers.activations import ACT2FN
from torch.fx import GraphModule from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel
from transformers.models.gpt2.modeling_gpt2 import (
GPT2MLP,
BaseModelOutputWithPastAndCrossAttentions,
GPT2PreTrainedModel,
)
from transformers.pytorch_utils import Conv1D from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 1 class GPT2MLP(nn.Module):
SEQ_LENGTH = 32
HIDDEN_DIM = 768 def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
# We temporarily banned the Dropout layer because the rng state need
# to process to get the correct result.
# self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
# TODO: the rng state need to be fixed for distributed runtime
# hidden_states = self.dropout(hidden_states)
return hidden_states
# The reason Why we don't import GPT2Attention from transformers directly is that: # The reason Why we don't import GPT2Attention from transformers directly is that:
@ -89,7 +87,7 @@ class GPT2Attention(nn.Module):
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype) attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights) # attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
@ -125,15 +123,10 @@ class GPT2Attention(nn.Module):
present = (key, value) present = (key, value)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output) attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output) # attn_output = self.resid_dropout(attn_output)
return attn_output
outputs = (attn_output, present)
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class GPT2Block(nn.Module): class GPT2Block(nn.Module):
@ -161,19 +154,15 @@ class GPT2Block(nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
) )
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_outputs + residual
residual = hidden_states residual = hidden_states
hidden_states = self.ln_2(hidden_states) hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection # residual connection
hidden_states = residual + feed_forward_hidden_states hidden_states = residual + feed_forward_hidden_states
outputs = (hidden_states,) + outputs[1:] return hidden_states
return outputs # hidden_states, present, (attentions, cross_attentions)
class GPT2Model(GPT2PreTrainedModel): class GPT2Model(GPT2PreTrainedModel):
@ -228,103 +217,25 @@ class GPT2Model(GPT2PreTrainedModel):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
# add_2 # add_2
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
token_type_embeds = self.wte(token_type_ids) token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds hidden_states = hidden_states + token_type_embeds
# transformer_drop
hidden_states = self.drop(hidden_states)
# comment to run pipeline # comment to run pipeline
# add_3 # add_3
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
presents = None
all_self_attentions = None
all_cross_attentions = None
all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
hidden_states = outputs[0] hidden_states = outputs
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
# comment to run pipeline # comment to run pipeline
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] return hidden_states
if v is not None)
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
model = model_cls(config=config)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
if model_cls == GPT2MLP:
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
else:
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_self_attention_block()

View File

@ -1,7 +1,7 @@
import copy import copy
import random import random
from functools import partial from functools import partial
from typing import Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import pytest import pytest
@ -10,13 +10,11 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import transformers import transformers
from torch.fx import GraphModule from torch.fx import GraphModule
from transformers.activations import ACT2FN
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
from colossalai.auto_parallel.tensor_shard.solver import ( from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph, CostGraph,
GraphAnalyser, GraphAnalyser,
@ -32,6 +30,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_glob
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGTH = 32 SEQ_LENGTH = 32
@ -46,36 +45,73 @@ torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
class GPT2MLP(nn.Module): def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor],
best_sharding_spec_dict: Dict[str, ShardingSpec]):
for name, param in module.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
atoms = name.split('.')
new_name = '_'.join(atoms)
if new_name in best_sharding_spec_dict:
param_sharding_spec = best_sharding_spec_dict[new_name]
grad_to_compare = copy.deepcopy(param_grad)
param_grad_global = to_global(grad_to_compare, param_sharding_spec)
def __init__(self, intermediate_size, config): try:
super().__init__() assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03)
embed_dim = config.hidden_size except:
self.c_fc = Conv1D(intermediate_size, embed_dim) difference = param_grad_global - origin_param_grad
self.c_proj = Conv1D(embed_dim, intermediate_size) avg_diff = difference.abs().sum() / difference.numel()
self.act = ACT2FN[config.activation_function] assert avg_diff < 0.001
# We temporarily banned the Dropout layer because the rng state need print(f'{name} param has {avg_diff} average difference')
# to process to get the correct result.
# self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
# TODO: the rng state need to be fixed for distributed runtime
# hidden_states = self.dropout(hidden_states)
return hidden_states
def check_mlp_layer(rank, model_cls, world_size, port): def check_attention_layer(rank, model_cls, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM)
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda') if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
else:
model = model_cls(config=config).to('cuda')
test_model = copy.deepcopy(model) test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32)
if model_cls == GPT2MLP:
input_sample = (hidden_states.to('cuda'),)
test_input_sample = copy.deepcopy(input_sample)
meta_input_sample = {
'hidden_states': hidden_states.to('meta'),
}
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = (
hidden_states.to('cuda'),
attention_mask.to('cuda'),
)
test_input_sample = copy.deepcopy(input_sample)
meta_input_sample = {
'hidden_states': hidden_states.to('meta'),
'attention_mask': attention_mask.to('meta'),
}
else:
input_sample = (
input_ids.to('cuda'),
token_type_ids.to('cuda'),
attention_mask.to('cuda'),
)
test_input_sample = copy.deepcopy(input_sample)
meta_input_sample = {
'input_ids': input_ids.to('meta'),
'token_type_ids': token_type_ids.to('meta'),
'attention_mask': attention_mask.to('meta'),
}
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
# [[0, 1] # [[0, 1]
@ -85,15 +121,10 @@ def check_mlp_layer(rank, model_cls, world_size, port):
tracer = ColoTracer() tracer = ColoTracer()
input_sample = { graph = tracer.trace(root=model, meta_args=meta_input_sample)
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
graph = tracer.trace(root=model, meta_args=input_sample)
print(graph)
gm = GraphModule(model, graph, model.__class__.__name__) gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile() gm.recompile()
print(gm)
graph_analyser = GraphAnalyser(gm) graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis() liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions() solver_options = SolverOptions()
@ -110,71 +141,35 @@ def check_mlp_layer(rank, model_cls, world_size, port):
gm, solution, device_mesh, strategies_constructor) gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm) gm = runtime_apply_pass(gm)
gm.recompile() gm.recompile()
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
best_sharding_spec_dict = {}
for index, node in enumerate(nodes):
best_sharding_spec_dict[node.name] = node.sharding_spec
cuda_rng_state = torch.cuda.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state()
cpu_rng_state = torch.get_rng_state() cpu_rng_state = torch.get_rng_state()
origin_output = test_model(test_input) origin_output = test_model(*test_input_sample)
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(cpu_rng_state) torch.set_rng_state(cpu_rng_state)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) output = gm(*input_sample, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, origin_output, rtol=1e-03, atol=1e-04) assert_close(output, origin_output, rtol=1e-03, atol=1e-03)
#*******************backward starting******************* #*******************backward starting*******************
cuda_rng_state = torch.cuda.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state()
cpu_rng_state = torch.get_rng_state()
output.sum().backward() output.sum().backward()
torch.set_rng_state(cpu_rng_state)
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
origin_output.sum().backward() origin_output.sum().backward()
origin_param_dict = dict(test_model.named_parameters()) origin_param_dict = dict(test_model.named_parameters())
if rank == 0: if rank == 0:
print("*******************backward starting*******************") print("*******************backward starting*******************")
for name, param in model.named_parameters():
param_grad = param.grad _check_module_grad(gm, origin_param_dict, best_sharding_spec_dict)
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1] if rank == 0:
print(name, param_grad, origin_param_grad)
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, 0, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
print("*******************backward finished*******************") print("*******************backward finished*******************")
if rank == 1:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
if rank == 2:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, 0, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
if rank == 3:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
#*******************backward finished******************* #*******************backward finished*******************
@ -202,11 +197,11 @@ def check_mlp_layer(rank, model_cls, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist @pytest.mark.dist
@parameterize('model_cls', [GPT2MLP]) @parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mlp_layer(model_cls): def test_mlp_layer(model_cls):
world_size = 4 world_size = 4
run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port()) run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
model = model_cls(config=config)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
if model_cls == GPT2MLP:
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
else:
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_self_attention_block()