mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 11:06:25 +00:00
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs
This commit is contained in:
parent
7b9b86441f
commit
c0a033700c
@ -2,7 +2,7 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Parameter
|
from torch.nn import Module, Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
@ -152,3 +152,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
if p is working_param:
|
if p is working_param:
|
||||||
continue
|
continue
|
||||||
working_param.data.copy_(p.data)
|
working_param.data.copy_(p.data)
|
||||||
|
|
||||||
|
def update_master_params(self, model: Module):
|
||||||
|
# Update master params from working params
|
||||||
|
with torch.no_grad():
|
||||||
|
for p in model.parameters():
|
||||||
|
if (p is None) or (p not in self.working_to_master_map):
|
||||||
|
continue
|
||||||
|
master_param = self.working_to_master_map[p]
|
||||||
|
master_param.data.copy_(p.data)
|
||||||
|
|
||||||
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||||
|
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
|
||||||
|
|
||||||
|
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||||
|
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
||||||
|
@ -139,7 +139,7 @@ class Booster:
|
|||||||
|
|
||||||
if self.plugin and not self.plugin.control_device():
|
if self.plugin and not self.plugin.control_device():
|
||||||
# transform model for accelerator
|
# transform model for accelerator
|
||||||
model = self.accelerator.configure(model)
|
model = self.accelerator.configure_model(model)
|
||||||
|
|
||||||
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
||||||
# transform model for mixed precision
|
# transform model for mixed precision
|
||||||
|
@ -44,6 +44,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||||
state_dict = model.state_dict(only_rank_0=True)
|
state_dict = model.state_dict(only_rank_0=True)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
@ -53,24 +54,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
Load model from checkpoint with automatic unwrapping.
|
Load model from checkpoint with automatic unwrapping.
|
||||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save unsharded optimizer state dict to checkpoint.
|
Save unsharded optimizer state dict to checkpoint.
|
||||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||||
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
|
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
|
||||||
The saving process will only be executed by master rank.
|
The saving process will only be executed by master rank.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||||
state_dict = optimizer.state_dict()
|
state_dict = optimizer.state_dict()
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
Loading unsharded optimizer from checkpoint file.
|
Loading unsharded optimizer from checkpoint file.
|
||||||
For each process, only loading optimizer states of parameters it controls.
|
For each process, only loading optimizer states of parameters it controls.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
@ -86,6 +90,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
Save sharded model.
|
Save sharded model.
|
||||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||||
if os.path.isfile(checkpoint_path):
|
if os.path.isfile(checkpoint_path):
|
||||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@ -111,7 +116,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
save_config_file(model.module, checkpoint_path)
|
save_config_file(model.unwrap(), checkpoint_path)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"The model is split into checkpoint shards. "
|
f"The model is split into checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
@ -124,17 +129,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
"""
|
"""
|
||||||
Load shard model, load model from multiple files.
|
Load shard model, load model from multiple files.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer state dict to checkpoint folder.
|
Save sharded optimizer state dict to checkpoint folder.
|
||||||
As there is communication when getting state dict, this must be called on all processes.
|
As there is communication when getting state dict, this must be called on all processes.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||||
assert isinstance(optimizer, GeminiOptimizer)
|
|
||||||
|
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
@ -176,12 +181,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
|
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
||||||
"""
|
"""
|
||||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||||
For each process, only loading optimizer states of parameters it controls.
|
For each process, only loading optimizer states of parameters it controls.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||||
if not os.path.isfile(checkpoint_index_file):
|
if not os.path.isfile(checkpoint_index_file):
|
||||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
||||||
|
|
||||||
@ -383,7 +388,7 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
|
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = GeminiOptimizer(
|
optimizer = GeminiOptimizer(
|
||||||
optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
|
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from types import MethodType
|
||||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -165,6 +166,15 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
init_pipeline_optimizer(optim, model)
|
init_pipeline_optimizer(optim, model)
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
|
def update_master_params(self, model: Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_working_to_master_map(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_master_to_working_map(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -466,9 +476,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
**self.amp_config,
|
**self.amp_config,
|
||||||
)
|
)
|
||||||
self.checkpoint_io.link_master_and_working_param(
|
|
||||||
optimizer.working_to_master_map, optimizer.master_to_working_map
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
optimizer = HybridParallelNaiveOptimizer(
|
optimizer = HybridParallelNaiveOptimizer(
|
||||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||||
@ -488,10 +495,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
**self.zero_config,
|
**self.zero_config,
|
||||||
**self.amp_config,
|
**self.amp_config,
|
||||||
)
|
)
|
||||||
self.checkpoint_io.link_master_and_working_param(
|
# inject update_master_params
|
||||||
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
|
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||||
)
|
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
def execute_pipeline(
|
def execute_pipeline(
|
||||||
@ -567,8 +572,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||||
return self.checkpoint_io
|
|
||||||
|
|
||||||
def no_sync(self, model: Module) -> Iterator[None]:
|
def no_sync(self, model: Module) -> Iterator[None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -22,7 +22,6 @@ from colossalai.checkpoint_io.utils import (
|
|||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
sharded_optimizer_loading_epilogue,
|
sharded_optimizer_loading_epilogue,
|
||||||
unwrap_optimizer,
|
|
||||||
)
|
)
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
@ -65,10 +64,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||||||
kwargs = tree_map(self.convert_fn, kwargs)
|
kwargs = tree_map(self.convert_fn, kwargs)
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def unwrap(self):
|
|
||||||
# TODO(ver217): this is a workaround for loading model
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||||
@ -79,7 +74,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
checkpoint (str): Path to save checkpoint
|
checkpoint (str): Path to save checkpoint
|
||||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||||
# if only the master rank collect state_dict and save,
|
# if only the master rank collect state_dict and save,
|
||||||
# the communication on each rank would not match
|
# the communication on each rank would not match
|
||||||
@ -109,6 +104,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
prefix (str): Perfix of file to save
|
prefix (str): Perfix of file to save
|
||||||
size_per_shard (int): Max file size of each file that store state tensors
|
size_per_shard (int): Max file size of each file that store state tensors
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@ -160,9 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
index_file_path (str): Path to the index file
|
index_file_path (str): Path to the index file
|
||||||
prefix (str): Not used.
|
prefix (str): Not used.
|
||||||
"""
|
"""
|
||||||
# If optimizer is wrapped, unwrap it.
|
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
|
||||||
if isinstance(optimizer, OptimizerWrapper):
|
optimizer = optimizer.unwrap()
|
||||||
optimizer = unwrap_optimizer(optimizer)
|
|
||||||
|
|
||||||
# Read checkpoint index file.
|
# Read checkpoint index file.
|
||||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||||
@ -194,44 +189,23 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer)
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
def save_unsharded_model(
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||||
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
):
|
super().load_unsharded_model(model, checkpoint, strict)
|
||||||
assert isinstance(model, LowLevelZeroModel)
|
|
||||||
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
|
|
||||||
|
|
||||||
def save_sharded_model(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
checkpoint_path: str,
|
|
||||||
gather_dtensor: bool = True,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
max_shard_size: int = 1024,
|
|
||||||
use_safetensors: bool = False,
|
|
||||||
):
|
|
||||||
assert isinstance(model, LowLevelZeroModel)
|
|
||||||
super().save_sharded_model(
|
|
||||||
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
|
|
||||||
assert isinstance(model, LowLevelZeroModel)
|
|
||||||
super().load_unsharded_model(model.module, checkpoint, strict)
|
|
||||||
model.update_master_params()
|
model.update_master_params()
|
||||||
|
|
||||||
def load_sharded_model(
|
def load_sharded_model(
|
||||||
self,
|
self,
|
||||||
model: LowLevelZeroModel,
|
model: ModelWrapper,
|
||||||
checkpoint_index_file: Path,
|
checkpoint_index_file: Path,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
):
|
):
|
||||||
assert isinstance(model, LowLevelZeroModel)
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||||
model.update_master_params()
|
model.update_master_params()
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,24 +20,33 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint with automatic unwrapping.
|
Load model from checkpoint.
|
||||||
"""
|
"""
|
||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||||
|
"""
|
||||||
|
Load optimizer from checkpoint.
|
||||||
|
"""
|
||||||
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
|
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||||
|
|
||||||
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||||
|
|
||||||
@ -50,7 +59,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: ModelWrapper,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
@ -60,22 +69,52 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
super().save_sharded_model(
|
||||||
|
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_sharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint_index_file: str,
|
||||||
|
strict: bool = False,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
load_sub_module: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load model from sharded checkpoint.
|
||||||
|
"""
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||||
|
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
optimizer: Optimizer,
|
optimizer: OptimizerWrapper,
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to sharded checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||||
|
|
||||||
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
index_file_path: str,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load optimizer from sharded checkpoint.
|
||||||
|
"""
|
||||||
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
|
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
|
@ -39,31 +39,35 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
||||||
|
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||||
|
model = model.unwrap()
|
||||||
checkpoint = utils.load_state_dict(checkpoint)
|
checkpoint = utils.load_state_dict(checkpoint)
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
||||||
|
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
checkpoint = utils.load_state_dict(checkpoint)
|
checkpoint = utils.load_state_dict(checkpoint)
|
||||||
fsdp_model = optimizer.unwrap_model()
|
fsdp_model = optimizer.unwrap_model()
|
||||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||||
optimizer.load_state_dict(sharded_osd)
|
optimizer.load_state_dict(sharded_osd)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
|
||||||
|
model = model.unwrap()
|
||||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||||
full_model_state = model.state_dict()
|
full_model_state = model.state_dict()
|
||||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, FSDPOptimizerWrapper)
|
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
fsdp_model = optimizer.unwrap_model()
|
fsdp_model = optimizer.unwrap_model()
|
||||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||||
|
@ -87,9 +87,6 @@ class CheckpointIO(ABC):
|
|||||||
# return the origin model instead of the unwrapped model
|
# return the origin model instead of the unwrapped model
|
||||||
origin_model = model
|
origin_model = model
|
||||||
|
|
||||||
if isinstance(model, ModelWrapper):
|
|
||||||
model = model.unwrap()
|
|
||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
self.load_sharded_model(model, index_file_path, strict)
|
self.load_sharded_model(model, index_file_path, strict)
|
||||||
else:
|
else:
|
||||||
@ -134,9 +131,6 @@ class CheckpointIO(ABC):
|
|||||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(model, ModelWrapper):
|
|
||||||
model = model.unwrap()
|
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||||
else:
|
else:
|
||||||
|
@ -8,8 +8,6 @@ from typing import Optional
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.interface import OptimizerWrapper
|
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -28,7 +26,6 @@ from .utils import (
|
|||||||
shard_model_checkpoint,
|
shard_model_checkpoint,
|
||||||
shard_optimizer_checkpoint,
|
shard_optimizer_checkpoint,
|
||||||
sharded_optimizer_loading_epilogue,
|
sharded_optimizer_loading_epilogue,
|
||||||
unwrap_optimizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ["GeneralCheckpointIO"]
|
__all__ = ["GeneralCheckpointIO"]
|
||||||
@ -58,10 +55,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
Load sharded optimizer with the given path to index file.
|
Load sharded optimizer with the given path to index file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If optimizer is wrapped, unwrap it.
|
|
||||||
if isinstance(optimizer, OptimizerWrapper):
|
|
||||||
optimizer = unwrap_optimizer(optimizer)
|
|
||||||
|
|
||||||
# Read checkpoint index file.
|
# Read checkpoint index file.
|
||||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||||
|
|
||||||
@ -98,10 +91,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If optimizer is wrapped, unwrap it.
|
|
||||||
if isinstance(optimizer, OptimizerWrapper):
|
|
||||||
optimizer = unwrap_optimizer(optimizer)
|
|
||||||
|
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
|
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -13,7 +13,7 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
@ -71,8 +71,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
self.tp_size = dist.get_world_size(tp_group)
|
self.tp_size = dist.get_world_size(tp_group)
|
||||||
self.use_zero = zero_stage > 0
|
self.use_zero = zero_stage > 0
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.working_to_master_map = None
|
|
||||||
self.master_to_working_map = None
|
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -159,7 +157,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: ModelWrapper,
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
@ -184,6 +182,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
|
model = model.unwrap()
|
||||||
|
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@ -279,7 +280,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {final_index_file_path}."
|
f"index located at {final_index_file_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||||
"""
|
"""
|
||||||
Load sharded model with the given path to index file of checkpoint folder.
|
Load sharded model with the given path to index file of checkpoint folder.
|
||||||
|
|
||||||
@ -289,6 +290,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||||
This argument should be manually set to False since params on same device might be stored in different files.
|
This argument should be manually set to False since params on same device might be stored in different files.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||||
|
model_before_wrapping = model # backup for model before wrapping
|
||||||
|
model = model.unwrap()
|
||||||
|
|
||||||
# Check whether the checkpoint uses safetensors.
|
# Check whether the checkpoint uses safetensors.
|
||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
@ -347,23 +351,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
_load(extra_state_key)
|
_load(extra_state_key)
|
||||||
|
|
||||||
# Update master params if mixed-precision training is enabled.
|
# Update master params if mixed-precision training is enabled.
|
||||||
with torch.no_grad():
|
model_before_wrapping.update_master_params()
|
||||||
if self.working_to_master_map is not None:
|
|
||||||
for param in model.parameters():
|
|
||||||
if (param is None) or (id(param) not in self.working_to_master_map):
|
|
||||||
continue
|
|
||||||
master_param = self.working_to_master_map[id(param)]
|
|
||||||
if self.use_zero:
|
|
||||||
# master_param is sharded under Zero setting
|
|
||||||
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
|
|
||||||
if padding_size > 0:
|
|
||||||
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
|
||||||
else:
|
|
||||||
padded_param = param.data.view(-1)
|
|
||||||
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
|
|
||||||
master_param.data.copy_(sharded_param.data)
|
|
||||||
else:
|
|
||||||
master_param.data.copy_(param.data)
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||||
@ -392,6 +380,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
prefix (str): Perfix of file to save
|
prefix (str): Perfix of file to save
|
||||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@ -410,7 +399,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
use_zero=self.use_zero,
|
use_zero=self.use_zero,
|
||||||
dp_group=self.dp_group,
|
dp_group=self.dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
master_to_working_map=self.master_to_working_map,
|
master_to_working_map=optimizer.get_master_to_working_map(),
|
||||||
size_per_shard=size_per_shard,
|
size_per_shard=size_per_shard,
|
||||||
)
|
)
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||||
@ -511,6 +500,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||||
prefix (str): Not used.
|
prefix (str): Not used.
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
|
|
||||||
def _get_param_id_from_optimizer_param(
|
def _get_param_id_from_optimizer_param(
|
||||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||||
@ -525,9 +515,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||||
id_map = {}
|
id_map = {}
|
||||||
|
master_to_working_map = optimizer.get_master_to_working_map()
|
||||||
for pg in optimizer.optim.param_groups:
|
for pg in optimizer.optim.param_groups:
|
||||||
for param in pg["params"]:
|
for param in pg["params"]:
|
||||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||||
id_map[param_id] = param
|
id_map[param_id] = param
|
||||||
|
|
||||||
# Read checkpoint index file.
|
# Read checkpoint index file.
|
||||||
@ -560,7 +551,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
for param in pg["params"]:
|
for param in pg["params"]:
|
||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||||
if param_id not in weight_map:
|
if param_id not in weight_map:
|
||||||
continue
|
continue
|
||||||
filename = weight_map[param_id]
|
filename = weight_map[param_id]
|
||||||
@ -577,8 +568,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
# Then shard the loaded optimizer states if using tp/zero.
|
# Then shard the loaded optimizer states if using tp/zero.
|
||||||
for param, state in optimizer.optim.state.items():
|
for param, state in optimizer.optim.state.items():
|
||||||
device = param.device
|
device = param.device
|
||||||
if self.master_to_working_map is not None:
|
if master_to_working_map is not None:
|
||||||
working_param = self.master_to_working_map[id(param)]
|
working_param = master_to_working_map[id(param)]
|
||||||
else:
|
else:
|
||||||
working_param = param
|
working_param = param
|
||||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||||
@ -614,42 +605,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||||
|
|
||||||
def link_master_and_working_param(
|
|
||||||
self,
|
|
||||||
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
|
||||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
|
|
||||||
This mapping can only be created when mixied precision is used.
|
|
||||||
The created mappings should be mappings from integer parameter addresses to parameter objects.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
|
|
||||||
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
|
|
||||||
"""
|
|
||||||
self.working_to_master_map = dict()
|
|
||||||
for k, v in working_to_master_map.items():
|
|
||||||
if isinstance(k, torch.Tensor):
|
|
||||||
self.working_to_master_map[id(k)] = v
|
|
||||||
elif isinstance(k, int):
|
|
||||||
self.working_to_master_map[k] = v
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.master_to_working_map = dict()
|
|
||||||
for k, v in master_to_working_map.items():
|
|
||||||
if isinstance(k, torch.Tensor):
|
|
||||||
self.master_to_working_map[id(k)] = v
|
|
||||||
elif isinstance(k, int):
|
|
||||||
self.master_to_working_map[k] = v
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gather_from_sharded_optimizer_state(
|
def gather_from_sharded_optimizer_state(
|
||||||
state: OrderedDict,
|
state: OrderedDict,
|
||||||
|
@ -11,7 +11,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.interface import OptimizerWrapper
|
|
||||||
from colossalai.tensor.d_tensor import (
|
from colossalai.tensor.d_tensor import (
|
||||||
is_customized_distributed_tensor,
|
is_customized_distributed_tensor,
|
||||||
is_distributed_tensor,
|
is_distributed_tensor,
|
||||||
@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
|
|||||||
# ======================================
|
# ======================================
|
||||||
# Helper classes and functions for saving shard file
|
# Helper classes and functions for saving shard file
|
||||||
# ======================================
|
# ======================================
|
||||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
|
||||||
"""
|
|
||||||
Unwrap a wrapped optimizer.
|
|
||||||
This method should be used before saving/loading it to/from sharded checkpoints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
unwrapped_optim = optimizer.optim
|
|
||||||
return unwrapped_optim
|
|
||||||
|
|
||||||
|
|
||||||
class StateDictSharder:
|
class StateDictSharder:
|
||||||
|
@ -186,10 +186,6 @@ class GeminiDDP(ModelWrapper):
|
|||||||
for p in params_to_ignore:
|
for p in params_to_ignore:
|
||||||
p._ddp_to_ignore = True
|
p._ddp_to_ignore = True
|
||||||
|
|
||||||
def unwrap(self):
|
|
||||||
# as save/load state dict is overwrited, only return self
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _get_non_persistent_buffers_set(
|
def _get_non_persistent_buffers_set(
|
||||||
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||||
):
|
):
|
||||||
|
@ -648,3 +648,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||||
|
|
||||||
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||||
|
return self._param_store.working_to_master_param
|
||||||
|
|
||||||
|
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||||
|
return self._param_store.master_to_working_param
|
||||||
|
@ -61,9 +61,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
|
|||||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||||
|
|
||||||
if plugin_type == "gemini":
|
if plugin_type == "gemini":
|
||||||
check_state_dict_equal(
|
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||||
model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
Loading…
Reference in New Issue
Block a user