mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
[feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;
This commit is contained in:
parent
cfade4c36d
commit
a11b4b50a7
@ -1217,6 +1217,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
inner_ring_size=inner_ring_size,
|
inner_ring_size=inner_ring_size,
|
||||||
|
use_zbv=(pp_style == "zbv"),
|
||||||
)
|
)
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
@ -373,6 +373,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
|
use_zbv=(pp_style == "zbv"),
|
||||||
)
|
)
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
@ -34,73 +34,3 @@ class WeightGradStore:
|
|||||||
weight.grad = grad_weight
|
weight.grad = grad_weight
|
||||||
else:
|
else:
|
||||||
raise Exception("Pop empty queue.")
|
raise Exception("Pop empty queue.")
|
||||||
|
|
||||||
# @classmethod
|
|
||||||
# def clear(cls, model, chunk=0):
|
|
||||||
# weight_grad_tasks = []
|
|
||||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
|
||||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
|
||||||
# if len(weight_grad_tasks) == 0:
|
|
||||||
# for _ in stored_grads:
|
|
||||||
# weight_grad_tasks.append([])
|
|
||||||
# else:
|
|
||||||
# assert len(weight_grad_tasks) == len(stored_grads)
|
|
||||||
# for i, task in enumerate(stored_grads):
|
|
||||||
# weight_grad_tasks[i].append(task)
|
|
||||||
# weight_params = []
|
|
||||||
# handles = []
|
|
||||||
# if get_args().overlap_grad_reduce:
|
|
||||||
# handles += model.async_reduce_grad()
|
|
||||||
|
|
||||||
# output_layer_weight = None
|
|
||||||
# if parallel_state.is_pipeline_last_stage():
|
|
||||||
# assert len(weight_grad_tasks) > 0
|
|
||||||
# output_layer_grads = weight_grad_tasks[0]
|
|
||||||
# for j in range(len(output_layer_grads)):
|
|
||||||
# total_input, grad_output, weight, func = output_layer_grads[j]
|
|
||||||
# if output_layer_weight is None:
|
|
||||||
# output_layer_weight = weight
|
|
||||||
# assert output_layer_weight is weight
|
|
||||||
# func(total_input, grad_output, weight.main_grad)
|
|
||||||
# output_layer_grads[j] = None # release memory
|
|
||||||
# weight_grad_tasks = weight_grad_tasks[1:]
|
|
||||||
# if get_args().overlap_grad_reduce:
|
|
||||||
# handles += model.async_reduce_grad(output_layer_weight)
|
|
||||||
|
|
||||||
# if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
|
|
||||||
# model_module = get_attr_wrapped_model(model, 'pre_process', return_model_obj=True)
|
|
||||||
# if model_module.share_embeddings_and_output_weights:
|
|
||||||
# # if share_embeddings_and_output_weights, wait all-reduce for embeddings
|
|
||||||
# for handle in handles:
|
|
||||||
# if handle is not None:
|
|
||||||
# handle.wait()
|
|
||||||
# handles = []
|
|
||||||
|
|
||||||
# config = get_model_config(model)
|
|
||||||
# # Do async all-reduce for embedding grads firstly, so that the rank 0 won't
|
|
||||||
# # be blocked
|
|
||||||
# embedding_handles = _allreduce_embedding_grads([model], config, async_op=True)
|
|
||||||
# handles += embedding_handles
|
|
||||||
|
|
||||||
# for i in range(len(weight_grad_tasks)):
|
|
||||||
# tasks = weight_grad_tasks[i]
|
|
||||||
# param = None
|
|
||||||
# for j in range(len(tasks)):
|
|
||||||
# total_input, grad_output, weight, func = tasks[j]
|
|
||||||
# if param is None:
|
|
||||||
# param = weight
|
|
||||||
# assert param is weight
|
|
||||||
# assert not (weight is output_layer_weight)
|
|
||||||
# func(total_input, grad_output, weight.main_grad)
|
|
||||||
# tasks[j] = None # release memory
|
|
||||||
# weight_params.append(param)
|
|
||||||
# if get_args().overlap_grad_reduce:
|
|
||||||
# # All-reduce param grad here
|
|
||||||
# handles += model.async_reduce_grad(param)
|
|
||||||
# weight_grad_tasks[i] = None # release memory
|
|
||||||
|
|
||||||
# # timers('wait_all_reduce', log_level=1).start(barrier=False)
|
|
||||||
# for handle in embedding_handles:
|
|
||||||
# if handle is not None:
|
|
||||||
# handle.wait()
|
|
||||||
# # timers('wait_all_reduce').stop()
|
|
||||||
|
@ -126,37 +126,65 @@ class LlamaPolicy(Policy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.gate_proj",
|
suffix="mlp.gate_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.up_proj",
|
suffix="mlp.up_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.down_proj",
|
suffix="mlp.down_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -124,27 +124,43 @@ class MixtralPolicy(Policy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": self.shard_config.use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": self.shard_config.use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": self.shard_config.use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": self.shard_config.use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="block_sparse_moe.gate",
|
suffix="block_sparse_moe.gate",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"gather_output": True,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": self.shard_config.use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -322,9 +338,13 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
gather_output=True,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
@ -380,7 +400,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
|||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="score",
|
suffix="score",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
gather_output=True,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=self.shard_config.use_zbv,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -49,6 +49,7 @@ class ShardConfig:
|
|||||||
make_vocab_size_divisible_by: int = 64
|
make_vocab_size_divisible_by: int = 64
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
use_zbv: bool = False
|
||||||
|
|
||||||
# For ring attention
|
# For ring attention
|
||||||
inner_ring_size: Optional[int] = None
|
inner_ring_size: Optional[int] = None
|
||||||
|
Loading…
Reference in New Issue
Block a user