[tensor] ColoTensor supports ZeRo (#1015)

* impl chunk manager

* impl param op hook

* add reduce_chunk

* add zero hook v2

* add zero dp

* fix TensorInfo

* impl load balancing when using zero without chunk

* fix zero hook

* polish chunk

* fix bugs

* ddp ok

* zero ok

* polish code

* fix bugs about load balancing

* polish code

* polish code

* add ene-to-end test

* polish code

* polish code

* polish code

* fix typo

* add test_chunk

* fix bugs

* fix bugs

* polish code
This commit is contained in:
ver217 2022-05-31 12:00:12 +08:00 committed by GitHub
parent cfa6c1b46b
commit 9492a561c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 618 additions and 4 deletions

View File

@ -3,8 +3,10 @@ import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor import ChunkManager, use_param_op_hooks, TensorState
__all__ = ['ColoDDP']
__all__ = ['ColoDDP', 'ColoDDPV2']
def free_storage(data: torch.Tensor) -> None:
@ -76,3 +78,54 @@ class ColoDDP(torch.nn.Module):
else:
p._saved_grad.requires_grad_(False)
p._saved_grad.zero_()
class ColoDDPV2(ColoDDP):
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(module)
self.chunk_manager = chunk_manager
self.param_op_hook = ZeROHookV2(chunk_manager)
self.fp32_params = []
# TODO: get param order and filter unused params
for p in module.parameters():
assert p.dtype == torch.half
fp32_p = p.float()
self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p)
def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True)
for p, fp32_p in zip(self.module.parameters(), self.fp32_params):
if not self.chunk_manager.is_chunk_free(p):
self.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
with use_param_op_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
return outputs
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
self.chunk_manager.exec_lazy_release()
for p in self.module.parameters():
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
self.chunk_manager.reduce_chunk(p)
self.chunk_manager.release_chunk(p)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)

View File

@ -8,12 +8,14 @@ from ._ops import *
from .optim.colo_optimizer import ColoOptimizer
from . import distspec
from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, use_param_op_hooks
from .chunk import ChunkManager, TensorState
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .modules import ColoLinear, ColoEmbedding
__all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager',
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding'
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear',
'ColoEmbedding', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState'
]

264
colossalai/tensor/chunk.py Normal file
View File

@ -0,0 +1,264 @@
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Deque, Set, List
from collections import deque
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
class Chunk:
def __init__(self,
chunk_size: int,
src_rank: int,
dtype: torch.dtype,
init_device: Optional[torch.device] = None) -> None:
self.size = chunk_size
self.utilized_size = 0
self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
self.dtype = dtype
self.device = init_device or get_current_device()
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
if not self.is_src_rank:
self.data.storage().resize_(0)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
def append(self, tensor: torch.Tensor) -> None:
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
if new_utilized_size > self.size:
raise ChunkFullError
tensor_state = TensorState.FREE
if self.is_src_rank:
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
tensor_state = TensorState.HOLD
tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape)
else:
tensor.storage().resize_(0)
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.utilized_size = new_utilized_size
def release(self) -> None:
if not self.is_src_rank:
self.data.storage().resize_(0)
self._update_tensors_state(TensorState.FREE)
def _update_tensors_ptr(self) -> None:
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
tensor_info.state = next_state
def access(self) -> None:
if not self.is_src_rank:
self.data.storage().resize_(self.size)
self.data.data = self.data.to(get_current_device())
dist.broadcast(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
self._update_tensors_ptr()
if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
def move_device(self, device: torch.device) -> None:
self.data.data = self.data.to(device)
self._update_tensors_ptr()
def reduce(self, is_all_reduce: bool = False) -> None:
self.data.data = self.data.to(get_current_device())
if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
else:
dist.reduce(self.data, self.src_rank, group=gpc.get_group(ParallelMode.DATA))
self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self.tensors_info[tensor].state = tensor_state
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
tensor_info = self.tensors_info[tensor]
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property
def can_release(self) -> bool:
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.HOLD:
return False
return True
@property
def can_move_device(self) -> bool:
for tensor_info in self.tensors_info.values():
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
return False
return True
@property
def can_reduce(self) -> bool:
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.READY_FOR_REDUCE:
return False
return True
@property
def is_free(self) -> bool:
return self.data.storage().size() == 0
def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
class ChunkManager:
def __init__(self,
chunk_size: Optional[int],
enable_distributed_storage: bool = False,
init_device: Optional[torch.device] = None) -> None:
assert chunk_size is None or chunk_size > 0
self.chunk_size = chunk_size
self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
self.accessed_chunks: Set[Chunk] = set()
self.lazy_release_tensors: List[torch.Tensor] = []
if enable_distributed_storage and chunk_size is None:
self.rank_load: Dict[str, torch.Tensor] = {}
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
assert tensor not in self.tensor_chunk_map
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
try:
self.chunk_groups[group_name][-1].append(tensor)
except (IndexError, ChunkFullError):
chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
if self.enable_distributed_storage and self.chunk_size is None:
self.rank_load[group_name][src_rank] += chunk_size
self.chunk_groups[group_name].append(chunk)
chunk.append(tensor)
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
if not self.enable_distributed_storage:
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
def _get_next_src_rank(self, group_name: str) -> int:
if not self.enable_distributed_storage:
return gpc.get_local_rank(ParallelMode.DATA)
if self.chunk_size is None:
if group_name not in self.rank_load:
self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64)
src_rank = torch.argmin(self.rank_load[group_name]).item()
else:
chunk_idx = len(self.chunk_groups[group_name])
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
return src_rank
def access_chunk(self, tensor: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk in self.accessed_chunks:
return
chunk.access()
self.accessed_chunks.add(chunk)
def release_chunk(self, tensor: torch.Tensor) -> None:
if not self.enable_distributed_storage:
return
chunk = self.tensor_chunk_map[tensor]
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
chunk.release()
self.accessed_chunks.remove(chunk)
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk.can_move_device:
chunk.move_device(device)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, tensor: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
if not chunk.can_reduce:
return
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
return chunk.is_free
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
return self.tensor_chunk_map[tensor]
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None:
for tensor in self.lazy_release_tensors:
self.release_chunk(tensor)
self.lazy_release_tensors.clear()
def __repr__(self) -> str:
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n'
for i, chunk in enumerate(group):
msg += f'[{i}] {chunk}\n'
return msg

View File

@ -3,8 +3,10 @@ from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
from .param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
from typing import Optional
class ColoParameter(ColoTensor, torch.nn.Parameter):
r"""A kind of ColoTensor to be considered as a module parameter.
@ -44,6 +46,22 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if len(_ParamOpHookWrapper.hooks) > 0:
if not func.__name__.startswith('__'):
params = list(filter(lambda arg: isinstance(arg, ColoParameter), args))
if kwargs is not None:
params.extend(list(filter(lambda arg: isinstance(arg, ColoParameter), kwargs.values())))
if len(params) > 0:
with torch._C.DisableTorchFunction():
args = PreFwdPostBwd.apply(params, *args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = PostFwdPreBwd.apply(params, ret)
return ret
return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
@ -69,4 +87,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError

View File

@ -0,0 +1,71 @@
import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple
class ParamOpHook(ABC):
@abstractmethod
def pre_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_forward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def pre_backward(self, params: List[torch.Tensor]) -> None:
pass
@abstractmethod
def post_backward(self, params: List[torch.Tensor]) -> None:
pass
class _ParamOpHookWrapper:
hooks: Tuple[ParamOpHook, ...] = tuple()
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.pre_forward(ctx.params)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks:
hook.post_backward(ctx.params)
return (None,) + grads
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, args):
ctx.params = params
for hook in _ParamOpHookWrapper.hooks:
hook.post_forward(params)
return args
@staticmethod
def backward(ctx, *grads):
for hook in _ParamOpHookWrapper.hooks:
hook.pre_backward(ctx.params)
return (None,) + grads
@contextmanager
def use_param_op_hooks(*hooks: ParamOpHook):
try:
old_param_op_hooks = _ParamOpHookWrapper.hooks
_ParamOpHookWrapper.hooks = hooks
yield
finally:
_ParamOpHookWrapper.hooks = old_param_op_hooks

View File

@ -0,0 +1,57 @@
import torch
from colossalai.tensor import ParamOpHook, ChunkManager, TensorState
from enum import Enum
from typing import List
from contextlib import contextmanager
from functools import partial
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
class ZeROHookV2(ParamOpHook):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
for p in params:
self._chunk_manager.access_chunk(p)
def post_op(self, params):
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)
self._chunk_manager.add_lazy_release_tensors(params)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_forward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
def pre_backward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_backward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
try:
old_training_phase = self._training_phase
self._training_phase = training_phase
yield
finally:
self._training_phase = old_training_phase
switch_to_backward = switch_training_phase
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)

View File

@ -0,0 +1,70 @@
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from typing import List
from functools import partial
from colossalai.tensor import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
for p, has_tensor in zip(params, has_tensors):
if has_tensor:
assert p.storage().size() > 0
assert p.device.type == 'cuda'
else:
assert p.storage().size() == 0
# HAS_TENSORS[use_chunk][use_zero]
HAS_TENSORS = {
True: {
True: [[True, True, False], [False, False, True]],
False: [[True, True, True], [True, True, True]]
},
False: {
True: [[True, False, True], [False, True, False]],
False: [[True, True, True], [True, True, True]]
}
}
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
def run_chunk_zero(use_chunk, use_zero):
rank = gpc.get_local_rank(ParallelMode.DATA)
if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(32, 32) for _ in range(3)]
chunk_size = 2048 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
for p in params:
chunk_manager.append_tensor(p, 'param')
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
for p in params:
chunk_manager.access_chunk(p)
check_has_params(params, [True, True, True])
for p in params:
chunk_manager.release_chunk(p)
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chunk_zero()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_chunk_mapping(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_mapping(2)

View File

@ -0,0 +1,80 @@
import pytest
import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec, ColoParameter, ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
from colossalai.testing import parameterize
def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.storage().size() > 0:
assert tensor_equal(torch_p, p.float()), f'{torch_p} vs {p}'
def check_grad_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.grad is not None:
assert tensor_equal(torch_p.grad, p.grad.float())
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
def run_gpt(use_chunk, use_zero):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
model = model.cuda()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)
model = model.half()
chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager)
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
print(chunk_manager)
check_param_equal(model, torch_model)
model.train()
torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA))
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
logits = model(input_ids, attn_mask)
torch_logits = torch_model(input_ids, attn_mask)
assert tensor_equal(torch_logits, logits.float())
loss = criterion(logits, input_ids)
torch_loss = criterion(torch_logits, input_ids)
model.backward(loss)
torch_loss.backward()
check_grad_equal(model, torch_model)
break
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(4)