[autoparallel] support distributed dataloader option (#1906)

* [autoparallel] support distributed dataloader option

* update output handler to support ddp dataloader

* poish code
This commit is contained in:
YuliangLiu0306
2022-11-17 20:11:53 +08:00
committed by GitHub
parent 6630d45546
commit 0da1d00399
18 changed files with 257 additions and 61 deletions

View File

@@ -43,8 +43,11 @@ class OperationData:
def __post_init__(self):
# if no logical shape is specified, use the data shape as the logical shape
if self.logical_shape is None and isinstance(self.data, torch.Tensor):
self.logical_shape = self.data.shape
if self.logical_shape is None:
if isinstance(self.data, torch.Tensor):
self.logical_shape = self.data.shape
elif isinstance(self.data, tuple):
self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data])
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'