ColossalAI/colossalai/inference/modeling/layers/baichuan_tp_linear.py
yuehuayingxueluo 5f00002e43
[Inference] Adapt Baichuan2-13B TP (#5659)
* adapt to baichuan2 13B

* add baichuan2 13B TP

* update baichuan tp logic

* rm unused code

* Fix TP logic

* fix alibi slopes tp logic

* rm nn.Module

* Polished the code.

* change BAICHUAN_MODEL_NAME_OR_PATH

* Modified the logic for loading Baichuan weights.

* fix typos
2024-04-30 15:47:07 +08:00

44 lines
1.4 KiB
Python

from typing import List, Union
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.layer.parallel_module import ParallelModule
class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(module.weight)
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)
class BaichuanWpackLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
in_features = module.in_features * 3
out_features = module.out_features // 3
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
module.bias = None
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)