[autoparallel] support distributed dataloader option (#1906)

* [autoparallel] support distributed dataloader option

* update output handler to support ddp dataloader

* poish code
This commit is contained in:
YuliangLiu0306
2022-11-17 20:11:53 +08:00
committed by GitHub
parent 6630d45546
commit 0da1d00399
18 changed files with 257 additions and 61 deletions

View File

@@ -1,11 +1,30 @@
from dataclasses import dataclass
from enum import Enum
__all__ = ['SolverOptions']
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2
class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
fast: bool = False
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED

View File

@@ -6,15 +6,16 @@ from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.tensor_shard.node_handler import OuputHandler, PlacehodlerHandler, operator_registry
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
OuputHandler,
PlacehodlerHandler,
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .options import SolverOptions
from .options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor']
@@ -67,7 +68,15 @@ class StrategiesConstructor:
strategies_vector = StrategiesVector(node)
# placeholder node
if node.op == 'placeholder':
placeholder_handler = PlacehodlerHandler(node, self.device_mesh, strategies_vector)
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed'
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
placeholder_option = 'replicated'
placeholder_handler = PlacehodlerHandler(node,
self.device_mesh,
strategies_vector,
placeholder_option=placeholder_option)
placeholder_handler.register_strategy()
# get_attr node
@@ -97,7 +106,12 @@ class StrategiesConstructor:
# output node
elif node.op == 'output':
output_handler = OuputHandler(node, self.device_mesh, strategies_vector)
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
output_option = 'distributed'
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
output_option = 'replicated'
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
if len(strategies_vector) <= 0: