This commit is contained in:
wangbluo 2024-08-19 09:23:10 +00:00
parent 12b44012d9
commit 2eb36839c6

View File

@ -452,4 +452,4 @@ def all_to_all_uneven(
assert ( assert (
inputs.requires_grad inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program." ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communicatio)