[autoparallel]add embedding handler (#2089)

* [autoparallel] add embedding handler

* fix bugs
This commit is contained in:
YuliangLiu0306
2022-12-07 09:41:46 +08:00
committed by GitHub
parent 1fca5d79ea
commit 7f72eb0510
6 changed files with 844 additions and 8 deletions

View File

@@ -0,0 +1,286 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import (
EmbeddingFunctionHandler,
EmbeddingModuleHandler,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
NUM_EMBEDDINGS = 16
EMBEDDING_DIMS = 32
class EmbeddingModule(nn.Module):
def __init__(self, num_embeddings, embedding_dims):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dims)
def forward(self, input):
x = self.embedding(input)
return x
def check_embedding_module_handler(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %embedding : [#users=1] = call_module[target=embedding](args = (%input_1,), kwargs = {})
# return embedding
input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS
input = input.to(torch.int64).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of embedding node in computation graph
node_index = 1
# total number of embedding strategies
strategy_number = 19
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input],
meta_arg_names=['input'])
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')})
gm = ColoGraphModule(model, graph)
embedding_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(embedding_node)
# build handler
handler = EmbeddingModuleHandler(node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
# assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([1024])
assert mapping['other'].name == "weight"
assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['output'].name == "embedding"
assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS])
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS])
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# RR = RR x RR
assert 'RR = R x RR' in strategy_name_list
# SR = SR x RR
assert 'S0R = S0 x RR_0' in strategy_name_list
assert 'S0R = S0 x RR_1' in strategy_name_list
assert 'S0R = S0 x RR_2' in strategy_name_list
assert 'S1R = S1 x RR_0' in strategy_name_list
assert 'S1R = S1 x RR_1' in strategy_name_list
assert 'S1R = S1 x RR_2' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0 x RS1_0' in strategy_name_list
assert 'S0S1 = S0 x RS1_1' in strategy_name_list
assert 'S0S1 = S0 x RS1_2' in strategy_name_list
assert 'S1S0 = S1 x RS0_0' in strategy_name_list
assert 'S1S0 = S1 x RS0_1' in strategy_name_list
assert 'S1S0 = S1 x RS0_2' in strategy_name_list
# RS= RR x RS
assert 'RS0 = R x RS0' in strategy_name_list
assert 'RS1 = R x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01 x RR_0' in strategy_name_list
assert 'S01R = S01 x RR_1' in strategy_name_list
assert 'S01R = S01 x RR_2' in strategy_name_list
# RS01 = RR x RS01
assert 'RS01 = R x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
output_sharding_spec = strategy.get_sharding_spec_by_name('embedding')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1]
class EmbeddingFunction(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, others):
x = nn.functional.embedding(input, others)
return x
def check_embedding_function_handler(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = EmbeddingFunction().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS
input = input.to(torch.int64).cuda()
others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda()
input_args = [input, others]
meta_arg_names = ['input', 'others']
input_kwargs = {}
# total number of embedding strategies
strategy_number = 19
node_index = 2
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs)
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others]
# %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False})
# return embedding
meta_args = {
"input": torch.rand(4, 16, 16).to('meta'),
"others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta')
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
embedding_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(embedding_node)
# build handler
handler = EmbeddingFunctionHandler(node=embedding_node,
device_mesh=device_mesh,
strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([1024])
assert mapping['other'].name == "others"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['output'].name == "embedding"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS])
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS])
handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# RR = RR x RR
assert 'RR = R x RR' in strategy_name_list
# SR = SR x RR
assert 'S0R = S0 x RR_0' in strategy_name_list
assert 'S0R = S0 x RR_1' in strategy_name_list
assert 'S0R = S0 x RR_2' in strategy_name_list
assert 'S1R = S1 x RR_0' in strategy_name_list
assert 'S1R = S1 x RR_1' in strategy_name_list
assert 'S1R = S1 x RR_2' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0 x RS1_0' in strategy_name_list
assert 'S0S1 = S0 x RS1_1' in strategy_name_list
assert 'S0S1 = S0 x RS1_2' in strategy_name_list
assert 'S1S0 = S1 x RS0_0' in strategy_name_list
assert 'S1S0 = S1 x RS0_1' in strategy_name_list
assert 'S1S0 = S1 x RS0_2' in strategy_name_list
# RS= RR x RS
assert 'RS0 = R x RS0' in strategy_name_list
assert 'RS1 = R x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01 x RR_0' in strategy_name_list
assert 'S01R = S01 x RR_1' in strategy_name_list
assert 'S01R = S01 x RR_2' in strategy_name_list
# RS01 = RR x RS01
assert 'RS01 = R x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
output_sharding_spec = strategy.get_sharding_spec_by_name('embedding')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_module_handler():
world_size = 4
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_function_handler():
world_size = 4
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_embedding_module_handler()
test_embedding_function_handler()

View File

@@ -13,7 +13,7 @@ from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing.comparison import assert_close, assert_close_loose
from colossalai.testing.comparison import assert_close
def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor],
@@ -32,8 +32,12 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso
param.register_hook(hook_fn)
arg_to_compare = copy.deepcopy(input_tensor)
arg_to_compare.requires_grad = True
wrapper(arg_to_compare, arg_index)
# only Tensors of floating point and complex dtype can require gradients
if arg_to_compare.dtype != torch.int64:
arg_to_compare.requires_grad = True
wrapper(arg_to_compare, arg_index)
args_to_compare.append(arg_to_compare)
for name, input_kwarg in input_kwargs.items():
@@ -46,8 +50,12 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso
param.register_hook(hook_fn)
kwarg_to_compare = copy.deepcopy(input_kwarg)
kwarg_to_compare.requires_grad = True
wrapper(kwarg_to_compare, name)
# only Tensors of floating point and complex dtype can require gradients
if kwarg_to_compare.dtype != torch.int64:
kwarg_to_compare.requires_grad = True
wrapper(kwarg_to_compare, name)
kwargs_to_compare[name] = kwarg_to_compare
return model_to_compare, args_to_compare, kwargs_to_compare
@@ -160,7 +168,6 @@ def assert_close_helper(first: torch.Tensor,
"""
This method is used to check whether the average difference between two tensors is as close as expected.
"""
# average_diff_tensor = ((first - second)/(second+0.1)).sum()/second.numel()
try:
if isinstance(first, (tuple, list)):
for first_element, second_element in zip(first, second):