[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -1,4 +1,5 @@
import torch.nn as nn
from peft import PeftModel
class ModelWrapper(nn.Module):
@@ -13,13 +14,17 @@ class ModelWrapper(nn.Module):
super().__init__()
self.module = module
def unwrap(self):
def unwrap(self, unwrap_peft: bool = True):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
return self.module.unwrap()
return self.module
model = self.module.unwrap()
else:
model = self.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)