diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index dcc6bba9e..fbd5d5e36 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -220,7 +220,9 @@ if CODEGEN_AVAILABLE: self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) # find the chunk regions - self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem) + self.chunk_region_search = ChunkRegionSearch( + meta_graph, max_memory, print_mem + ) self.chunk_infos = self.chunk_region_search.search_region() def _gen_python_code( diff --git a/colossalai/autochunk/chunk_region_search.py b/colossalai/autochunk/chunk_region_search.py index 76b02cade..7a0e8a36c 100644 --- a/colossalai/autochunk/chunk_region_search.py +++ b/colossalai/autochunk/chunk_region_search.py @@ -1,8 +1,13 @@ +import copy + +from .chunk_selector import ChunkSelector from .index_tracer import IndexTracer from .memory_estiamtor import MemoryEstimator -from .chunk_selector import ChunkSelector -import copy -from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape +from .utils import ( + get_node_shape, + is_non_compute_node, + is_non_compute_node_except_placeholder, +) class ChunkRegionSearch(object): @@ -11,7 +16,7 @@ class ChunkRegionSearch(object): self.print_mem = print_mem self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() - self.memory_estimator = MemoryEstimator(self.index_tracer) + self.memory_estimator = MemoryEstimator() self.chunk_selector = ChunkSelector( self.index_tracer, self.memory_estimator, max_memory=max_memory ) @@ -211,4 +216,3 @@ class ChunkRegionSearch(object): self.index_tracer.node_list, chunk_infos, print_mem=True ) return chunk_infos - diff --git a/colossalai/autochunk/memory_estiamtor.py b/colossalai/autochunk/memory_estiamtor.py index c3d8b1803..034f59e52 100644 --- a/colossalai/autochunk/memory_estiamtor.py +++ b/colossalai/autochunk/memory_estiamtor.py @@ -16,7 +16,7 @@ from .utils import ( class MemoryEstimator(object): - def __init__(self, index_tracer: IndexTracer) -> None: + def __init__(self) -> None: pass def _get_meta_node_size(self, x):