[feat] support no_tp Linear for sharderformer.llama

This commit is contained in:
duanjunwen
2024-11-05 05:55:42 +00:00
parent 8e40087633
commit 4fc92aa77d
5 changed files with 140 additions and 42 deletions

View File

@@ -64,10 +64,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.send_tensor_metadata = [True, True]
self.send_grad_metadata = [True, True]
# meta cache buffer
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
self.grad_metadata_recv = [None, None]
# P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
@@ -235,10 +236,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
# return input_tensor, wait_handles
return wait_handles
@@ -259,10 +260,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(
next_rank, metadata_recv=self.tensor_metadata_recv
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
# return input_tensor, wait_handles
return wait_handles
@@ -297,10 +298,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
# return output_tensor_grad, wait_handles
return wait_handles
@@ -322,10 +323,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
# return output_tensor_grad, wait_handles
return wait_handles
@@ -359,9 +360,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
output_object=output_tensor,
next_rank=next_rank,
send_metadata=self.send_tensor_metadata[model_chunk_id],
)
self.send_tensor_metadata = not self.enable_metadata_cache
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
else:
@@ -380,9 +383,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
)
self.send_tensor_metadata = not self.enable_metadata_cache
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
@@ -415,9 +418,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
)
self.send_grad_metadata = not self.enable_metadata_cache
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
# bwd chunk1 is left V;
@@ -437,9 +440,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
next_rank = self.stage_manager.get_next_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
)
self.send_grad_metadata = not self.enable_metadata_cache
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
def forward_step(
@@ -662,6 +665,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};")
# Step3:
# 3-1:detach output; detach output for send fwd;
@@ -886,6 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
# print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]

View File

@@ -191,7 +191,6 @@ class LlamaPipelineForwards:
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)

View File

@@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
@@ -104,7 +105,7 @@ class LlamaPolicy(Policy):
policy=policy,
target_key=LlamaModel,
)
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
if self.shard_config.enable_tensor_parallelism:
assert (
num_q_heads % tp_size == 0
@@ -191,6 +192,84 @@ class LlamaPolicy(Policy):
],
)
# not enable tp, replace layer to LinearWithGradAccum
else:
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
"self_attn.num_heads": num_q_heads,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@@ -416,6 +495,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
@@ -434,6 +514,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
)
}
policy.update(new_item)
# enable tp, replace layer to LinearWithGradAccum
else:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score",
target_module=LinearWithGradAccum,
kwargs=dict(
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:
# set None as default