mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[shardformer] Add dropout layer in shard model and refactor policy api (#3949)
* add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage
This commit is contained in:
@@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy
|
||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer
|
||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||
from .shard_config import ShardConfig
|
||||
from .slicer import Slicer
|
||||
@@ -141,65 +141,73 @@ class ModelSharder(object):
|
||||
for func in param_funcs:
|
||||
policy_layers = func()
|
||||
for policy_layer in policy_layers:
|
||||
weight = None
|
||||
bias = None
|
||||
weight_attr = policy_layer.weight
|
||||
bias_attr = policy_layer.bias
|
||||
suffix = policy_layer.suffix
|
||||
replace_layer_cls = policy_layer.replace_layer
|
||||
ignore = policy_layer.ignore
|
||||
n_cast = policy_layer.n_cast
|
||||
reversed = policy_layer.reversed
|
||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||
gather_output = policy_layer.gather_output and self.shard_config.gather_output
|
||||
n_cast = policy_layer.n_cast
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
weight = getattr_(org_layer, weight_attr)
|
||||
elif not ignore:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
|
||||
|
||||
if bias_attr is not None:
|
||||
if hasattr_(org_layer, bias_attr):
|
||||
bias = getattr_(org_layer, bias_attr)
|
||||
elif not ignore:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
|
||||
|
||||
# dont have the attribute in policy, and ignore is true
|
||||
if weight is None and bias is None and ignore:
|
||||
continue
|
||||
|
||||
# set the sliced weight and bias to the new nn_col layer
|
||||
assert weight is not None or bias is not None
|
||||
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
|
||||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||
assert replace_layer_cls is not None, 'replace_layer should not be None'
|
||||
|
||||
# create new object to replace the origin layer
|
||||
if replace_layer_cls is not None:
|
||||
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
bias=False if bias is None else True)
|
||||
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
bias=False if bias is None else True,
|
||||
gather_output=gather_output)
|
||||
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
|
||||
# Linear
|
||||
suffix_layer = getattr_(org_layer, suffix, ignore=True)
|
||||
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
|
||||
if suffix_layer is None and ignore:
|
||||
continue
|
||||
if isinstance(policy_layer, (Col_Layer, Row_Layer)):
|
||||
weight = None
|
||||
bias = None
|
||||
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
|
||||
bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
weight = getattr_(org_layer, weight_attr)
|
||||
else:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
|
||||
|
||||
if bias_attr is not None:
|
||||
if hasattr_(org_layer, bias_attr):
|
||||
bias = getattr_(org_layer, bias_attr)
|
||||
else:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
|
||||
|
||||
# set the sliced weight and bias to the new nn_col layer
|
||||
assert weight is not None or bias is not None
|
||||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
bias=False if bias is None else True)
|
||||
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
||||
gather_output = policy_layer.gather_output and self.shard_config.gather_output
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
bias=False if bias is None else True,
|
||||
gather_output=gather_output)
|
||||
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
||||
getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
|
||||
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
|
||||
# setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
# self.set_param(replace_layer, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
|
||||
# do not replace the layer object, just replace the weight and bias
|
||||
f"Replacing to {replace_layer_cls.__name__} is not implemented so far")
|
||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
# dropout
|
||||
elif isinstance(policy_layer, Dropout_Layer):
|
||||
p_attr = suffix + '.' + policy_layer.p
|
||||
p = getattr_(org_layer, p_attr, ignore=True)
|
||||
replace_layer = replace_layer_cls(p)
|
||||
setattr_(org_layer, suffix, replace_layer, ignore=ignore)
|
||||
else:
|
||||
self.set_param(org_layer, layer_attr, weight, bias)
|
||||
raise NotImplementedError(
|
||||
f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far")
|
||||
|
||||
def set_param(self,
|
||||
layer: Any,
|
||||
|
Reference in New Issue
Block a user