mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[format] Run lint on colossalai.engine (#3367)
This commit is contained in:
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseGradientHandler(ABC):
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
before optimization.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class DataParallelGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,9 +4,10 @@ from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
@@ -14,9 +15,9 @@ from ._base_gradient_handler import BaseGradientHandler
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in sub parallel groups.
|
||||
A all-reduce collective communication will be operated in
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user