mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 11:58:06 +00:00
[Tensor] add embedding tp1d row (#904)
This commit is contained in:
parent
16122d5fac
commit
f593a5637e
@ -9,7 +9,7 @@ from packaging import version
|
|||||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
|
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
|
||||||
|
|
||||||
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
||||||
# embedding_1Dcol split the weight(lookup table)
|
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||||
# Gather splitted lookup table
|
# Gather splitted lookup table
|
||||||
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Embedding)
|
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Embedding)
|
||||||
if not input_tensor.is_gathered():
|
if not input_tensor.is_gathered():
|
||||||
@ -25,6 +25,37 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
|
|||||||
output.gather()
|
output.gather()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
|
||||||
|
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||||
|
# Find index in this shard and mask those not here
|
||||||
|
# Reduce all
|
||||||
|
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Embedding)
|
||||||
|
if not input_tensor.is_gathered():
|
||||||
|
input_tensor.gather()
|
||||||
|
|
||||||
|
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||||
|
num_embeddings_per_partition = weight.size(0)
|
||||||
|
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||||
|
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||||
|
|
||||||
|
# Build the mask.
|
||||||
|
input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \
|
||||||
|
(input_tensor.torch_tensor() >= vocab_end_index)
|
||||||
|
# Mask the input.
|
||||||
|
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
|
||||||
|
masked_input = input_tensor.torch_tensor().clone() - vocab_start_index
|
||||||
|
masked_input[input_mask] = 0
|
||||||
|
|
||||||
|
partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(),
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
|
# Mask the output embedding.
|
||||||
|
partial_output[input_mask, :] = 0.
|
||||||
|
# Reduce across all the model parallel GPUs.
|
||||||
|
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||||
|
output = ColoTensor.init_from_torch_tensor(output)
|
||||||
|
return output
|
||||||
|
|
||||||
@colo_op_impl(torch.nn.functional.embedding)
|
@colo_op_impl(torch.nn.functional.embedding)
|
||||||
def colo_embedding(types, args, kwargs, pg):
|
def colo_embedding(types, args, kwargs, pg):
|
||||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
|
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
|
||||||
@ -48,7 +79,9 @@ def colo_embedding(types, args, kwargs, pg):
|
|||||||
return ColoTensor.init_from_torch_tensor(output)
|
return ColoTensor.init_from_torch_tensor(output)
|
||||||
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
|
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
|
||||||
compute_patterns = weight.shard_spec.compute_patterns
|
compute_patterns = weight.shard_spec.compute_patterns
|
||||||
if ComputePattern.TP1DCol_Embedding in compute_patterns:
|
if ComputePattern.TP1DRow_Embedding in compute_patterns:
|
||||||
|
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
|
||||||
|
elif ComputePattern.TP1DCol_Embedding in compute_patterns:
|
||||||
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
|
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -166,6 +166,7 @@ class ColoTensor(object):
|
|||||||
dim = -1
|
dim = -1
|
||||||
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
||||||
self._shard_pattern = ShardPattern.NA
|
self._shard_pattern = ShardPattern.NA
|
||||||
|
self._size = self._torch_tensor.size()
|
||||||
|
|
||||||
def is_gathered(self) -> bool:
|
def is_gathered(self) -> bool:
|
||||||
return self._shard_pattern == ShardPattern.NA
|
return self._shard_pattern == ShardPattern.NA
|
||||||
|
@ -5,7 +5,6 @@ from .utils.dummy_data_generator import DummyDataGenerator
|
|||||||
from .registry import non_distributed_component_funcs
|
from .registry import non_distributed_component_funcs
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class SimpleNet(CheckpointModule):
|
class SimpleNet(CheckpointModule):
|
||||||
"""
|
"""
|
||||||
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
|
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
|
||||||
@ -13,12 +12,14 @@ class SimpleNet(CheckpointModule):
|
|||||||
|
|
||||||
def __init__(self, checkpoint=False) -> None:
|
def __init__(self, checkpoint=False) -> None:
|
||||||
super().__init__(checkpoint=checkpoint)
|
super().__init__(checkpoint=checkpoint)
|
||||||
|
self.embed = nn.Embedding(20, 4)
|
||||||
self.proj1 = nn.Linear(4, 8)
|
self.proj1 = nn.Linear(4, 8)
|
||||||
self.ln1 = nn.LayerNorm(8)
|
self.ln1 = nn.LayerNorm(8)
|
||||||
self.proj2 = nn.Linear(8, 4)
|
self.proj2 = nn.Linear(8, 4)
|
||||||
self.ln2 = nn.LayerNorm(4)
|
self.ln2 = nn.LayerNorm(4)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
x = self.embed(x)
|
||||||
x = self.proj1(x)
|
x = self.proj1(x)
|
||||||
x = self.ln1(x)
|
x = self.ln1(x)
|
||||||
x = self.proj2(x)
|
x = self.proj2(x)
|
||||||
@ -26,11 +27,12 @@ class SimpleNet(CheckpointModule):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDataLoader(DummyDataGenerator):
|
class DummyDataLoader(DummyDataGenerator):
|
||||||
|
|
||||||
def generate(self):
|
def generate(self):
|
||||||
data = torch.rand(16, 4, device=get_current_device())
|
data = torch.randint(low=0, high=20, size=(16,20), device=get_current_device())
|
||||||
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
|
label = torch.randint(low=0, high=2, size=(16,4), device=get_current_device())
|
||||||
return data, label
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,10 +65,60 @@ def run_embedding_tp1d_col_test():
|
|||||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank]
|
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank]
|
||||||
check_equal(W_grad, layer.weight.grad)
|
check_equal(W_grad, layer.weight.grad)
|
||||||
|
|
||||||
|
def run_embedding_tp1d_row_test():
|
||||||
|
device = get_current_device()
|
||||||
|
dtype = torch.float32
|
||||||
|
DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
|
num_embeddings = 12
|
||||||
|
embedding_dim = 32
|
||||||
|
|
||||||
|
local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
|
||||||
|
layer_master = torch.nn.Embedding(num_embeddings, embedding_dim)
|
||||||
|
layer = torch.nn.Embedding(num_embeddings, embedding_dim)
|
||||||
|
|
||||||
|
A_master = torch.tensor((0,3,6,9), device=device)
|
||||||
|
A = broadcast_tensor_chunk(A_master, chunk_size=1)
|
||||||
|
|
||||||
|
W_shape = (num_embeddings, embedding_dim)
|
||||||
|
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||||
|
W = broadcast_tensor_chunk(W_master, chunk_size=1)
|
||||||
|
W.requires_grad = True
|
||||||
|
|
||||||
|
# replace the torch nn.Parameters with ColoTensor
|
||||||
|
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||||
|
parallel_action_list = [
|
||||||
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding,
|
||||||
|
parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
|
]
|
||||||
|
spec = TensorSpec(parallel_action_list)
|
||||||
|
sharded_weight.set_spec(spec) # reshard
|
||||||
|
replace_parameter_add_grad(layer, sharded_weight)
|
||||||
|
out = layer(A)
|
||||||
|
|
||||||
|
replace_parameter_add_grad(layer_master, W_master)
|
||||||
|
C_master = layer_master(A_master)
|
||||||
|
C = C_master.clone()
|
||||||
|
|
||||||
|
check_equal(out, C)
|
||||||
|
|
||||||
|
grad_shape = C_master.shape
|
||||||
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||||
|
grad = broadcast_tensor_chunk(grad_master, chunk_size=1)
|
||||||
|
out.backward(grad)
|
||||||
|
|
||||||
|
grad_master = grad_master.clone()
|
||||||
|
C_master.backward(grad_master)
|
||||||
|
|
||||||
|
W_grad = W_master.grad
|
||||||
|
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[local_rank]
|
||||||
|
check_equal(W_grad, layer.weight.grad)
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_embedding_tp1d_col_test()
|
run_embedding_tp1d_col_test()
|
||||||
|
run_embedding_tp1d_row_test()
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@parameterize('world_size', [1, 4])
|
@parameterize('world_size', [1, 4])
|
||||||
|
@ -47,6 +47,11 @@ def run_1d_col_tp():
|
|||||||
]
|
]
|
||||||
spec_col = TensorSpec(parallel_action_list_col)
|
spec_col = TensorSpec(parallel_action_list_col)
|
||||||
|
|
||||||
|
parallel_action_list_embedding_col = [
|
||||||
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
|
]
|
||||||
|
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
model_torch = model_builder(checkpoint=True)
|
model_torch = model_builder(checkpoint=True)
|
||||||
@ -60,6 +65,8 @@ def run_1d_col_tp():
|
|||||||
p.set_spec(spec_col)
|
p.set_spec(spec_col)
|
||||||
if 'proj2' in name and 'weight' in name:
|
if 'proj2' in name and 'weight' in name:
|
||||||
p.set_spec(spec_row)
|
p.set_spec(spec_row)
|
||||||
|
if 'embed' in name and 'weight' in name:
|
||||||
|
p.set_spec(spec_embedding_col)
|
||||||
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
@ -172,6 +179,11 @@ def run_1d_row_tp():
|
|||||||
]
|
]
|
||||||
spec = TensorSpec(parallel_action_list)
|
spec = TensorSpec(parallel_action_list)
|
||||||
|
|
||||||
|
parallel_action_list_embedding_row = [
|
||||||
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
|
]
|
||||||
|
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
model_torch = model_builder(checkpoint=True)
|
model_torch = model_builder(checkpoint=True)
|
||||||
@ -183,6 +195,8 @@ def run_1d_row_tp():
|
|||||||
continue
|
continue
|
||||||
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
||||||
p.set_spec(spec)
|
p.set_spec(spec)
|
||||||
|
if 'embed' in name and 'weight' in name:
|
||||||
|
p.set_spec(spec_embedding_row)
|
||||||
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
@ -227,7 +241,7 @@ def run_dist(rank, world_size, port):
|
|||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_1d_row_tp()
|
run_1d_row_tp()
|
||||||
|
run_1d_col_tp()
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@parameterize('world_size', [1, 4])
|
@parameterize('world_size', [1, 4])
|
||||||
@ -238,6 +252,6 @@ def test_simple_net(world_size):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_simple_net()
|
test_simple_net()
|
||||||
# test_model_parameters()
|
# test_model_parameters()
|
||||||
test_colo_optimizer()
|
# test_colo_optimizer()
|
||||||
|
Loading…
Reference in New Issue
Block a user