[zero] add state dict for low level zero (#4179)

* add state dict for zero

* fix unit test

* polish
This commit is contained in:
LuGY 2023-07-06 17:20:04 +08:00 committed by Hongxin Liu
parent c668801d36
commit dd7cc58299
2 changed files with 188 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Optional
@ -198,7 +199,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
params_current_rank = [] params_current_rank = []
device = 'cpu' if self._cpu_offload else get_current_device() device = 'cpu' if self._cpu_offload else get_current_device()
for param in reversed(param_list): for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
self._param_store.record_param_padding_size(param, padding_size) self._param_store.record_param_padding_size(param, padding_size)
@ -468,3 +469,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
yield yield
finally: finally:
self.require_grad_sync = old_require_grad_sync self.require_grad_sync = old_require_grad_sync
##############
# State Dict #
##############
def _pack_state(self, state: dict) -> dict:
# comes from pytorch optimizer.state_dict()
param_mappings = {}
start_index = 0
def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
return packed
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
return {'state': packed_state, 'param_groups': param_groups}
def state_dict(self) -> dict:
"""Return a state_dict same with DDP
Returns:
dict: the pytorch form state_dict
"""
zero_state = dict()
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(gather_tensor, v, group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
return states_dict
def load_state_dict(self, state_dict: dict):
"""Load state dict, requires the state_dict be the pytorch form
Args:
state_dict (dict): A pytorch form state_dict
"""
zero_state_dict = copy.deepcopy(state_dict)
for param_idx, state in zero_state_dict['state'].items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()

View File

@ -0,0 +1,121 @@
import copy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(12, 24)
self.linear2 = nn.Linear(24, 12)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
def exam_zero_1_torch_ddp_ckpt():
"""
We examine the state_dict of zero and DDP.
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
"""
local_rank = torch.distributed.get_rank()
seed_all(1453)
# create models
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
# we only test stage 1 here
# the state dicts of stage 1 and stage 2 are the same
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
seed_all(1453 + local_rank)
# create
input_data = torch.rand(4, 12).cuda()
# forward
zero_output = zero_model(input_data)
torch_output = torch_model(input_data)
# backward
zero_optimizer.backward(zero_output.mean().float())
torch_output.mean().backward()
# step
zero_optimizer.step()
torch_optimizer.step()
torch_state_dict = torch_optimizer.state_dict()
zero_state_dict = zero_optimizer.state_dict()
# examine the original state dict
for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
for t_v, z_v in zip(torch_state.values(), zero_state.values()):
loose_close(t_v, z_v)
# empty the optimzer state
zero_optimizer.optim.state = []
# zero load a torch checkpoint
zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict))
zero_state_dict = zero_optimizer.state_dict()
# examine the loaded state dict
for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
for t_v, z_v in zip(torch_state.values(), zero_state.values()):
loose_close(t_v, z_v)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_torch_ddp_ckpt()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_ckpt():
spawn(run_dist, 2)
if __name__ == '__main__':
test_zero_ckpt()