[Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)

* fix bug

* fix

* fix multiquery

* fix multiquery

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
Jianghai
2023-11-07 15:01:50 +08:00
committed by GitHub
parent c36e782d80
commit ef4c14a5e2
8 changed files with 21 additions and 19 deletions

View File

@@ -149,7 +149,6 @@ class Linear1D_Col(ParallelModule):
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."

View File

@@ -400,7 +400,6 @@ class SelfAttention(torch.nn.Module):
)
self.core_attention = CoreAttention(config, self.layer_number)
# Output.
self.dense = nn.Linear(
self.projection_size,

View File

@@ -104,7 +104,6 @@ class ChatGLMPolicy(Policy):
),
],
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=[

View File

@@ -180,7 +180,6 @@ class ModelSharder(object):
assert target_module is not None, "target_module should not be None"
native_sub_module = getattr_(org_layer, suffix, ignore=True)
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue