[autoparallel] added matmul handler (#1763)

* [autoparallel] added matmul handler

* polish code
This commit is contained in:
Frank Lee
2022-11-01 15:14:53 +08:00
committed by GitHub
parent 4df0194976
commit f3f19a5c47
7 changed files with 725 additions and 28 deletions

View File

@@ -1,6 +1,5 @@
import operator
from copy import deepcopy
from enum import Enum
from functools import reduce
import torch
@@ -175,6 +174,9 @@ class ShardingSpec:
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
if isinstance(entire_shape, (list, tuple)):
entire_shape = torch.Size(entire_shape)
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence