[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -60,9 +60,9 @@ MODEL_CONFIGS = {
attn_implementation="flash_attention_2",
trust_remote_code=True,
),
"v3-6b": AutoConfig.from_pretrained(
"v3-7b": AutoConfig.from_pretrained(
"deepseek-ai/DeepSeek-V3",
num_hidden_layers=5,
num_hidden_layers=6,
first_k_dense_replace=2,
n_routed_experts=32,
vocab_size=8192,
@@ -210,14 +210,15 @@ def main():
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
).to(torch.bfloat16)
if args.enable_lora:
booster.enable_lora(
model = booster.enable_lora(
model,
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if model.__class__.__name__.startswith("DeepseekV3"):
if config.__class__.__name__.startswith("DeepseekV3"):
model.config.use_cache = False
model.eval()
# enable grad for moe layers
for m in model.modules():
@@ -257,40 +258,42 @@ def main():
) as prof: # , distributed_debug_mode(10, enable=True):
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()
with tqdm(
range(len(dataloader)), desc="Step", disable=dist.get_rank() != dist.get_world_size() - 1
) as pbar:
for step in pbar:
performance_evaluator.on_step_start(step)
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
loss_scalar = loss.item() if loss is not None else None
pbar.set_postfix({"loss": loss_scalar})
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
print(f"rank {dist.get_rank()} step {step} passed")
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
else:
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
del outputs # free memory
with tqdm(dataloader, desc="Step", disable=not coordinator.is_master()) as pbar:
for step, batch in enumerate(pbar):
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
del outputs # free memory
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
pbar.set_postfix({"loss": loss.item()})
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
prof.step()
performance_evaluator.on_step_end(**batch)
prof.step()
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")