mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
This commit is contained in:
@@ -60,9 +60,9 @@ MODEL_CONFIGS = {
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"v3-6b": AutoConfig.from_pretrained(
|
||||
"v3-7b": AutoConfig.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
num_hidden_layers=5,
|
||||
num_hidden_layers=6,
|
||||
first_k_dense_replace=2,
|
||||
n_routed_experts=32,
|
||||
vocab_size=8192,
|
||||
@@ -210,14 +210,15 @@ def main():
|
||||
config, trust_remote_code=True, attn_implementation=attn_impl, torch_dtype=torch.bfloat16
|
||||
).to(torch.bfloat16)
|
||||
if args.enable_lora:
|
||||
booster.enable_lora(
|
||||
model = booster.enable_lora(
|
||||
model,
|
||||
lora_config=LoraConfig(task_type="CAUSAL_LM", target_modules=["gate_proj", "up_proj", "down_proj"]),
|
||||
)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if model.__class__.__name__.startswith("DeepseekV3"):
|
||||
if config.__class__.__name__.startswith("DeepseekV3"):
|
||||
model.config.use_cache = False
|
||||
model.eval()
|
||||
# enable grad for moe layers
|
||||
for m in model.modules():
|
||||
@@ -257,40 +258,42 @@ def main():
|
||||
) as prof: # , distributed_debug_mode(10, enable=True):
|
||||
if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
|
||||
data_iter = iter(dataloader)
|
||||
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
with tqdm(
|
||||
range(len(dataloader)), desc="Step", disable=dist.get_rank() != dist.get_world_size() - 1
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
loss_scalar = loss.item() if loss is not None else None
|
||||
pbar.set_postfix({"loss": loss_scalar})
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
prof.step()
|
||||
print(f"rank {dist.get_rank()} step {step} passed")
|
||||
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
|
||||
prof.step()
|
||||
else:
|
||||
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
del outputs # free memory
|
||||
with tqdm(dataloader, desc="Step", disable=not coordinator.is_master()) as pbar:
|
||||
for step, batch in enumerate(pbar):
|
||||
performance_evaluator.on_step_start(step)
|
||||
outputs = model(**batch)
|
||||
loss = outputs[0]
|
||||
del outputs # free memory
|
||||
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
prof.step()
|
||||
performance_evaluator.on_step_end(**batch)
|
||||
prof.step()
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
Reference in New Issue
Block a user