mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[hotfix] fix shape error in backward when using ColoTensor (#1298)
This commit is contained in:
@@ -11,16 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||||||
# mat1:S[1] x mat2:S[0] = Output:P
|
# mat1:S[1] x mat2:S[0] = Output:P
|
||||||
# beta * input + alpha * All-Reduce(Output) = res
|
# beta * input + alpha * All-Reduce(Output) = res
|
||||||
|
|
||||||
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
|
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group())
|
||||||
|
|
||||||
# Output:P
|
# Output:P
|
||||||
partial_output = torch.mm(mat1, mat2)
|
partial_output = torch.mm(mat1, mat2)
|
||||||
# Reduce(Output)
|
# Reduce(Output)
|
||||||
output = reduce_input(partial_output, mat1.get_process_group())
|
output = reduce_input(partial_output, mat2.get_process_group())
|
||||||
# input
|
# input
|
||||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||||
output = beta * input_tensor + alpha * output
|
output = beta * input_tensor + alpha * output
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
|
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group()))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@@ -3,15 +3,15 @@ from typing import Optional
|
|||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from ._utils import reduce_input, reduce_grad
|
from ._utils import reduce_input, reduce_grad
|
||||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
|
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
# Input:S[1] x Weight:S[0] = Output:P
|
# Input:S[1] x Weight:S[0] = Output:P
|
||||||
# All-Reduce(Output) + bias = res
|
# All-Reduce(Output) + bias = res
|
||||||
# Input:S[1]
|
# Input:S[1]
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
|
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg)
|
||||||
|
|
||||||
# Output:P
|
# Output:P
|
||||||
partial_output = F.linear(input_tensor, weight)
|
partial_output = F.linear(input_tensor, weight)
|
||||||
@@ -27,7 +27,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||||
# All-Gather(Output)
|
# All-Gather(Output)
|
||||||
# Input:B
|
# Input:B
|
||||||
@@ -48,7 +48,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||||||
|
|
||||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
assert mode in ('row', 'col')
|
assert mode in ('row', 'col')
|
||||||
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
|
funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol}
|
||||||
return funcs[mode](input_tensor, weight, bias)
|
return funcs[mode](input_tensor, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -204,12 +204,14 @@ class ColoTensor(torch.Tensor):
|
|||||||
ColoTensor: a redistributed colotensor
|
ColoTensor: a redistributed colotensor
|
||||||
"""
|
"""
|
||||||
if pg is not None and pg != self.get_process_group():
|
if pg is not None and pg != self.get_process_group():
|
||||||
print('here _redistribute')
|
|
||||||
# if the pg is not equal, convert the current tensor to replicated
|
# if the pg is not equal, convert the current tensor to replicated
|
||||||
self._redistribute(ReplicaSpec())
|
handled = self.redistribute(ReplicaSpec())
|
||||||
self.process_group = pg
|
else:
|
||||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
handled = self
|
||||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
pg = self.process_group
|
||||||
|
|
||||||
|
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
|
||||||
|
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
|
||||||
|
|
||||||
def to_replicate_(self):
|
def to_replicate_(self):
|
||||||
"""to_replicate_
|
"""to_replicate_
|
||||||
|
@@ -11,42 +11,13 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
|
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup, ReplicaSpec
|
|
||||||
from colossalai.nn.optimizer import ColoOptimizer
|
from colossalai.nn.optimizer import ColoOptimizer
|
||||||
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from _utils import split_param_row_tp1d, split_param_col_tp1d
|
from _utils import split_param_row_tp1d, split_param_col_tp1d
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
|
||||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
weight.set_process_group(pg)
|
|
||||||
weight.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_linear(weight, pg):
|
|
||||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
weight.set_process_group(pg)
|
|
||||||
weight.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_embedding(weight, pg):
|
|
||||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
weight.set_process_group(pg)
|
|
||||||
weight.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_embedding(weight, pg):
|
|
||||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
weight.set_process_group(pg)
|
|
||||||
weight.set_tensor_spec(*spec)
|
|
||||||
|
|
||||||
|
|
||||||
def run_1d_hybrid_tp(model_name):
|
def run_1d_hybrid_tp(model_name):
|
||||||
# A simple net with two stacked nn.Linear
|
# A simple net with two stacked nn.Linear
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
@@ -79,19 +50,16 @@ def run_1d_hybrid_tp(model_name):
|
|||||||
|
|
||||||
# num_class = type_vocab_size = 2 | (8, 2)
|
# num_class = type_vocab_size = 2 | (8, 2)
|
||||||
if 'classifier' in name and 'weight' in name:
|
if 'classifier' in name and 'weight' in name:
|
||||||
init_1d_row_linear(p, pg)
|
split_param_col_tp1d(p, pg)
|
||||||
# num_class = vocab_size = 30524 | (30524, 8)
|
# num_class = vocab_size = 30524 | (30524, 8)
|
||||||
elif 'word_embeddings' in name and 'weight' in name:
|
elif 'word_embeddings' in name and 'weight' in name:
|
||||||
init_1d_row_embedding(p, pg)
|
split_param_row_tp1d(p, pg)
|
||||||
# num_class = seq_len = 512 | (512, 8)
|
# num_class = seq_len = 512 | (512, 8)
|
||||||
elif 'position_embeddings' in name and 'weight' in name:
|
elif 'position_embeddings' in name and 'weight' in name:
|
||||||
init_1d_row_embedding(p, pg)
|
split_param_row_tp1d(p, pg)
|
||||||
# num_class = type_vocab_size = 2 | (2, 8)
|
# num_class = type_vocab_size = 2 | (2, 8)
|
||||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||||
init_1d_col_embedding(p, pg)
|
split_param_col_tp1d(p, pg)
|
||||||
elif p.process_group.tp_world_size() == 1:
|
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
p.redistribute(ReplicaSpec(), pg)
|
|
||||||
|
|
||||||
elif "simple_net" == model_name:
|
elif "simple_net" == model_name:
|
||||||
# A naive way to set spec for all weights in Linear
|
# A naive way to set spec for all weights in Linear
|
||||||
@@ -99,13 +67,13 @@ def run_1d_hybrid_tp(model_name):
|
|||||||
if not isinstance(p, ColoTensor):
|
if not isinstance(p, ColoTensor):
|
||||||
continue
|
continue
|
||||||
if 'embed' in name and 'weight' in name:
|
if 'embed' in name and 'weight' in name:
|
||||||
init_1d_col_embedding(p, pg)
|
split_param_col_tp1d(p, pg)
|
||||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||||
init_1d_col_linear(p, pg)
|
split_param_row_tp1d(p, pg)
|
||||||
if 'proj2' in name and 'weight' in name:
|
if 'proj2' in name and 'weight' in name:
|
||||||
init_1d_row_linear(p, pg)
|
split_param_col_tp1d(p, pg)
|
||||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||||
init_1d_col_linear(p, pg)
|
split_param_row_tp1d(p, pg)
|
||||||
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
model.train()
|
model.train()
|
||||||
@@ -327,9 +295,9 @@ def _run_pretrain_load():
|
|||||||
|
|
||||||
def run_model_dist(rank, world_size, port):
|
def run_model_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
for name in ['bert']:
|
for name in ['bert', 'simple_net']:
|
||||||
run_1d_row_tp(name)
|
run_1d_row_tp(name)
|
||||||
for name in ['bert']:
|
for name in ['bert', 'simple_net']:
|
||||||
run_1d_hybrid_tp(name)
|
run_1d_hybrid_tp(name)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user