mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Inference]refactor baichuan (#5791)
* refactor baichuan * remove unused code and add TODO for lazyinit
This commit is contained in:
@@ -15,25 +15,10 @@ class BaichuanLMHeadLinear1D_Col(Linear1D_Col):
|
||||
module.in_features = module.weight.size(1)
|
||||
module.out_features = module.weight.size(0)
|
||||
module.bias = None
|
||||
module.weight.data = nn.functional.normalize(module.weight)
|
||||
|
||||
return Linear1D_Col.from_native_module(
|
||||
module,
|
||||
process_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanWpackLinear1D_Col(Linear1D_Col):
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
in_features = module.in_features * 3
|
||||
out_features = module.out_features // 3
|
||||
module.weight.data = module.weight.view(3, out_features, -1).transpose(0, 1).reshape(out_features, in_features)
|
||||
module.bias = None
|
||||
module.weight.data = nn.functional.normalize(
|
||||
module.weight
|
||||
) # TODO(lry89757) This behavior may not apply to lazy init. When we use lazy init, the weight of shardformer is not the real weight.
|
||||
# So we should rewrite our own load_from_state_dict of `BaichuanLMHeadLinear1D_Col` to fix this potential issue.
|
||||
|
||||
return Linear1D_Col.from_native_module(
|
||||
module,
|
||||
|
Reference in New Issue
Block a user