mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .module_checkpoint import load_checkpoint, save_checkpoint
|
||||
|
||||
__all__ = ['save_checkpoint', 'load_checkpoint']
|
||||
__all__ = ["save_checkpoint", "load_checkpoint"]
|
||||
|
@@ -9,13 +9,15 @@ from colossalai.tensor import ColoTensor
|
||||
from .utils import gather_tensor, scatter_tensor
|
||||
|
||||
|
||||
def save_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
def save_checkpoint(
|
||||
path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""save_checkpoint
|
||||
save a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
@@ -30,7 +32,7 @@ def save_checkpoint(path: str,
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
@@ -39,10 +41,10 @@ def save_checkpoint(path: str,
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.save_ready
|
||||
assert v.is_replicate()
|
||||
delattr(v, 'save_ready')
|
||||
delattr(v, "save_ready")
|
||||
# model saving
|
||||
save_state = {'epoch': epoch, 'model': model_state}
|
||||
torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs)
|
||||
save_state = {"epoch": epoch, "model": model_state}
|
||||
torch.save(save_state, path + "/epoch_{}_model.pth".format(epoch), *args, **kwargs)
|
||||
|
||||
# delete old dicts
|
||||
del model_state
|
||||
@@ -52,35 +54,37 @@ def save_checkpoint(path: str,
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for k, v in optim_state["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs)
|
||||
save_state = {"epoch": epoch, "optim": optim_state}
|
||||
torch.save(save_state, path + "/epoch_{}_optim.pth".format(epoch), *args, **kwargs)
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for k, v in optimizer.state_dict()["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
assert hasattr(t, "save_ready")
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
delattr(t, "save_ready")
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
torch_load_kwargs: Optional[Dict] = None,
|
||||
load_state_dict_kwargs: Optional[Dict] = None):
|
||||
def load_checkpoint(
|
||||
path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
torch_load_kwargs: Optional[Dict] = None,
|
||||
load_state_dict_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""load_checkpoint
|
||||
load a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
@@ -106,8 +110,8 @@ def load_checkpoint(path: str,
|
||||
gather_tensor(p)
|
||||
|
||||
if rank == 0:
|
||||
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs)
|
||||
model.load_state_dict(load_state['model'], **load_state_dict_kwargs)
|
||||
load_state = torch.load(path + "/epoch_{}_model.pth".format(epoch), **torch_load_kwargs)
|
||||
model.load_state_dict(load_state["model"], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
# scatter loaded parameters
|
||||
@@ -115,24 +119,24 @@ def load_checkpoint(path: str,
|
||||
if isinstance(p, ColoTensor):
|
||||
scatter_tensor(p, mapping[n])
|
||||
if rank == 0:
|
||||
assert hasattr(p, 'save_ready')
|
||||
delattr(p, 'save_ready')
|
||||
assert hasattr(p, "save_ready")
|
||||
delattr(p, "save_ready")
|
||||
del mapping
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for k, v in optimizer.state_dict()["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs)
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs)
|
||||
colo_checkpoint = torch.load(path + "/epoch_{}_optim.pth".format(epoch), **torch_load_kwargs)
|
||||
optimizer.load_state_dict(colo_checkpoint["optim"], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for k, v in optimizer.state_dict()["state"].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
|
@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor
|
||||
|
||||
def robust_broadcast(tensor):
|
||||
with torch.no_grad():
|
||||
is_cpu_ten = tensor.device.type == 'cpu'
|
||||
is_cpu_ten = tensor.device.type == "cpu"
|
||||
if is_cpu_ten:
|
||||
b_data = tensor.cuda()
|
||||
else:
|
||||
@@ -21,8 +21,7 @@ def robust_broadcast(tensor):
|
||||
|
||||
|
||||
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
"""Make colo_tensor replicated when the rank is 0
|
||||
"""
|
||||
"""Make colo_tensor replicated when the rank is 0"""
|
||||
if not colo_tensor.is_replicate():
|
||||
pg = colo_tensor.get_process_group()
|
||||
# for the group which contains rank 0
|
||||
@@ -36,12 +35,11 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signature
|
||||
setattr(colo_tensor, "save_ready", True) # set saving signature
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
"""Reversal operation of `gather_tensor`.
|
||||
"""
|
||||
"""Reversal operation of `gather_tensor`."""
|
||||
if dist_spec.placement == DistPlacementPattern.REPLICATE:
|
||||
robust_broadcast(colo_tensor.data)
|
||||
else:
|
||||
@@ -57,7 +55,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)
|
||||
)
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
|
Reference in New Issue
Block a user