mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[Feature] Zigzag Ring attention (#5905)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -69,13 +69,20 @@ class LlamaPolicy(Policy):
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
if sp_mode == "ring_attn" and not self.is_causal:
|
||||
raise ValueError("Ring attention is only meant for causal language modeling.")
|
||||
|
||||
tp_size = self.shard_config.tensor_parallel_size
|
||||
# Modified by SP and TP
|
||||
num_q_heads = self.model.config.num_attention_heads
|
||||
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
|
||||
|
||||
if sp_mode == "all_to_all":
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
num_q_heads //= sp_size
|
||||
decoder_attribute_replacement = {"num_heads": num_q_heads}
|
||||
if num_kv_heads:
|
||||
num_kv_heads //= sp_size
|
||||
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
@@ -104,21 +111,20 @@ class LlamaPolicy(Policy):
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
num_q_heads % tp_size == 0
|
||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||
if hasattr(self.model.config, "num_key_value_heads"):
|
||||
assert (
|
||||
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
||||
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
||||
num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
|
||||
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||
num_q_heads //= tp_size
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
||||
"self_attn.num_heads": num_q_heads,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
num_kv_heads //= tp_size
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
@@ -295,10 +301,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
self.is_causal = True
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
# add a new item for causal lm
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
@@ -313,10 +320,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
],
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
new_item[LlamaForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
else:
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
@@ -336,7 +339,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
|
||||
# Compute loss distributedly along the sequence dimension
|
||||
new_item[LlamaForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
|
Reference in New Issue
Block a user