diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 97fabe63a..f3a6901ad 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -290,8 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors, - state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())) + return peft_model.save_pretrained( + checkpoint, + safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), + ) class LowLevelZeroPlugin(DPPluginBase): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index aa4d35cd4..156a4acf9 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -5,8 +5,8 @@ import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator @@ -136,9 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors, - state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, - peft_model.state_dict())) + return peft_model.save_pretrained( + checkpoint, + safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), + ) class TorchDDPModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 4ca1353d8..e6abf59e3 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -957,6 +957,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors, - state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, - peft_model.state_dict())) + return peft_model.save_pretrained( + checkpoint, + safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), + )