mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[shardformer] add gpt2 policy and modify shard and slicer to support (#3883)
* add gpt2 policy and modify shard and slicer to support * remove unused code * polish code
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy
|
||||
@@ -35,10 +36,22 @@ class ModelSharder(object):
|
||||
self.model_config = self.model.config
|
||||
|
||||
def shard(self) -> None:
|
||||
self.reshape_embedding()
|
||||
self.inject_model(self.model)
|
||||
self.replace_layer(self.model)
|
||||
self.bind_layer(self.model)
|
||||
|
||||
def reshape_embedding(self,) -> None:
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model_config.vocab_size
|
||||
world_size = self.shard_config.world_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.model_config = self.model.config
|
||||
|
||||
def inject_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
@@ -53,6 +66,8 @@ class ModelSharder(object):
|
||||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
|
||||
if inject_policy is None:
|
||||
return
|
||||
org_model_cls = inject_policy[0]
|
||||
shard_model_cls = inject_policy[1]
|
||||
|
||||
@@ -82,9 +97,9 @@ class ModelSharder(object):
|
||||
origin_layer_cls = argument_policy[0]
|
||||
attr_dict = argument_policy[1].attr_dict
|
||||
param_funcs = argument_policy[1].param_funcs
|
||||
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
|
||||
def reverse_replace_layer(
|
||||
def traverse_replace_layer(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
@@ -100,17 +115,12 @@ class ModelSharder(object):
|
||||
attr_dict (Dict): The attribute dict to modify
|
||||
policy_cls (:class:`Policy`): The policy class
|
||||
"""
|
||||
if layer.__class__ == origin_cls:
|
||||
for k, v in attr_dict.items():
|
||||
setattr_(layer, k, v, ignore=True)
|
||||
self.shard_one_layer(layer, param_funcs)
|
||||
for name, child in layer.named_children():
|
||||
if child.__class__ == origin_cls:
|
||||
# replac_layer = child
|
||||
for k, v in attr_dict.items():
|
||||
setattr_(child, k, v, ignore=True)
|
||||
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
|
||||
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
|
||||
self.shard_one_layer(child, param_funcs)
|
||||
continue
|
||||
|
||||
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
return layer
|
||||
|
||||
def shard_one_layer(
|
||||
@@ -126,7 +136,6 @@ class ModelSharder(object):
|
||||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||
|
||||
"""
|
||||
# print(org_layer)
|
||||
for func in param_funcs:
|
||||
policy_layers = func()
|
||||
for policy_layer in policy_layers:
|
||||
@@ -136,9 +145,10 @@ class ModelSharder(object):
|
||||
bias_attr = policy_layer.bias
|
||||
replace_layer_cls = policy_layer.replace_layer
|
||||
ignore = policy_layer.ignore
|
||||
n_cast = policy_layer.n_cast
|
||||
reversed = policy_layer.reversed
|
||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||
gather_output = policy_layer.gather_output
|
||||
# print(gather_output)
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
@@ -161,13 +171,11 @@ class ModelSharder(object):
|
||||
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
|
||||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
|
||||
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||
|
||||
# create new object to replace the origin layer
|
||||
if replace_layer_cls is not None:
|
||||
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
|
||||
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
|
||||
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
@@ -235,6 +243,8 @@ class ModelSharder(object):
|
||||
model (:class:`torch.nn.Module`): The shard model
|
||||
"""
|
||||
binding_map = self.policy.binding_policy()
|
||||
if binding_map is None:
|
||||
return
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(model, k)
|
||||
param = nn.Parameter(param)
|
||||
|
Reference in New Issue
Block a user