mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fx] fix offload codegen test (#1648)
* [fx] fix offload codegen test * [fx] modify typing
This commit is contained in:
@@ -83,13 +83,13 @@ def _run_offload_codegen(rank):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
@@ -138,13 +138,13 @@ def _run_offload_codegen_torch11(rank):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
|
Reference in New Issue
Block a user