mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)
* fix bug * fix * fix multiquery * fix multiquery --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -149,7 +149,6 @@ class Linear1D_Col(ParallelModule):
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
|
@@ -400,7 +400,6 @@ class SelfAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.core_attention = CoreAttention(config, self.layer_number)
|
||||
|
||||
# Output.
|
||||
self.dense = nn.Linear(
|
||||
self.projection_size,
|
||||
|
@@ -104,7 +104,6 @@ class ChatGLMPolicy(Policy):
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
|
@@ -180,7 +180,6 @@ class ModelSharder(object):
|
||||
assert target_module is not None, "target_module should not be None"
|
||||
|
||||
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
||||
|
||||
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
|
||||
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
|
||||
continue
|
||||
|
Reference in New Issue
Block a user