mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[dtensor] updated api and doc (#3845)
This commit is contained in:
@@ -3,10 +3,8 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.d_tensor.comm_spec import *
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
@@ -28,13 +26,21 @@ class LayoutConverterOptions:
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||
def to_global(distributed_tensor: "DTensor", layout: Layout) -> torch.Tensor:
|
||||
"""
|
||||
Convert a distributed tensor to the global tensor with the given layout.
|
||||
This function returns a native `torch.Tensor` object.
|
||||
|
||||
|
||||
Args:
|
||||
distributed_tensor (`DTensor`): the distributed tensor to be converted.
|
||||
layout (`Layout`): the target layout specification.
|
||||
"""
|
||||
layout_converter = LayoutConverter()
|
||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||
device_type=layout.device_type,
|
||||
sharding_spec=global_sharding_spec,
|
||||
entire_shape=layout.entire_shape)
|
||||
global_shape=layout.global_shape)
|
||||
with torch.no_grad():
|
||||
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
||||
return global_tensor
|
||||
@@ -49,6 +55,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
|
||||
|
||||
|
||||
class LayoutConverter(metaclass=SingletonMeta):
|
||||
"""
|
||||
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
@@ -91,15 +100,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -112,7 +120,12 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list = all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
@@ -130,7 +143,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
logical_process_axis = target_pair[1][-1]
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
process_groups_dict=process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
@@ -141,8 +154,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
@@ -167,15 +179,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
rst_dict = layout_converter.all_to_all_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -188,7 +199,12 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
source_spec = source_layout.sharding_spec
|
||||
tensor_dims = source_spec.dims
|
||||
for f_index in range(tensor_dims - 1):
|
||||
@@ -229,7 +245,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
@@ -252,8 +268,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
@@ -278,16 +293,15 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_dict = {0: [0]}
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
rst_dict = layout_converter.shard_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -301,7 +315,11 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
|
||||
@@ -329,7 +347,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
shard_dim = index
|
||||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
@@ -340,8 +358,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
@@ -399,7 +416,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
@@ -407,16 +424,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [R,S01,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
# [S01,R,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
@@ -505,21 +520,19 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
# [R,S0,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
if rank in (0, 1):
|
||||
sharded_tensor_0 = torch.zeros(2, 1)
|
||||
@@ -554,3 +567,4 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
return tensor
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user