mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[feature] new zero implementation (#1623)
This commit is contained in:
@@ -3,7 +3,7 @@ import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.gemini.update import ChunkManagerV2
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec
|
||||
@@ -19,23 +19,17 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
|
||||
def exam_chunk_memory(keep_gathered, pin_memory):
|
||||
pg = ProcessGroup()
|
||||
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(
|
||||
keep_gathered, pin_memory))
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
|
||||
|
||||
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
|
||||
config = {
|
||||
2: dict(
|
||||
chunk_size=128,
|
||||
keep_gathered=keep_gathered
|
||||
)
|
||||
}
|
||||
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
|
||||
|
||||
chunk_manager = ChunkManagerV2(config, pin_memory=pin_memory)
|
||||
chunk_manager = ChunkManager(config)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
for p in params:
|
||||
chunk_manager.append_tensor(p, 'param', 2)
|
||||
chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory)
|
||||
chunk_manager.close_all_groups()
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
||||
@@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.update import ChunkV2
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
|
||||
|
||||
def dist_sum(x):
|
||||
@@ -38,14 +38,12 @@ def check_euqal(param, param_cp):
|
||||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ColoProcessGroup()
|
||||
my_chunk = ChunkV2(
|
||||
chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
keep_gathered=keep_gathered,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
my_chunk = Chunk(chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
keep_gathered=keep_gathered,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
param_list = []
|
||||
param_cp_list = []
|
||||
|
||||
109
tests/test_gemini/update/test_fwd_bwd.py
Normal file
109
tests/test_gemini/update/test_fwd_bwd.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from time import time
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
chunk_manager = model.chunk_manager
|
||||
param_list = [p for p in model.parameters()]
|
||||
chunk_list = chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
|
||||
for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
|
||||
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
return logits
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
def exam_gpt_fwd_bwd(placement_policy):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
pg = ProcessGroup()
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 0:
|
||||
break
|
||||
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
model.backward(loss)
|
||||
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
|
||||
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
|
||||
|
||||
check_grad(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_gpt_fwd_bwd()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(1)
|
||||
118
tests/test_gemini/update/test_optim.py
Normal file
118
tests/test_gemini/update/test_optim.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from time import time
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
return logits
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
def exam_gpt_fwd_bwd(placement_policy):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
# debug_print([0], zero_logits, torch_logits)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_gpt_fwd_bwd()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(1)
|
||||
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.gemini.update import search_chunk_configuration
|
||||
from colossalai.gemini.chunk import search_chunk_configuration
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup
|
||||
@@ -35,12 +35,11 @@ def exam_search_chunk_size():
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
init_1d_row_spec(model, pg_tp)
|
||||
config_dict = search_chunk_configuration(
|
||||
model,
|
||||
search_range_mb=1,
|
||||
search_interval_byte=16,
|
||||
min_chunk_size_mb=0,
|
||||
filter_exlarge_params=True)
|
||||
config_dict = search_chunk_configuration(model,
|
||||
search_range_mb=1,
|
||||
search_interval_byte=16,
|
||||
min_chunk_size_mb=0,
|
||||
filter_exlarge_params=True)
|
||||
|
||||
for key in config_dict:
|
||||
chunk_size = config_dict[key]['chunk_size']
|
||||
|
||||
114
tests/test_gemini/update/test_zeroddp_state_dict.py
Normal file
114
tests/test_gemini/update/test_zeroddp_state_dict.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_state_dict(placement_policy, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_load_state_dict(placement_policy, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
optim = ZeroOptimizer(optimizer, model) # initialize the link between chunk16 and chunk32
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
exam_load_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_ddp(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_ddp(1)
|
||||
81
tests/test_gemini/update/test_zerooptim_state_dict.py
Normal file
81
tests/test_gemini/update/test_zerooptim_state_dict.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
optim = ZeroOptimizer(optimizer, model) # initialize the link between chunk16 and chunk32
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 0:
|
||||
break
|
||||
optim.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optim.backward(loss)
|
||||
|
||||
optim_state_dict = optim.state_dict()
|
||||
optim.load_state_dict(optim_state_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_zero_optim_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_optim(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_optim(1)
|
||||
Reference in New Issue
Block a user