mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 22:19:47 +00:00
[fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;
This commit is contained in:
parent
014afbdb59
commit
5c2ebbfd48
@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
num_microbatch: Optional[int] = None,
|
num_microbatch: Optional[int] = None,
|
||||||
microbatch_size: Optional[int] = None,
|
microbatch_size: Optional[int] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = True,
|
||||||
overlap_p2p: bool = False,
|
overlap_p2p: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(stage_manager)
|
super().__init__(stage_manager)
|
||||||
# Not support overlap_p2p so far
|
# Not support overlap_p2p so far
|
||||||
@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||||
for it in range(len(schedule)):
|
for it in range(len(schedule)):
|
||||||
scheduled_node = schedule[it]
|
scheduled_node = schedule[it]
|
||||||
|
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
wait_handle = communication_func(scheduled_node.chunk)
|
wait_handle = communication_func(scheduled_node.chunk)
|
||||||
self.wait_handles.append(wait_handle)
|
self.wait_handles.append(wait_handle)
|
||||||
elif scheduled_node.type == "F":
|
elif scheduled_node.type == "F":
|
||||||
|
for h in self.wait_handles:
|
||||||
|
for hh in h:
|
||||||
|
hh.wait()
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
elif scheduled_node.type == "B":
|
elif scheduled_node.type == "B":
|
||||||
|
for h in self.wait_handles:
|
||||||
|
for hh in h:
|
||||||
|
hh.wait()
|
||||||
self.schedule_b(
|
self.schedule_b(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
for h in self.wait_handles:
|
||||||
|
for hh in h:
|
||||||
|
hh.wait()
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
outputs = merge_batch(outputs)
|
||||||
|
@ -381,7 +381,6 @@ class MixtralPipelineForwards:
|
|||||||
output_router_logits,
|
output_router_logits,
|
||||||
use_cache,
|
use_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -214,7 +214,6 @@ class MixtralPolicy(Policy):
|
|||||||
suffix="block_sparse_moe.gate",
|
suffix="block_sparse_moe.gate",
|
||||||
target_module=LinearWithGradAccum,
|
target_module=LinearWithGradAccum,
|
||||||
kwargs={
|
kwargs={
|
||||||
"gather_output": True,
|
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
"use_zbv": use_zbv,
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
@ -414,7 +413,6 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=LinearWithGradAccum,
|
target_module=LinearWithGradAccum,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
gather_output=True,
|
|
||||||
fp8_communication=self.shard_config.fp8_communication,
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
use_zbv=use_zbv,
|
use_zbv=use_zbv,
|
||||||
),
|
),
|
||||||
|
@ -122,7 +122,7 @@ def main():
|
|||||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
),
|
),
|
||||||
"num_layers_per_stage": [19, 20, 20, 21],
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
# "pp_style": "interleaved",
|
"pp_style": "interleaved",
|
||||||
}
|
}
|
||||||
if args.custom_ckpt
|
if args.custom_ckpt
|
||||||
else {}
|
else {}
|
||||||
|
Loading…
Reference in New Issue
Block a user