[Inference]refactor baichuan (#5791)

* refactor baichuan

* remove unused code and add TODO for lazyinit
This commit is contained in:
Runyu Lu
2024-06-11 10:52:01 +08:00
committed by GitHub
parent 77a219a082
commit c0948aff97
3 changed files with 24 additions and 110 deletions

View File

@@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
module.in_features = module.weight.size(1)
module.out_features = module.weight.size(0)
module.bias = None
module.weight.data = nn.functional.normalize(module.weight)
return Linear1D_Col.from_native_module(
module,
process_group,
*args,
**kwargs,
)
class BaichuanWpackLinear1D_Col(Linear1D_Col):
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
in_features = module.in_features * 3
out_features = module.out_features // 3
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
module.bias = None
module.weight.data = nn.functional.normalize(
module.weight
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
return Linear1D_Col.from_native_module(
module,