[fx] fix offload codegen test (#1648)

* [fx] fix offload codegen test

* [fx] modify typing
This commit is contained in:
Boyuan Yao
2022-09-27 10:25:27 +08:00
committed by GitHub
parent 45b39a692a
commit 5d0fdb9cb4
2 changed files with 12 additions and 12 deletions

View File

@@ -1,6 +1,6 @@
import colossalai
import torch
from typing import List, Callable, Any, Tuple, Dict
from typing import List, Callable, Any, Tuple, Dict, Iterable
try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
@@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
act_offload_label = node.activation_offload
if current_region == None:
@@ -796,7 +796,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
@@ -999,7 +999,7 @@ else:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)