mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
This commit is contained in:
92
colossalai/legacy/nn/layer/utils/common.py
Normal file
92
colossalai/legacy/nn/layer/utils/common.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import collections.abc
|
||||
from itertools import repeat
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint: bool = True, offload: bool = False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self._use_checkpoint = checkpoint
|
||||
self._offload = offload
|
||||
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
return checkpoint(self._forward, self._offload, *args, **kwargs)
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
self._use_checkpoint = self.checkpoint
|
||||
return super().train(mode=mode)
|
||||
|
||||
def eval(self):
|
||||
self._use_checkpoint = False
|
||||
return super().eval()
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Only allow exact division.
|
||||
|
||||
Args:
|
||||
numerator (int): Numerator of the division.
|
||||
denominator (int): Denominator of the division.
|
||||
|
||||
Returns:
|
||||
int: the result of exact division.
|
||||
"""
|
||||
assert denominator != 0, 'denominator can not be zero'
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
def set_tensor_parallel_attribute_by_size(param, size):
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
|
||||
|
||||
|
||||
def set_tensor_parallel_attribute_by_partition(param, num_partitions):
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
setattr(param, NUM_PARTITIONS, num_partitions)
|
||||
|
||||
|
||||
def get_tensor_parallel_mode():
|
||||
return env.mode
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
Reference in New Issue
Block a user