mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
50
colossalai/legacy/tensor/op_wrapper.py
Normal file
50
colossalai/legacy/tensor/op_wrapper.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
from typing import Callable, Dict
|
||||
|
||||
# Custom sharded ops
|
||||
_COLOSSAL_OPS: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def _register_colo_op(op, func):
|
||||
global _COLOSSAL_OPS
|
||||
_COLOSSAL_OPS[op] = func
|
||||
|
||||
|
||||
def colo_op_impl(func):
|
||||
"""
|
||||
Provides a way for users to write their own custom operator. This
|
||||
can be used to override existing ColoTensor operators or write a new
|
||||
one not supported by ColoTensor. If the operator in question is covered
|
||||
by ``__torch_function__`` dispatch and has a ColoTensor as any of its
|
||||
parameters, the function provided will be invoked for that operator.
|
||||
|
||||
Example:
|
||||
>>> @colo_op_impl(torch.nn.functional.linear)
|
||||
>>> def my_custom_linear(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
>>> input = torch.rand(10, 32)
|
||||
>>> weight = ColoTensor(torch.rand(32, 16))
|
||||
>>> bias = ColoTensor(torch.rand(16))
|
||||
>>> # This will call `my_custom_linear` instead of the default.
|
||||
>>> torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
The types, args and kwargs parameters are the same parameters that are
|
||||
passed to ``__torch_function__`` dispatch API
|
||||
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
|
||||
|
||||
Args:
|
||||
func(Callable): Torch function for which we want to provide a sharded
|
||||
implementation (ex: torch.nn.functional.linear)
|
||||
"""
|
||||
|
||||
def decorator_sharded_func(wrapped_func):
|
||||
_register_colo_op(func, wrapped_func)
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return wrapped_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_sharded_func
|
Reference in New Issue
Block a user