mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer] Refactor shardformer api (#4001)
* fix an error in readme * simplify code * refactor shardformer * add todo * remove slicer * resolve code review
This commit is contained in:
@@ -4,11 +4,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from colossalai.cluster.process_group_manager import ProcessGroupManager
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
|
||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||
from ..policies.basepolicy import Policy
|
||||
from ..utils.utils import setattr_
|
||||
from .shard_config import ShardConfig
|
||||
from .slicer import Slicer
|
||||
|
||||
__all__ = ['ModelSharder', 'shard_model']
|
||||
|
||||
@@ -28,20 +29,23 @@ class ModelSharder(object):
|
||||
model: nn.Module,
|
||||
policy: Policy,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
) -> None:
|
||||
pg_manager: ProcessGroupManager = None) -> None:
|
||||
self.model = model
|
||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
self.slicer = Slicer(shard_config)
|
||||
self.shard_config = shard_config
|
||||
self.model_config = self.model.config
|
||||
self.pg_manager = pg_manager
|
||||
|
||||
def shard(self) -> None:
|
||||
self.reshape_embedding()
|
||||
self.inject_model(self.model)
|
||||
self.replace_layer(self.model)
|
||||
self.bind_layer(self.model)
|
||||
r"""
|
||||
Shard the model according to the policy
|
||||
"""
|
||||
self.policy.set_model(self.model)
|
||||
self.preprocess()
|
||||
self.replace_model_class()
|
||||
self.replace_module()
|
||||
self.postprocess()
|
||||
|
||||
def reshape_embedding(self,) -> None:
|
||||
def reshape_embedding(self) -> None:
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
@@ -52,10 +56,13 @@ class ModelSharder(object):
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.model_config = self.model.config
|
||||
|
||||
def inject_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
def preprocess(self) -> None:
|
||||
self.model = self.policy.preprocess(self.shard_config)
|
||||
|
||||
def postprocess(self) -> None:
|
||||
self.model = self.policy.postprocess()
|
||||
|
||||
def replace_model_class(self,) -> None:
|
||||
r"""
|
||||
Replace the model to policy defined model
|
||||
Mainly modify the forward and backward to fit distributed model
|
||||
@@ -64,49 +71,43 @@ class ModelSharder(object):
|
||||
::
|
||||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
if inject_policy is None:
|
||||
new_model_class = self.policy.new_model_class()
|
||||
if new_model_class is None:
|
||||
return
|
||||
|
||||
if inject_policy is None:
|
||||
return
|
||||
org_model_cls = inject_policy[0]
|
||||
shard_model_cls = inject_policy[1]
|
||||
for key in new_model_class.__dict__.keys():
|
||||
if hasattr(self.model.__class__, key):
|
||||
setattr(
|
||||
self.model.__class__,
|
||||
key,
|
||||
getattr(new_model_class, key),
|
||||
)
|
||||
|
||||
if model.__class__ == org_model_cls:
|
||||
for key in shard_model_cls.__dict__.keys():
|
||||
if hasattr(model.__class__, key):
|
||||
setattr(
|
||||
model.__class__,
|
||||
key,
|
||||
getattr(shard_model_cls, key),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{model.__class__} is not implemented so far")
|
||||
|
||||
def replace_layer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
def replace_module(self,) -> None:
|
||||
r"""
|
||||
Replace the layer according to the policy, and replace the layer one by one
|
||||
Replace the module according to the policy, and replace the module one by one
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The layer to shard
|
||||
model (:class:`torch.nn.Module`): The model to shard
|
||||
"""
|
||||
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
|
||||
for argument_policy in argument_policies.items():
|
||||
origin_layer_cls = argument_policy[0]
|
||||
attr_dict = argument_policy[1].attr_dict
|
||||
param_funcs = argument_policy[1].param_funcs
|
||||
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
print(self.policy)
|
||||
module_descriptions = self.policy.module_policy(self.shard_config)
|
||||
print(f"*******{module_descriptions}")
|
||||
for module_description in module_descriptions.items():
|
||||
origin_layer_cls = module_description[0]
|
||||
attr_replacement = module_description[1].attribute_replacement
|
||||
param_replacement = module_description[1].param_replacement
|
||||
sub_module_replacement = module_description[1].sub_module_replacement
|
||||
self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
|
||||
sub_module_replacement)
|
||||
|
||||
def traverse_replace_layer(
|
||||
def _recursive_replace_layer(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
module: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
attr_dict: Dict[str, Any],
|
||||
param_funcs: List[Callable],
|
||||
attr_replacement: Dict[str, Any],
|
||||
param_replacement: List[Callable],
|
||||
sub_module_replacement: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Reverse the replace layer operation
|
||||
@@ -114,21 +115,52 @@ class ModelSharder(object):
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
origin_cls (:class:`transformers.model`): The origin layer class
|
||||
attr_dict (Dict): The attribute dict to modify
|
||||
policy_cls (:class:`Policy`): The policy class
|
||||
attr_replacement (Dict): The attribute dict to modify
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in polic
|
||||
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
|
||||
"""
|
||||
if layer.__class__ == origin_cls:
|
||||
for k, v in attr_dict.items():
|
||||
setattr_(layer, k, v, ignore=True)
|
||||
self.shard_one_layer(layer, param_funcs)
|
||||
for name, child in layer.named_children():
|
||||
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
return layer
|
||||
if module.__class__ == origin_cls:
|
||||
self._replace_attr(module, attr_replacement)
|
||||
self._replace_param(module, param_replacement)
|
||||
self._replace_sub_module(module, sub_module_replacement)
|
||||
for name, child in module.named_children():
|
||||
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement,
|
||||
sub_module_replacement)
|
||||
|
||||
def shard_one_layer(
|
||||
def _replace_attr(
|
||||
self,
|
||||
module: nn.Module,
|
||||
attr_replacement: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the attribute of the layer
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
attr_replacement (Dict): The attribute dict to modify
|
||||
"""
|
||||
for k, v in attr_replacement.items():
|
||||
setattr_(module, k, v, ignore=True)
|
||||
|
||||
def _replace_param(
|
||||
self,
|
||||
module: nn.Module,
|
||||
param_replacement: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the parameter of the layer
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
||||
"""
|
||||
# TODO: support parameter shard
|
||||
pass
|
||||
|
||||
def _replace_sub_module(
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
param_funcs: List[Callable],
|
||||
sub_module_replacement: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||
@@ -138,145 +170,14 @@ class ModelSharder(object):
|
||||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||
|
||||
"""
|
||||
for func in param_funcs:
|
||||
policy_layers = func()
|
||||
for policy_layer in policy_layers:
|
||||
suffix = policy_layer.suffix
|
||||
replace_layer_cls = policy_layer.replace_layer
|
||||
ignore = policy_layer.ignore
|
||||
reversed = policy_layer.reversed
|
||||
n_cast = policy_layer.n_cast
|
||||
for description in sub_module_replacement:
|
||||
suffix = description.suffix
|
||||
target_module = description.target_module
|
||||
kwargs = description.kwargs
|
||||
|
||||
assert replace_layer_cls is not None, 'replace_layer should not be None'
|
||||
assert target_module is not None, 'target_module should not be None'
|
||||
|
||||
# create new object to replace the origin layer
|
||||
# Linear
|
||||
suffix_layer = getattr_(org_layer, suffix, ignore=True)
|
||||
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
|
||||
if suffix_layer is None and ignore:
|
||||
continue
|
||||
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
|
||||
weight = None
|
||||
bias = None
|
||||
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
|
||||
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
weight = getattr_(org_layer, weight_attr)
|
||||
else:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
|
||||
|
||||
if bias_attr is not None:
|
||||
if hasattr_(org_layer, bias_attr):
|
||||
bias = getattr_(org_layer, bias_attr)
|
||||
else:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
|
||||
|
||||
# set the sliced weight and bias to the new nn_col layer
|
||||
assert weight is not None or bias is not None
|
||||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
bias=False if bias is None else True)
|
||||
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
||||
gather_output = policy_layer.gather_output and self.shard_config.gather_output
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
bias=False if bias is None else True,
|
||||
gather_output=gather_output)
|
||||
elif replace_layer_cls.__name__ == "Embedding1D":
|
||||
gather_output = policy_layer.gather_output
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
gather_output=gather_output)
|
||||
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
||||
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
|
||||
# setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
# self.set_param(replace_layer, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Replacing to {replace_layer_cls.__name__} is not implemented so far")
|
||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
# dropout
|
||||
elif isinstance(policy_layer, Dropout_Layer):
|
||||
p_attr = suffix + '.' + policy_layer.p
|
||||
p = getattr_(org_layer, p_attr, ignore=True)
|
||||
replace_layer = replace_layer_cls(p)
|
||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far")
|
||||
|
||||
def set_param(self,
|
||||
layer: Any,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
layer_attr: str = "") -> None:
|
||||
r"""
|
||||
Reset the weight and bias of the layer object
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
weight (:class:`torch.Tensor`): The weight of the layer
|
||||
bias (:class:`torch.Tensor`): The bias of the layer
|
||||
"""
|
||||
assert weight is not None or bias is not None
|
||||
if weight is not None:
|
||||
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
|
||||
self.set_layer_size(layer, layer_attr, weight.shape)
|
||||
if bias is not None:
|
||||
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
|
||||
|
||||
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
|
||||
r"""
|
||||
Set the layer attribute
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
size (:class:`torch.Size`): The size of the tensor
|
||||
"""
|
||||
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
|
||||
attrs = ["out_features", "in_features"]
|
||||
for i, attr in enumerate(attrs):
|
||||
if hasattr_(layer, f"{layer_attr}.{attr}"):
|
||||
setattr_(layer, f"{layer_attr}.{attr}", size[i])
|
||||
|
||||
def bind_layer(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
Bind the layer according to the binding policy
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The shard model
|
||||
"""
|
||||
binding_map = self.policy.binding_policy()
|
||||
if binding_map is None:
|
||||
return
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(model, k, param)
|
||||
setattr_(model, v, param)
|
||||
|
||||
|
||||
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
|
||||
r"""
|
||||
The function is used to shard the PyTorch model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Model`): the origin huggingface model
|
||||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
# TODO: init shard_config automatically
|
||||
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
|
||||
sharder.shard()
|
||||
return model
|
||||
# TODO: integrate with new layer
|
||||
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
|
||||
replace_layer = None
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
Reference in New Issue
Block a user