mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] fix emerged bugs after updating transformers (#4526)
This commit is contained in:
@@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any:
|
||||
merged_data = []
|
||||
for elem_batch in zip(*flattened_data):
|
||||
if isinstance(elem_batch[0], torch.Tensor):
|
||||
merged_data.append(torch.cat(elem_batch, dim=0))
|
||||
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
|
||||
merged_data.append(None)
|
||||
else:
|
||||
merged_data.append(torch.cat(elem_batch, dim=0))
|
||||
else:
|
||||
merged_data.append(list(elem_batch))
|
||||
return tree_unflatten(merged_data, tree_spec)
|
||||
|
Reference in New Issue
Block a user