[shardformer] fix emerged bugs after updating transformers (#4526)

This commit is contained in:
Baizhou Zhang
2023-08-29 11:25:05 +08:00
committed by GitHub
parent c554b7f559
commit 0387a47e63
2 changed files with 9 additions and 2 deletions

View File

@@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any:
merged_data = []
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
merged_data.append(torch.cat(elem_batch, dim=0))
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=0))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)