mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)
This commit is contained in:
parent
11f54c7b6b
commit
26d4ab8b03
@ -1,7 +1,9 @@
|
|||||||
|
from .spec import ComputePattern, ParallelAction, TensorSpec
|
||||||
from .op_wrapper import (
|
from .op_wrapper import (
|
||||||
colo_op_impl,)
|
colo_op_impl,)
|
||||||
from .colo_tensor import ColoTensor
|
from .colo_tensor import ColoTensor
|
||||||
from .utils import convert_parameter
|
from .utils import convert_parameter
|
||||||
from ._ops import *
|
from ._ops import *
|
||||||
|
|
||||||
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl']
|
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern',
|
||||||
|
'TensorSpec', 'ParallelAction']
|
||||||
|
@ -6,8 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
|
|||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.nn.functional.linear)
|
@colo_op_impl(torch.nn.functional.linear)
|
||||||
def colo_linear(types, args, kwargs, pg):
|
def colo_linear(types, args, kwargs, pg):
|
||||||
@ -30,32 +29,36 @@ def colo_linear(types, args, kwargs, pg):
|
|||||||
|
|
||||||
# Add communication logic before and after linear call.
|
# Add communication logic before and after linear call.
|
||||||
if isinstance(weight, ColoTensor):
|
if isinstance(weight, ColoTensor):
|
||||||
if weight.shard_spec == None:
|
if weight.shard_spec == None or weight.shard_spec.num_action == 0:
|
||||||
if isinstance(input_tensor, ColoTensor):
|
if isinstance(input_tensor, ColoTensor):
|
||||||
input_tensor = input_tensor.torch_tensor()
|
input_tensor = input_tensor.torch_tensor()
|
||||||
if isinstance(weight, ColoTensor):
|
if isinstance(weight, ColoTensor):
|
||||||
weight = weight.torch_tensor()
|
weight = weight.torch_tensor()
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
elif weight.shard_spec == '1Drow':
|
elif weight.shard_spec.num_action == 1:
|
||||||
|
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
|
||||||
# Input:S[1] x Weight:S[0] = Output:P
|
# Input:S[1] x Weight:S[0] = Output:P
|
||||||
# All-Reduce(Output) + bias = res
|
# All-Reduce(Output) + bias = res
|
||||||
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
|
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
|
||||||
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||||
input_tensor.shape, weight.size, weight.size[-1] * gpc.tensor_parallel_size)
|
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
|
||||||
# Input:S[1]
|
# Input:S[1]
|
||||||
input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1)
|
if isinstance(input_tensor, ColoTensor):
|
||||||
|
input_tensor = input_tensor.torch_tensor()
|
||||||
|
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||||
|
input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1)
|
||||||
# Output:P
|
# Output:P
|
||||||
device = get_current_device() # TODO where to put to(deivce)?
|
weight_ = weight.torch_tensor()
|
||||||
weight_ = weight.torch_tensor().to(device)
|
|
||||||
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
|
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
|
||||||
# Reduce(Output)
|
# Reduce(Output)
|
||||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||||
# Bias
|
# Bias
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias_ = bias.to(device)
|
bias_ = bias
|
||||||
output = output + bias_
|
output = output + bias_
|
||||||
return output
|
return ColoTensor.init_from_torch_tensor(output)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
|
from colossalai.context import parallel_mode
|
||||||
from .op_wrapper import _COLOSSAL_OPS
|
from .op_wrapper import _COLOSSAL_OPS
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from numpy import product
|
from numpy import product
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||||
|
|
||||||
|
|
||||||
class ColoTensor(object):
|
class ColoTensor(object):
|
||||||
""" Data Structure for Tensor in Colossal-AI
|
""" Data Structure for Tensor in Colossal-AI
|
||||||
@ -28,7 +27,7 @@ class ColoTensor(object):
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
device=None,
|
device=None,
|
||||||
torch_tensor=torch.empty(0),
|
torch_tensor=torch.empty(0),
|
||||||
shard_spec: str = None,
|
shard_spec: TensorSpec = TensorSpec(),
|
||||||
):
|
):
|
||||||
self._size = size
|
self._size = size
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
@ -39,7 +38,7 @@ class ColoTensor(object):
|
|||||||
self._shard_spec = shard_spec
|
self._shard_spec = shard_spec
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shard_spec(self) -> Optional[str]:
|
def shard_spec(self) -> TensorSpec:
|
||||||
return self._shard_spec
|
return self._shard_spec
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -109,19 +108,20 @@ class ColoTensor(object):
|
|||||||
device=self._device)
|
device=self._device)
|
||||||
return self._torch_tensor
|
return self._torch_tensor
|
||||||
|
|
||||||
def set_spec(self, spec: str, lazy_shard: bool = False) -> None:
|
def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None:
|
||||||
self._shard_spec = spec
|
self._shard_spec = spec
|
||||||
if lazy_shard == False:
|
if lazy_shard == False:
|
||||||
self._shard()
|
self._shard()
|
||||||
|
|
||||||
def _shard(self):
|
def _shard(self):
|
||||||
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
||||||
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
|
if self._shard_spec.num_action == 1:
|
||||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
|
||||||
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||||
|
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||||
|
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||||
dim = -1
|
dim = -1
|
||||||
chunk_size = divide(self._size[dim], num_partition)
|
chunk_size = divide(self._size[dim], num_partition)
|
||||||
device = get_current_device()
|
|
||||||
# Reshape to get shard for this rank and we don't want autograd
|
# Reshape to get shard for this rank and we don't want autograd
|
||||||
# recording here for the narrow op and 'local_shard' should be a
|
# recording here for the narrow op and 'local_shard' should be a
|
||||||
# leaf variable in the autograd graph.
|
# leaf variable in the autograd graph.
|
||||||
@ -129,7 +129,6 @@ class ColoTensor(object):
|
|||||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||||
self._torch_tensor.requires_grad = self._requires_grad
|
self._torch_tensor.requires_grad = self._requires_grad
|
||||||
self._size = self._torch_tensor.size()
|
self._size = self._torch_tensor.size()
|
||||||
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
@ -151,5 +150,5 @@ class ColoTensor(object):
|
|||||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
def backward(self, retain_graph: bool = False):
|
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
|
||||||
self._torch_tensor.backward(retain_graph=retain_graph)
|
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
|
|
||||||
|
|
||||||
class ComputePattern(Enum):
|
class ComputePattern(Enum):
|
||||||
TP1DRow = 1
|
TP1DRow = 1
|
||||||
@ -12,17 +10,13 @@ class ComputePattern(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class ParallelAction(object):
|
class ParallelAction(object):
|
||||||
priority = 0
|
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
|
||||||
compute_pattern = ComputePattern.DP
|
|
||||||
process_group = gpc.get_group(ParallelMode.DATA)
|
|
||||||
|
|
||||||
def __init__(self, priority, compute_pattern, process_group) -> None:
|
|
||||||
self.priority = priority
|
self.priority = priority
|
||||||
self.compute_pattern = compute_pattern
|
self.compute_pattern = compute_pattern
|
||||||
self.process_group = process_group
|
self.parallel_mode = parallel_mode
|
||||||
|
|
||||||
|
|
||||||
class TensorSpec(Enum):
|
class TensorSpec(object):
|
||||||
"""
|
"""
|
||||||
It contains two aspects of information:
|
It contains two aspects of information:
|
||||||
First, How are tensors distributed in Heterougenous memory space.
|
First, How are tensors distributed in Heterougenous memory space.
|
||||||
@ -44,4 +38,28 @@ class TensorSpec(Enum):
|
|||||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||||
# We perform Linear Op according to compute pattern of TP1DRow.
|
# We perform Linear Op according to compute pattern of TP1DRow.
|
||||||
# After Linear Op, we split the tensors according to ZeRO.
|
# After Linear Op, we split the tensors according to ZeRO.
|
||||||
parallel_action_list: List[ParallelAction] = []
|
def __init__(self, parallel_action_list: List[ParallelAction] = []):
|
||||||
|
self._parallel_action_list = parallel_action_list
|
||||||
|
self.sort()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parallel_action_list(self):
|
||||||
|
return self._parallel_action_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_action(self):
|
||||||
|
return len(self._parallel_action_list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def compute_patterns(self):
|
||||||
|
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
|
||||||
|
|
||||||
|
def sort(self):
|
||||||
|
if len(self._parallel_action_list) > 0:
|
||||||
|
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
|
||||||
|
|
||||||
|
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
|
||||||
|
for parallel_action in self._parallel_action_list:
|
||||||
|
if parallel_action.compute_pattern == compute_pattern:
|
||||||
|
return parallel_action
|
||||||
|
return None
|
||||||
|
@ -12,6 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
|||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||||
|
|
||||||
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
|
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
|
||||||
|
|
||||||
@ -45,7 +46,11 @@ def run_linear_tp1d_row_test():
|
|||||||
|
|
||||||
# replace the torch nn.Parameters with ColoTensor
|
# replace the torch nn.Parameters with ColoTensor
|
||||||
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||||
sharded_weight.set_spec(spec="1Drow") # reshard
|
parallel_action_list = [
|
||||||
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
|
]
|
||||||
|
spec = TensorSpec(parallel_action_list)
|
||||||
|
sharded_weight.set_spec(spec=spec) # reshard
|
||||||
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
||||||
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
|
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
|
||||||
out = layer(A)
|
out = layer(A)
|
||||||
|
Loading…
Reference in New Issue
Block a user