mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[autoparallel] support distributed dataloader option (#1906)
* [autoparallel] support distributed dataloader option * update output handler to support ddp dataloader * poish code
This commit is contained in:
@@ -1,11 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ['SolverOptions']
|
||||
|
||||
|
||||
class SolverPerference(Enum):
|
||||
"""
|
||||
This enum class is to define the solver preference.
|
||||
"""
|
||||
STANDARD = 0
|
||||
DP = 1
|
||||
TP = 2
|
||||
|
||||
|
||||
class DataloaderOption(Enum):
|
||||
"""
|
||||
This enum class is to define the dataloader option.
|
||||
"""
|
||||
REPLICATED = 0
|
||||
DISTRIBUTED = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolverOptions:
|
||||
"""
|
||||
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||
"""
|
||||
fast: bool = False
|
||||
solver_perference: SolverPerference = SolverPerference.STANDARD
|
||||
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
|
||||
|
@@ -6,15 +6,16 @@ from typing import Dict, List
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import (
|
||||
GetattrHandler,
|
||||
OuputHandler,
|
||||
PlacehodlerHandler,
|
||||
operator_registry,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .options import SolverOptions
|
||||
from .options import DataloaderOption, SolverOptions
|
||||
|
||||
__all__ = ['StrategiesConstructor']
|
||||
|
||||
@@ -67,7 +68,15 @@ class StrategiesConstructor:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector)
|
||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||
placeholder_option = 'distributed'
|
||||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
placeholder_option = 'replicated'
|
||||
placeholder_handler = PlacehodlerHandler(node,
|
||||
self.device_mesh,
|
||||
strategies_vector,
|
||||
placeholder_option=placeholder_option)
|
||||
placeholder_handler.register_strategy()
|
||||
|
||||
# get_attr node
|
||||
@@ -97,7 +106,12 @@ class StrategiesConstructor:
|
||||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector)
|
||||
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
|
||||
output_option = 'distributed'
|
||||
else:
|
||||
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
|
||||
output_option = 'replicated'
|
||||
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
|
||||
output_handler.register_strategy()
|
||||
|
||||
if len(strategies_vector) <= 0:
|
||||
|
Reference in New Issue
Block a user