mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[Gemini] patch for supporting orch.add_ function for ColoTensor (#2003)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from .linear import colo_linear
|
||||
from .element_wise import *
|
||||
from .layernorm import colo_layernorm
|
||||
from .loss import colo_cross_entropy
|
||||
from .embedding import colo_embedding
|
||||
from .addmm import colo_addmm
|
||||
from .batch_norm import colo_batch_norm
|
||||
from .element_wise import *
|
||||
from .embedding import colo_embedding
|
||||
from .embedding_bag import colo_embedding_bag
|
||||
from .view import colo_view
|
||||
from .layernorm import colo_layernorm
|
||||
from .linear import colo_linear
|
||||
from .loss import colo_cross_entropy
|
||||
from .view import colo_view
|
||||
|
33
colossalai/nn/_ops/batch_norm.py
Normal file
33
colossalai/nn/_ops/batch_norm.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
@colo_op_impl(F.batch_norm)
|
||||
def colo_batch_norm(
|
||||
input: GeneralTensor,
|
||||
running_mean: Optional[GeneralTensor],
|
||||
running_var: Optional[GeneralTensor],
|
||||
weight: Optional[GeneralTensor] = None,
|
||||
bias: Optional[GeneralTensor] = None,
|
||||
training: bool = False,
|
||||
momentum: float = 0.1,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
assert isinstance(weight, ColoTensor)
|
||||
running_mean = running_mean.detach()
|
||||
running_var = running_var.detach()
|
||||
|
||||
input = convert_to_colo_tensor(input, weight.get_process_group())
|
||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||
input = input.redistribute(ReplicaSpec())
|
||||
bias = bias.redistribute(ReplicaSpec())
|
||||
|
||||
output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
|
||||
output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
|
||||
return output
|
@@ -34,6 +34,18 @@ def register_elementwise_op(op):
|
||||
dist_attr=input_tensor.dist_spec))
|
||||
|
||||
|
||||
@colo_op_impl(torch.relu_)
|
||||
def elementwise_op(input_tensor):
|
||||
torch.relu_(input_tensor.data)
|
||||
return input_tensor
|
||||
|
||||
|
||||
@colo_op_impl(Tensor.add_)
|
||||
def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
|
||||
input_tensor = input_tensor.data.add_(*args, **kwargs)
|
||||
return input_tensor
|
||||
|
||||
|
||||
# Tensor op
|
||||
register_elementwise_op(Tensor.abs)
|
||||
register_elementwise_op(Tensor.absolute)
|
||||
|
@@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP):
|
||||
p.grad = None
|
||||
|
||||
def _post_backward(self):
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
# assert self.chunk_manager.accessed_mem == 0
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
|
Reference in New Issue
Block a user