Merge pull request #6096 from BurkeHulk/hotfix/lora_ckpt

[hotfix] fix lora ckpt saving format
This commit is contained in:
Hanks 2024-10-21 14:13:04 +08:00 committed by GitHub
commit dee63cc5ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 3 deletions

View File

@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):

View File

@ -1,9 +1,11 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class TorchDDPModel(ModelWrapper): class TorchDDPModel(ModelWrapper):

View File

@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -956,4 +957,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)