mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 19:58:17 +00:00
fix lora ckpt save format (ColoTensor to Tensor)
This commit is contained in:
parent
5ddad486ca
commit
b10339df7c
@ -290,7 +290,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
peft_model, PeftModel
|
peft_model, PeftModel
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
|
||||||
|
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroPlugin(DPPluginBase):
|
class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
@ -134,7 +136,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
peft_model, PeftModel
|
peft_model, PeftModel
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
|
||||||
|
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
|
||||||
|
peft_model.state_dict()))
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
|
@ -11,6 +11,7 @@ import torch.distributed as dist
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
@ -956,4 +957,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
peft_model, PeftModel
|
peft_model, PeftModel
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
|
||||||
|
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
|
||||||
|
peft_model.state_dict()))
|
||||||
|
Loading…
Reference in New Issue
Block a user