mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Doc] add more doc for ColoTensor. (#1458)
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import operator
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor.distspec import shard
|
||||
from colossalai.tensor.distspec import ShardSpec
|
||||
from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
|
||||
|
||||
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
||||
@@ -85,13 +85,13 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
for shard_type, module in annotation_record.items():
|
||||
# add row sharding spec
|
||||
if shard_type == 'row':
|
||||
dist_spec = shard(dims=[-1], num_partitions=[world_size])
|
||||
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
|
||||
comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
setattr(module.weight, 'dist_spec', dist_spec)
|
||||
setattr(module.weight, 'comp_spec', comp_spec)
|
||||
elif shard_type == 'col':
|
||||
weight_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
weight_comp_spec.output_replicate = False
|
||||
setattr(module.weight, 'pg', process_group)
|
||||
@@ -99,7 +99,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
|
||||
setattr(module.weight, 'comp_spec', weight_comp_spec)
|
||||
|
||||
if module.bias is not None:
|
||||
bias_dist_spec = shard(dims=[0], num_partitions=[world_size])
|
||||
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
|
||||
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
bias_comp_spec.output_replicate = False
|
||||
setattr(module.bias, 'pg', process_group)
|
||||
|
Reference in New Issue
Block a user