mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
This commit is contained in:
@@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, T
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import PeftModel
|
||||
from torch import Tensor, inf
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
from torch.nn import Module, SyncBatchNorm
|
||||
@@ -219,11 +220,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
with self._hook_context():
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
module = super().unwrap()
|
||||
if isinstance(module, DDP):
|
||||
module = module.module
|
||||
return module
|
||||
def unwrap(self, unwrap_peft: bool = True):
|
||||
model = self.module
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
return model
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
@@ -1509,7 +1512,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
||||
assert self.pp_size == 1 and self.tp_size == 1
|
||||
assert self.tp_size == 1
|
||||
self.lora_enabled = True
|
||||
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
|
||||
|
||||
|
Reference in New Issue
Block a user