mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 01:39:26 +00:00
[Tensor] add ColoTensor TP1Dcol Embedding (#899)
This commit is contained in:
parent
e46e423c00
commit
2c0d19d755
@ -2,3 +2,4 @@ from .linear import colo_linear
|
|||||||
from .element_wise import *
|
from .element_wise import *
|
||||||
from .layernorm import colo_layernorm
|
from .layernorm import colo_layernorm
|
||||||
from .loss import colo_cross_entropy
|
from .loss import colo_cross_entropy
|
||||||
|
from .embedding import colo_embedding
|
56
colossalai/tensor/_ops/embedding.py
Normal file
56
colossalai/tensor/_ops/embedding.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, \
|
||||||
|
gather_forward_split_backward, reduce_grad
|
||||||
|
from colossalai.nn.layer.utils import divide
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from packaging import version
|
||||||
|
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
|
||||||
|
|
||||||
|
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
||||||
|
# embedding_1Dcol split the weight(lookup table)
|
||||||
|
# Gather splitted lookup table
|
||||||
|
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Embedding)
|
||||||
|
if not input_tensor.is_gathered():
|
||||||
|
input_tensor.gather()
|
||||||
|
|
||||||
|
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(),
|
||||||
|
*args, **kwargs)
|
||||||
|
output = ColoTensor.init_from_torch_tensor(output_parallel)
|
||||||
|
out_parallel_action_list = [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]
|
||||||
|
output_spec = TensorSpec(out_parallel_action_list)
|
||||||
|
output.set_spec(output_spec, shard=False)
|
||||||
|
output.set_shard_pattern(ShardPattern.Col)
|
||||||
|
output.gather()
|
||||||
|
return output
|
||||||
|
|
||||||
|
@colo_op_impl(torch.nn.functional.embedding)
|
||||||
|
def colo_embedding(types, args, kwargs, pg):
|
||||||
|
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
|
||||||
|
This method looks up an embedding table.
|
||||||
|
"""
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
args = args[2:]
|
||||||
|
|
||||||
|
if not isinstance(input_tensor, ColoTensor):
|
||||||
|
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
|
||||||
|
|
||||||
|
if not isinstance(weight, ColoTensor):
|
||||||
|
weight = ColoTensor.init_from_torch_tensor(weight)
|
||||||
|
|
||||||
|
# Handle differen parallel actions.
|
||||||
|
if not weight.has_spec(): # No Model Parallel Applied
|
||||||
|
input_tensor = input_tensor.torch_tensor()
|
||||||
|
weight = weight.torch_tensor()
|
||||||
|
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
|
||||||
|
return ColoTensor.init_from_torch_tensor(output)
|
||||||
|
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
|
||||||
|
compute_patterns = weight.shard_spec.compute_patterns
|
||||||
|
if ComputePattern.TP1DCol_Embedding in compute_patterns:
|
||||||
|
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
@ -27,7 +27,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
|
|||||||
eps = kwargs['eps']
|
eps = kwargs['eps']
|
||||||
|
|
||||||
if isinstance(input_tensor, ColoTensor):
|
if isinstance(input_tensor, ColoTensor):
|
||||||
if input_tensor.is_activation() and not input_tensor.is_gathered():
|
if not input_tensor.is_gathered():
|
||||||
input_tensor.gather()
|
input_tensor.gather()
|
||||||
input_tensor = input_tensor.torch_tensor()
|
input_tensor = input_tensor.torch_tensor()
|
||||||
if isinstance(weight, ColoTensor):
|
if isinstance(weight, ColoTensor):
|
||||||
|
@ -9,8 +9,8 @@ from packaging import version
|
|||||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
|
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
|
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor:
|
||||||
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Linear)
|
||||||
# Input:S[1] x Weight:S[0] = Output:P
|
# Input:S[1] x Weight:S[0] = Output:P
|
||||||
# All-Reduce(Output) + bias = res
|
# All-Reduce(Output) + bias = res
|
||||||
# Input:S[1]
|
# Input:S[1]
|
||||||
@ -47,7 +47,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
|
|||||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||||
# All-Gather(Output)
|
# All-Gather(Output)
|
||||||
# Input:B
|
# Input:B
|
||||||
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Linear)
|
||||||
if input_tensor.is_gathered():
|
if input_tensor.is_gathered():
|
||||||
# Not splited yet.
|
# Not splited yet.
|
||||||
assert input_tensor.shape[-1] == weight.size(-1), \
|
assert input_tensor.shape[-1] == weight.size(-1), \
|
||||||
@ -108,9 +108,9 @@ def colo_linear(types, args, kwargs, pg):
|
|||||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
||||||
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
|
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
|
||||||
compute_patterns = weight.shard_spec.compute_patterns
|
compute_patterns = weight.shard_spec.compute_patterns
|
||||||
if ComputePattern.TP1DRow in compute_patterns:
|
if ComputePattern.TP1DRow_Linear in compute_patterns:
|
||||||
return colo_linear_1Drow(input_tensor, weight, bias)
|
return colo_linear_1Drow(input_tensor, weight, bias)
|
||||||
elif ComputePattern.TP1DCol in compute_patterns:
|
elif ComputePattern.TP1DCol_Linear in compute_patterns:
|
||||||
return colo_linear_1Dcol(input_tensor, weight, bias)
|
return colo_linear_1Dcol(input_tensor, weight, bias)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -142,14 +142,19 @@ class ColoTensor(object):
|
|||||||
if self._shard_pattern is not ShardPattern.NA: # reshard
|
if self._shard_pattern is not ShardPattern.NA: # reshard
|
||||||
self.gather()
|
self.gather()
|
||||||
# Model Parameters
|
# Model Parameters
|
||||||
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
|
if self._shard_spec.num_action == 1:
|
||||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(
|
||||||
self._shard_1d(parallel_action=parallel_action, dim=-1)
|
self._shard_spec.compute_patterns[0])
|
||||||
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
|
if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \
|
||||||
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
|
ComputePattern.TP1DCol_Embedding]:
|
||||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
self._shard_1d(parallel_action=parallel_action, dim=-1)
|
||||||
self._shard_1d(parallel_action=parallel_action, dim=0)
|
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
|
||||||
self._shard_pattern = ShardPattern.Row
|
elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \
|
||||||
|
ComputePattern.TP1DRow_Embedding]:
|
||||||
|
self._shard_1d(parallel_action=parallel_action, dim=0)
|
||||||
|
self._shard_pattern = ShardPattern.Row
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def gather(self):
|
def gather(self):
|
||||||
assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.'
|
assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.'
|
||||||
|
@ -4,10 +4,12 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||||||
|
|
||||||
|
|
||||||
class ComputePattern(Enum):
|
class ComputePattern(Enum):
|
||||||
TP1DRow = 1
|
TP1DRow_Linear = 1
|
||||||
TP1DCol = 2
|
TP1DCol_Linear = 2
|
||||||
ZeRO = 3
|
TP1DRow_Embedding = 3
|
||||||
DP = 4
|
TP1DCol_Embedding = 4
|
||||||
|
ZeRO = 5
|
||||||
|
DP = 6
|
||||||
|
|
||||||
|
|
||||||
class ShardPattern(Enum):
|
class ShardPattern(Enum):
|
||||||
@ -43,14 +45,14 @@ class TensorSpec(object):
|
|||||||
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
||||||
# parallel_action_list = [
|
# parallel_action_list = [
|
||||||
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
|
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
|
||||||
# ParallelAction(1, ComputePattern.TP1DRow, gpc.get_group(ParallelMode.PARALLEL_1D))
|
# ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||||
# ]
|
# ]
|
||||||
# When the ColoTensor is initialized,
|
# When the ColoTensor is initialized,
|
||||||
# we first splitting tensor according to ParallelAction of ZeRO,
|
# we first splitting tensor according to ParallelAction of ZeRO,
|
||||||
# then splitting tensor according to ParallelAction of TP1DRow.
|
# then splitting tensor according to ParallelAction of TP1DRow_Linear.
|
||||||
# During Linear computation
|
# During Linear computation
|
||||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||||
# We perform Linear Op according to compute pattern of TP1DRow.
|
# We perform Linear Op according to compute pattern of TP1DRow_Linear.
|
||||||
# After Linear Op, we split the tensors according to ZeRO.
|
# After Linear Op, we split the tensors according to ZeRO.
|
||||||
|
|
||||||
def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA):
|
def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA):
|
||||||
|
82
tests/test_tensor/test_embedding_tp.py
Normal file
82
tests/test_tensor/test_embedding_tp.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import torch
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||||
|
|
||||||
|
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
|
||||||
|
|
||||||
|
def run_embedding_tp1d_col_test():
|
||||||
|
device = get_current_device()
|
||||||
|
dtype = torch.float32
|
||||||
|
DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
|
num_embeddings = 12
|
||||||
|
embedding_dim = 32
|
||||||
|
|
||||||
|
local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
|
||||||
|
layer_master = torch.nn.Embedding(num_embeddings, embedding_dim)
|
||||||
|
layer = torch.nn.Embedding(num_embeddings, embedding_dim)
|
||||||
|
|
||||||
|
A_master = torch.tensor((0,3,6,9), device=device)
|
||||||
|
A = broadcast_tensor_chunk(A_master, chunk_size=1)
|
||||||
|
|
||||||
|
W_shape = (num_embeddings, embedding_dim)
|
||||||
|
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||||
|
W = broadcast_tensor_chunk(W_master, chunk_size=1)
|
||||||
|
W.requires_grad = True
|
||||||
|
|
||||||
|
# replace the torch nn.Parameters with ColoTensor
|
||||||
|
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||||
|
parallel_action_list = [
|
||||||
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding,
|
||||||
|
parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
|
]
|
||||||
|
spec = TensorSpec(parallel_action_list)
|
||||||
|
sharded_weight.set_spec(spec) # reshard
|
||||||
|
replace_parameter_add_grad(layer, sharded_weight)
|
||||||
|
out = layer(A)
|
||||||
|
|
||||||
|
replace_parameter_add_grad(layer_master, W_master)
|
||||||
|
C_master = layer_master(A_master)
|
||||||
|
C = C_master.clone()
|
||||||
|
|
||||||
|
check_equal(out, C)
|
||||||
|
|
||||||
|
grad_shape = C_master.shape
|
||||||
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||||
|
grad = broadcast_tensor_chunk(grad_master, chunk_size=1)
|
||||||
|
out.backward(grad)
|
||||||
|
|
||||||
|
grad_master = grad_master.clone()
|
||||||
|
C_master.backward(grad_master)
|
||||||
|
|
||||||
|
W_grad = W_master.grad
|
||||||
|
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank]
|
||||||
|
check_equal(W_grad, layer.weight.grad)
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_embedding_tp1d_col_test()
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@parameterize('world_size', [1, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_embedding_1d(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_embedding_1d()
|
@ -47,7 +47,7 @@ def run_linear_tp1d_col_test():
|
|||||||
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||||
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
||||||
parallel_action_list = [
|
parallel_action_list = [
|
||||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
]
|
]
|
||||||
spec = TensorSpec(parallel_action_list)
|
spec = TensorSpec(parallel_action_list)
|
||||||
sharded_weight.set_spec(spec) # reshard
|
sharded_weight.set_spec(spec) # reshard
|
||||||
@ -110,7 +110,7 @@ def run_linear_tp1d_row_test():
|
|||||||
# replace the torch nn.Parameters with ColoTensor
|
# replace the torch nn.Parameters with ColoTensor
|
||||||
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||||
parallel_action_list = [
|
parallel_action_list = [
|
||||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
]
|
]
|
||||||
spec = TensorSpec(parallel_action_list)
|
spec = TensorSpec(parallel_action_list)
|
||||||
sharded_weight.set_spec(spec=spec) # reshard
|
sharded_weight.set_spec(spec=spec) # reshard
|
||||||
@ -145,7 +145,7 @@ def run_linear_tp1d_row_test():
|
|||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
#run_linear_tp1d_row_test()
|
run_linear_tp1d_row_test()
|
||||||
run_linear_tp1d_col_test()
|
run_linear_tp1d_col_test()
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -38,12 +38,12 @@ def run_1d_col_tp():
|
|||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
parallel_action_list_row = [
|
parallel_action_list_row = [
|
||||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
]
|
]
|
||||||
spec_row = TensorSpec(parallel_action_list_row)
|
spec_row = TensorSpec(parallel_action_list_row)
|
||||||
|
|
||||||
parallel_action_list_col = [
|
parallel_action_list_col = [
|
||||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
]
|
]
|
||||||
spec_col = TensorSpec(parallel_action_list_col)
|
spec_col = TensorSpec(parallel_action_list_col)
|
||||||
|
|
||||||
@ -168,7 +168,7 @@ def run_1d_row_tp():
|
|||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
parallel_action_list = [
|
parallel_action_list = [
|
||||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
]
|
]
|
||||||
spec = TensorSpec(parallel_action_list)
|
spec = TensorSpec(parallel_action_list)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user