mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 12:07:00 +00:00
[fix] MixtralForCausalLMPolicy get_held_layer support zbv;
This commit is contained in:
parent
3f5bec8dc4
commit
9ee80fc828
@ -268,6 +268,7 @@ class MixtralPipelineForwards:
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
|
print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
@ -343,8 +343,18 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
if stage_manager.use_zbv:
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
else:
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
else:
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
|
# if stage_manager.is_last_stage():
|
||||||
|
# held_layers.append(self.model.lm_head)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
@ -167,6 +167,7 @@ def main():
|
|||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
|
num_microbatches=args.batch_size // args.mbs,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
enable_metadata_cache=not args.no_cache,
|
enable_metadata_cache=not args.no_cache,
|
||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
@ -208,8 +209,10 @@ def main():
|
|||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
|
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
|
||||||
|
|
||||||
|
# if args.grad_checkpoint:
|
||||||
|
# model.gradient_checkpointing_enable()
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
model_numel = get_model_numel(model)
|
||||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
@ -224,6 +227,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters())
|
optimizer = HybridAdam(model.parameters())
|
||||||
|
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user