mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[autoparallel] added matmul handler (#1763)
* [autoparallel] added matmul handler * polish code
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
@@ -175,6 +174,9 @@ class ShardingSpec:
|
||||
dim_partition_dict=None,
|
||||
sharding_sequence=None):
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
if isinstance(entire_shape, (list, tuple)):
|
||||
entire_shape = torch.Size(entire_shape)
|
||||
self.entire_shape = entire_shape
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
|
Reference in New Issue
Block a user