ColossalAI/colossalai/legacy/nn/layer/utils/common.py
Hongxin Liu 554aa9592e
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
2023-09-11 16:24:28 +08:00

93 lines
2.3 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import collections.abc
from itertools import repeat
import numpy as np
import torch
from torch import Tensor, nn
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.utils import checkpoint
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True, offload: bool = False):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
self._offload = offload
def _forward(self, *args, **kwargs):
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, self._offload, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()
def divide(numerator, denominator):
"""Only allow exact division.
Args:
numerator (int): Numerator of the division.
denominator (int): Denominator of the division.
Returns:
int: the result of exact division.
"""
assert denominator != 0, 'denominator can not be zero'
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute_by_size(param, size):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, num_partitions)
def get_tensor_parallel_mode():
return env.mode
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)