[refactor] remove old zero code (#517)

This commit is contained in:
Jiarui Fang
2022-03-25 14:54:39 +08:00
committed by GitHub
parent 6a3f9fda83
commit 4d322b79da
28 changed files with 33 additions and 2978 deletions

View File

@@ -7,9 +7,7 @@ import colossalai
import torch
from functools import partial
import torch.multiprocessing as mp
import pytest
def run_tensor_move(rank):

View File

@@ -2,11 +2,9 @@
# -*- encoding: utf-8 -*-
import copy
import operator as op
from functools import partial, reduce
from typing import List
import colossalai
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
import pytest
import torch
import torch.distributed as dist
@@ -14,10 +12,11 @@ import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from colossalai.testing import parameterize
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial
def checkpoint_wrapper(module, enable=True):
@@ -97,41 +96,9 @@ def check_params(model, zero_model, loose=False):
assert allclose(p, zero_p, loose=loose)
@parameterize('checkpoint', [False, True])
@parameterize('fp16', [False, True])
@parameterize('offload', [False, True])
@parameterize('norm_type', [1.0, 2.0, float('inf')])
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
ddp_model = DDP(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
check_grads(ddp_model, zero_model)
check_params(ddp_model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
check_grads(ddp_model, zero_model, loose=True)
check_params(ddp_model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_config()
@pytest.mark.dist