[tensor] refine linear and add gather for laynorm (#893)

* refine linear and add function to ColoTensor

* add gather for layernorm

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-04-28 10:55:40 +08:00
committed by GitHub
parent 26c49639d8
commit cb182da7c5
7 changed files with 225 additions and 123 deletions

View File

@@ -1,4 +1,4 @@
from .spec import ComputePattern, ParallelAction, TensorSpec
from .spec import ComputePattern, ParallelAction, TensorSpec, ShardPattern
from .op_wrapper import (
colo_op_impl,)
from .colo_tensor import ColoTensor
@@ -7,5 +7,5 @@ from ._ops import *
__all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor'
'named_params_with_colotensor', 'ShardPattern'
]