[fx] fix offload codegen test (#1648)

* [fx] fix offload codegen test

* [fx] modify typing
This commit is contained in:
Boyuan Yao
2022-09-27 10:25:27 +08:00
committed by GitHub
parent 45b39a692a
commit 5d0fdb9cb4
2 changed files with 12 additions and 12 deletions

View File

@@ -83,13 +83,13 @@ def _run_offload_codegen(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear2":
setattr(node, "activation_offload", (1, True, True))
setattr(node, "activation_offload", [1, True, True])
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
setattr(node, "activation_offload", [2, False, True])
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
@@ -138,13 +138,13 @@ def _run_offload_codegen_torch11(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", (0, True, False))
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear2":
setattr(node, "activation_offload", (1, True, True))
setattr(node, "activation_offload", [1, True, True])
if node.name == "linear4":
setattr(node, "activation_offload", (2, False, True))
setattr(node, "activation_offload", [2, False, True])
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)