mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 09:51:18 +00:00
[autoparallel] record parameter attribute in colotracer (#2217)
* [autoparallel] record parameter attribute in collotracer * [autoparallel] fix construct_meta_info bug
This commit is contained in:
parent
92de90dfb3
commit
3b1b91eaf4
@ -174,8 +174,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||||||
runtime_apply,
|
runtime_apply,
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
args=(node, origin_dict_node, input_dict_node,
|
||||||
node_to_index_dict[node], user_node_index))
|
node_to_index_dict[node], user_node_index))
|
||||||
meta_info = construct_meta_info(node, user_node)
|
# meta_info = construct_meta_info(node, user_node)
|
||||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
# setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||||
|
|
||||||
new_args = list(user_node.args)
|
new_args = list(user_node.args)
|
||||||
new_kwargs = dict(user_node.kwargs)
|
new_kwargs = dict(user_node.kwargs)
|
||||||
|
@ -229,6 +229,15 @@ class ColoTracer(Tracer):
|
|||||||
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
|
||||||
|
|
||||||
if kind == "call_function":
|
if kind == "call_function":
|
||||||
|
# Our meta data will not record the nn.parameter.Parameter attribute。
|
||||||
|
# It works fine in most of the case, but it may cause some problems after
|
||||||
|
# the bias addition manipulation.
|
||||||
|
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
|
||||||
|
# added by the bias addition manipulation following the get_attr node.
|
||||||
|
convert_to_parameter = False
|
||||||
|
if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
|
||||||
|
torch.nn.parameter.Parameter):
|
||||||
|
convert_to_parameter = True
|
||||||
# fetch patched function
|
# fetch patched function
|
||||||
if meta_patched_function.has(target):
|
if meta_patched_function.has(target):
|
||||||
meta_target = meta_patched_function.get(target)
|
meta_target = meta_patched_function.get(target)
|
||||||
@ -241,7 +250,18 @@ class ColoTracer(Tracer):
|
|||||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
if isinstance(meta_out, torch.Tensor):
|
if isinstance(meta_out, torch.Tensor):
|
||||||
meta_out = meta_out.to(device="meta")
|
meta_out = meta_out.to(device="meta")
|
||||||
|
if convert_to_parameter:
|
||||||
|
meta_out = torch.nn.Parameter(meta_out)
|
||||||
|
|
||||||
elif kind == "call_method":
|
elif kind == "call_method":
|
||||||
|
# Our meta data will not record the nn.parameter.Parameter attribute。
|
||||||
|
# It works fine in most of the case, but it may cause some problems after
|
||||||
|
# the bias addition manipulation.
|
||||||
|
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
|
||||||
|
# added by the bias addition manipulation following the get_attr node.
|
||||||
|
convert_to_parameter = False
|
||||||
|
if target in (torch.Tensor.view,) and isinstance(args_metas[0], torch.nn.parameter.Parameter):
|
||||||
|
convert_to_parameter = True
|
||||||
method = getattr(args_metas[0].__class__, target)
|
method = getattr(args_metas[0].__class__, target)
|
||||||
|
|
||||||
# fetch patched method
|
# fetch patched method
|
||||||
@ -251,6 +271,8 @@ class ColoTracer(Tracer):
|
|||||||
meta_target = method
|
meta_target = method
|
||||||
|
|
||||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
|
if convert_to_parameter:
|
||||||
|
meta_out = torch.nn.Parameter(meta_out)
|
||||||
elif kind == "call_module":
|
elif kind == "call_module":
|
||||||
if not hasattr(self, "orig_forward"):
|
if not hasattr(self, "orig_forward"):
|
||||||
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
||||||
|
@ -35,13 +35,14 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
|||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2LMHeadModel, GPTLMLoss
|
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2LMHeadModel, GPTLMLoss
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 32
|
||||||
SEQ_LENGTH = 128
|
SEQ_LENGTH = 256
|
||||||
HIDDEN_DIM = 4096
|
HIDDEN_DIM = 16384
|
||||||
NUM_HEADS = 32
|
NUM_HEADS = 128
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
VOCAB_SIZE = 50257
|
VOCAB_SIZE = 50257
|
||||||
NUM_STEPS = 10
|
NUM_STEPS = 10
|
||||||
|
FP16 = True
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_mem():
|
def get_cpu_mem():
|
||||||
@ -57,7 +58,8 @@ def get_mem_info(prefix=''):
|
|||||||
|
|
||||||
|
|
||||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||||
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
|
||||||
|
|
||||||
|
|
||||||
# Randomly Generated Data
|
# Randomly Generated Data
|
||||||
@ -72,8 +74,11 @@ def main():
|
|||||||
launch_from_torch(config={})
|
launch_from_torch(config={})
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
|
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
|
||||||
|
if FP16:
|
||||||
model = GPT2LMHeadModel(config=config).to('cuda')
|
model = GPT2LMHeadModel(config=config).half().to('cuda')
|
||||||
|
else:
|
||||||
|
model = GPT2LMHeadModel(config=config).to('cuda')
|
||||||
|
global_numel = sum([p.numel() for p in model.parameters()])
|
||||||
|
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
@ -108,6 +113,7 @@ def main():
|
|||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
|
|
||||||
solution = list(ret[0])
|
solution = list(ret[0])
|
||||||
|
# solution = [0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 13, 8, 9, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 9, 0, 0, 8, 0]
|
||||||
print(solution)
|
print(solution)
|
||||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||||
gm, solution, device_mesh, strategies_constructor)
|
gm, solution, device_mesh, strategies_constructor)
|
||||||
@ -125,9 +131,8 @@ def main():
|
|||||||
criterion = GPTLMLoss()
|
criterion = GPTLMLoss()
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
|
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
|
||||||
numel = sum([p.numel() for p in model.parameters()])
|
|
||||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LENGTH)
|
get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
model.train()
|
model.train()
|
||||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
@ -102,13 +102,11 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||||||
else:
|
else:
|
||||||
input_sample = (
|
input_sample = (
|
||||||
input_ids.to('cuda'),
|
input_ids.to('cuda'),
|
||||||
token_type_ids.to('cuda'),
|
|
||||||
attention_mask.to('cuda'),
|
attention_mask.to('cuda'),
|
||||||
)
|
)
|
||||||
test_input_sample = copy.deepcopy(input_sample)
|
test_input_sample = copy.deepcopy(input_sample)
|
||||||
meta_input_sample = {
|
meta_input_sample = {
|
||||||
'input_ids': input_ids.to('meta'),
|
'input_ids': input_ids.to('meta'),
|
||||||
'token_type_ids': token_type_ids.to('meta'),
|
|
||||||
'attention_mask': attention_mask.to('meta'),
|
'attention_mask': attention_mask.to('meta'),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,9 +50,8 @@ def test_self_attention_block(model_cls):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
|
||||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||||
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
|
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
|
||||||
|
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
@ -130,7 +130,10 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
|
|||||||
|
|
||||||
assert mapping['other'].name == "transpose"
|
assert mapping['other'].name == "transpose"
|
||||||
assert mapping['other'].data.shape == torch.Size([16, 8])
|
assert mapping['other'].data.shape == torch.Size([16, 8])
|
||||||
assert mapping['other'].type == OperationDataType.ARG
|
if model_cls == AddmmModel:
|
||||||
|
assert mapping['other'].type == OperationDataType.ARG
|
||||||
|
else:
|
||||||
|
assert mapping['other'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([8, 16])
|
assert mapping['other'].logical_shape == torch.Size([8, 16])
|
||||||
|
|
||||||
assert mapping['output'].name == "linear"
|
assert mapping['output'].name == "linear"
|
||||||
|
Loading…
Reference in New Issue
Block a user