[Doc] add more doc for ColoTensor. (#1458)

This commit is contained in:
Jiarui Fang
2022-08-16 10:38:41 +08:00
committed by GitHub
parent a1476ea882
commit 36824a304c
4 changed files with 46 additions and 18 deletions

View File

@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import operator
from colossalai.tensor import ProcessGroup
from colossalai.tensor.distspec import shard
from colossalai.tensor.distspec import ShardSpec
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
@@ -85,13 +85,13 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
for shard_type, module in annotation_record.items():
# add row sharding spec
if shard_type == 'row':
dist_spec = shard(dims=[-1], num_partitions=[world_size])
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
setattr(module.weight, 'pg', process_group)
setattr(module.weight, 'dist_spec', dist_spec)
setattr(module.weight, 'comp_spec', comp_spec)
elif shard_type == 'col':
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
setattr(module.weight, 'pg', process_group)
@@ -99,7 +99,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
setattr(module.weight, 'comp_spec', weight_comp_spec)
if module.bias is not None:
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
setattr(module.bias, 'pg', process_group)