[refactor] memory utils (#577)

This commit is contained in:
Jiarui Fang
2022-04-01 09:22:33 +08:00
committed by GitHub
parent 104cbbb313
commit e956d93ac2
15 changed files with 261 additions and 202 deletions

View File

@@ -29,6 +29,7 @@ class MoeGradientHandler(BaseGradientHandler):
if global_data > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism

View File

@@ -10,8 +10,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_utils.utils import \
colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
@OPHOOKS.register_module