[Tensor] get named parameters for model using ColoTensors (#874)

This commit is contained in:
Jiarui Fang
2022-04-26 14:08:01 +08:00
committed by GitHub
parent 2883040286
commit e43f83aa5c
3 changed files with 59 additions and 3 deletions

View File

@@ -2,8 +2,10 @@ from .spec import ComputePattern, ParallelAction, TensorSpec
from .op_wrapper import (
colo_op_impl,)
from .colo_tensor import ColoTensor
from .utils import convert_parameter
from .utils import convert_parameter, named_params_with_colotensor
from ._ops import *
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern',
'TensorSpec', 'ParallelAction']
__all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor'
]