[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, T
import numpy as np
import torch
import torch.distributed as dist
from peft import PeftModel
from torch import Tensor, inf
from torch.distributed import ProcessGroup, get_world_size
from torch.nn import Module, SyncBatchNorm
@@ -219,11 +220,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
with self._hook_context():
return super().forward(*args, **kwargs)
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def unwrap(self, unwrap_peft: bool = True):
model = self.module
if isinstance(model, DDP):
model = model.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model
def _force_wait_all_gather(self):
for p in self.module.parameters():
@@ -1509,7 +1512,7 @@ class HybridParallelPlugin(PipelinePluginBase):
from peft import PeftModel, get_peft_model
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
assert self.tp_size == 1
self.lora_enabled = True
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])