Merge pull request #403 from ver217/feature/shard-strategy

[zero] Add bucket tensor shard strategy
This commit is contained in:
Frank Lee 2022-03-14 16:29:28 +08:00 committed by GitHub
commit 2fe68b359a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 129 additions and 63 deletions

View File

@ -1,7 +1,8 @@
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook
@ -18,23 +19,32 @@ class ZeroHook(BaseOpHook):
self.computing_device = torch.device(f'cuda:{get_current_device()}') self.computing_device = torch.device(f'cuda:{get_current_device()}')
def pre_fwd_exec(self, module: torch.nn.Module, *args): def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = []
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data]) tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = []
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data]) tensor_list.append(param.col_attr.data)
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device) self.shard_strategy.shard(tensor_list)
for param in module.parameters():
param.col_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output): def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = []
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.gather([param.col_attr.data]) tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
@ -52,10 +62,13 @@ class ZeroHook(BaseOpHook):
param.col_attr.bwd_count += 1 param.col_attr.bwd_count += 1
def post_bwd_exec(self, module: torch.nn.Module, input): def post_bwd_exec(self, module: torch.nn.Module, input):
tensor_list = []
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
self.shard_strategy.shard([param.col_attr.data]) tensor_list.append(param.col_attr.data)
param.data = torch.empty([], dtype=param.col_attr.data.dtype, device=param.col_attr.data.payload.device) self.shard_strategy.shard(tensor_list)
for param in module.parameters():
param.col_attr.remove_torch_payload()
def pre_iter(self): def pre_iter(self):
pass pass

View File

@ -1,4 +1,5 @@
from colossalai.zero.shard_utils.base_shard_strategy import BaseShardStrategy from .base_shard_strategy import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy
__all__ = ['BaseShardStrategy', 'TensorShardStrategy'] __all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']

View File

@ -0,0 +1,41 @@
from typing import List
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
def gather(self, tensor_list: List[ShardedTensor]):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0:
return
target_device = tensor_list[0].device
dtype = tensor_list[0].dtype
buffer_list: List[torch.Tensor] = []
tensor_numels = [t.payload.numel() for t in tensor_list]
buffer_size = sum(tensor_numels)
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
offset = 0
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.reset_payload(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]

View File

@ -4,21 +4,20 @@
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port, init_device): def run_dist(rank, world_size, port, init_device, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for get_components_func in non_distributed_component_funcs: for get_components_func in non_distributed_component_funcs:
@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor = torch.zeros(1, dtype=torch.int) model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(convert_fp16=True,
target_device=init_device, target_device=init_device,
shard_strategy=TensorShardStrategy(), shard_strategy=shard_strategy(),
shard_param=True, shard_param=True,
model_numel_tensor=model_numel_tensor): model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
@ -50,11 +49,16 @@ def run_dist(rank, world_size, port, init_device):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) @pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
def test_zero_init_context(world_size, init_device): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), init_device=init_device) def test_zero_init_context(world_size, init_device, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
init_device=init_device,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_init_context(2, torch.device('cpu')) test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}')) test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'), TensorShardStrategy)

View File

@ -3,30 +3,28 @@
import copy import copy
from functools import partial from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.zero.init_ctx import ZeroInitContext import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.init_ctx import ZeroInitContext
TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast): def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
@ -66,14 +64,16 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("enable_autocast", [True]) @pytest.mark.parametrize("enable_autocast", [True])
@pytest.mark.parametrize("use_zero_init_ctx", [True]) @pytest.mark.parametrize("use_zero_init_ctx", [True])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
use_zero_init_ctx=use_zero_init_ctx, use_zero_init_ctx=use_zero_init_ctx,
enable_autocast=enable_autocast) enable_autocast=enable_autocast,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True) test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)

View File

@ -10,20 +10,20 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero_data_parallel.common import CONFIG, allclose
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose
def _run_shard_tensor(rank, world_size, port): def _run_shard_tensor(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3] assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3] assert list(t.shape) == [world_size * 2, 3]
shard_strategy = TensorShardStrategy(process_group=None) shard_strategy = shard_strategy(process_group=None)
# test shard strategy # test shard strategy
shard_strategy.shard([t]) shard_strategy.shard([t])
@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_shard_tensor(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) def test_shard_tensor(world_size, shard_strategy):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_shard_tensor(2) test_shard_tensor(2, TensorShardStrategy)
test_shard_param(2) test_shard_param(2)
test_shard_param_v2(2) test_shard_param_v2(2)
test_init_shard_param(4) test_init_shard_param(4)

View File

@ -10,7 +10,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -38,12 +38,12 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port, cpu_offload): def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), zero_model = ShardedModelV2(copy.deepcopy(model),
@ -69,10 +69,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("cpu_offload", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False])
def test_sharded_optim_v2(world_size, cpu_offload): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload) def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
cpu_offload=cpu_offload,
shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2, cpu_offload=True) test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)

View File

@ -11,7 +11,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -47,12 +47,12 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'}) zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy, offload_config={'device': 'cpu'})
@ -79,10 +79,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_sharded_optim_v2(world_size): @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_sharded_optim_v2(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2(world_size=2) test_sharded_optim_v2(world_size=2, shard_strategy=TensorShardStrategy)

View File

@ -9,22 +9,21 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18'] test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model_builder() model = model_builder()
shard_strategy = TensorShardStrategy()
model = model.half().cuda() model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy) zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
zero_state_dict = zero_model.state_dict() zero_state_dict = zero_model.state_dict()
@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_zero_state_dict(): @pytest.mark.parametrize("world_size", [1, 2])
world_size = 2 @pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
run_func = partial(run_dist, world_size=world_size, port=free_port()) def test_zero_state_dict(world_size, shard_strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_state_dict() test_zero_state_dict(2, TensorShardStrategy)