This commit is contained in:
wangbluo 2024-08-19 10:15:16 +00:00
parent 88b3f0698c
commit 1f703e0ef4

View File

@ -452,4 +452,4 @@ def all_to_all_uneven(
assert ( assert (
inputs.requires_grad inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program." ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communicatio) return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)