fix lora ckpt save format (ColoTensor to Tensor)

This commit is contained in:
BurkeHulk 2024-10-21 13:55:43 +08:00
parent 5ddad486ca
commit b10339df7c
3 changed files with 11 additions and 3 deletions

View File

@ -290,7 +290,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):

View File

@ -1,10 +1,12 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils._pytree import tree_map
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -134,7 +136,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
peft_model.state_dict()))
class TorchDDPModel(ModelWrapper): class TorchDDPModel(ModelWrapper):

View File

@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -956,4 +957,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
peft_model.state_dict()))