mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
320 lines
12 KiB
Python
320 lines
12 KiB
Python
import copy
|
|
from typing import Dict, List, Tuple
|
|
|
|
from torch.fx.node import Node
|
|
|
|
from .estimate_memory import EstimateMemory
|
|
from .reorder_graph import ReorderGraph
|
|
from .select_chunk import SelectChunk
|
|
from .trace_flow import TraceFlow
|
|
from .trace_indice import TraceIndice
|
|
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
|
|
|
|
|
class SearchChunk(object):
|
|
"""
|
|
This is the core class for AutoChunk.
|
|
|
|
It defines the framework of the strategy of AutoChunk.
|
|
Chunks will be selected one by one utill search stops.
|
|
|
|
The chunk search is as follows:
|
|
1. find the peak memory node
|
|
2. find the max chunk region according to the peak memory node
|
|
3. find all possible chunk regions in the max chunk region
|
|
4. find the best chunk region for current status
|
|
5. goto 1
|
|
|
|
Attributes:
|
|
gm: graph model
|
|
print_mem (bool): print estimated memory
|
|
trace_index: trace the flow of every dim of every node to find all free dims
|
|
trace_flow: determine the region chunk strategy
|
|
reorder_graph: reorder nodes to improve chunk efficiency
|
|
estimate_memory: estimate memory with chunk
|
|
select_chunk: select the best chunk region
|
|
|
|
Args:
|
|
gm: graph model
|
|
max_memory (int): max memory in MB
|
|
print_mem (bool): print estimated memory
|
|
"""
|
|
|
|
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
|
self.print_mem = print_mem
|
|
self.print_progress = print_progress
|
|
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
|
self.estimate_memory = EstimateMemory()
|
|
self._init_trace()
|
|
self.trace_flow = TraceFlow(self.trace_indice)
|
|
self.reorder_graph = ReorderGraph(self.trace_indice)
|
|
self.select_chunk = SelectChunk(
|
|
self.trace_indice,
|
|
self.estimate_memory,
|
|
self.reorder_graph,
|
|
max_memory=max_memory,
|
|
)
|
|
|
|
def _init_trace(self) -> None:
|
|
"""
|
|
find the max trace range for every node
|
|
reduce the computation complexity of trace_indice
|
|
"""
|
|
# find all max ranges
|
|
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
|
|
cur_node_idx = len(self._get_free_var_idx())
|
|
max_chunk_region_list = []
|
|
while True:
|
|
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
|
cur_node_idx = max_chunk_region[1]
|
|
if cur_node_idx == len(active_nodes) - 1:
|
|
break
|
|
max_chunk_region_list.append(max_chunk_region)
|
|
|
|
# nothing to limit for the first range
|
|
max_chunk_region_list = max_chunk_region_list[1:]
|
|
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
|
|
|
|
# set trace range and do the trace
|
|
if self.print_progress:
|
|
get_logger().info("AutoChunk start tracing indice")
|
|
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
|
|
self.trace_indice.trace_indice()
|
|
|
|
def _find_peak_node(self, mem_peak: List) -> int:
|
|
max_value = max(mem_peak)
|
|
max_idx = mem_peak.index(max_value)
|
|
return max_idx
|
|
|
|
def _get_free_var_idx(self) -> List:
|
|
"""
|
|
Get free var index
|
|
|
|
Returns:
|
|
free_var_idx (List): all indexs of free vars
|
|
"""
|
|
free_var_idx = []
|
|
for idx, n in enumerate(self.trace_indice.node_list):
|
|
if n.op == "placeholder" and get_node_shape(n) is not None:
|
|
free_var_idx.append(idx)
|
|
return free_var_idx
|
|
|
|
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
|
|
"""
|
|
Search max chunk region according to peak memory node
|
|
|
|
Chunk region starts extending from the peak node, stops where free var num is min
|
|
|
|
Args:
|
|
active_node (List): active node status for every node
|
|
peak_node_idx (int): peak memory node idx
|
|
chunk_regions (List): chunk region infos
|
|
|
|
Returns:
|
|
chunk_region_start (int)
|
|
chunk_region_end (int)
|
|
"""
|
|
free_vars = self._get_free_var_idx()
|
|
free_var_num = len(free_vars)
|
|
active_node_num = [len(i) for i in active_node]
|
|
min_active_node_num = min(active_node_num[free_var_num:])
|
|
threshold = max(free_var_num, min_active_node_num)
|
|
|
|
# from peak_node to free_var
|
|
inside_flag = False
|
|
chunk_region_start = free_var_num
|
|
for i in range(peak_node_idx, -1, -1):
|
|
if active_node_num[i] <= threshold:
|
|
inside_flag = True
|
|
if inside_flag and active_node_num[i] > threshold:
|
|
chunk_region_start = i + 1
|
|
break
|
|
|
|
# from peak_node to len-2
|
|
inside_flag = False
|
|
chunk_region_end = len(active_node) - 1
|
|
for i in range(peak_node_idx, len(active_node)):
|
|
if active_node_num[i] <= threshold:
|
|
inside_flag = True
|
|
if inside_flag and active_node_num[i] > threshold:
|
|
chunk_region_end = i
|
|
break
|
|
|
|
# avoid chunk regions overlap
|
|
if chunk_regions is not None:
|
|
for i in chunk_regions:
|
|
region = i["region"]
|
|
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
|
return None
|
|
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
|
chunk_region_start = region[1] + 1
|
|
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
|
chunk_region_end = region[0] - 1
|
|
return chunk_region_start, chunk_region_end
|
|
|
|
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
|
"""
|
|
Find chunk info for a region.
|
|
|
|
We are given the region start and region end, and need to find out all chunk info for it.
|
|
We first loop every dim of start node and end node, to see if we can find dim pair,
|
|
which is linked in a flow and not computed.
|
|
If found, we then search flow in the whole region to find out all chunk infos.
|
|
|
|
Args:
|
|
input_trace (List): node's input trace in region
|
|
output_trace (List): node's output trace in region
|
|
start_idx (int): region start node index
|
|
end_idx (int): region end node index
|
|
|
|
Returns:
|
|
chunk_infos: possible regions found
|
|
"""
|
|
start_traces = input_trace[start_idx]
|
|
end_trace = output_trace[end_idx]
|
|
end_node = self.trace_indice.node_list[end_idx]
|
|
chunk_infos = []
|
|
for end_dim, _ in enumerate(end_trace["indice"]):
|
|
if len(start_traces) > 1:
|
|
continue
|
|
for start_node, start_trace in start_traces.items():
|
|
for start_dim, _ in enumerate(start_trace["indice"]):
|
|
# dim size cannot be 1
|
|
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
|
continue
|
|
# must have users
|
|
if len(end_node.users) == 0:
|
|
continue
|
|
# check index source align
|
|
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
|
continue
|
|
# check index copmute
|
|
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
|
continue
|
|
# flow search
|
|
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
|
if chunk_info is None:
|
|
continue
|
|
# check index copmute
|
|
if not self.trace_flow.check_index_duplicate(chunk_info):
|
|
continue
|
|
chunk_infos.append(chunk_info)
|
|
return chunk_infos
|
|
|
|
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
|
"""
|
|
Search every possible region within the max chunk region.
|
|
|
|
Args:
|
|
max_chunk_region (Tuple)
|
|
peak_node (Node): peak memory node
|
|
|
|
Returns:
|
|
possible_chunk_region (List)
|
|
"""
|
|
possible_chunk_region = []
|
|
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
|
input_trace = [] # trace of a node's input nodes
|
|
for _, n in enumerate(self.trace_indice.node_list):
|
|
cur_trace = {}
|
|
for arg in n.args:
|
|
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
|
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
|
input_trace.append(cur_trace)
|
|
|
|
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
|
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
|
# skip non compute nodes
|
|
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
|
self.trace_indice.node_list[end_idx]):
|
|
continue
|
|
|
|
# select free dim
|
|
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
|
if len(chunk_info) > 0:
|
|
possible_chunk_region.extend(chunk_info)
|
|
return possible_chunk_region
|
|
|
|
def _step_search(
|
|
self,
|
|
mem_peak: List[float],
|
|
active_node: List[List[Node]],
|
|
chunk_infos: List[Dict],
|
|
) -> Dict:
|
|
"""
|
|
Find one chunk region
|
|
|
|
The chunk search is as follows:
|
|
1. find the peak memory node
|
|
2. find the max chunk region according to the peak memory node
|
|
3. find all possible chunk regions in the max chunk region
|
|
4. find the best chunk region for current status
|
|
|
|
Args:
|
|
mem_peak (List): peak memory for every node
|
|
active_node (List[List[Node]]): active node for every node
|
|
chunk_infos (List[Dict]): all chunk info
|
|
|
|
Returns:
|
|
best_chunk_region (Dict)
|
|
"""
|
|
peak_node = self._find_peak_node(mem_peak)
|
|
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
|
|
if max_chunk_region == None:
|
|
return None
|
|
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
|
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
|
|
max_chunk_region, mem_peak)
|
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
|
return best_chunk_region
|
|
|
|
def _stop_search(self, init_mem_peak, mem_peak):
|
|
sorted_init_mem_peak = sorted(init_mem_peak)
|
|
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
|
|
return True
|
|
return False
|
|
|
|
def search_region(self) -> Dict:
|
|
"""
|
|
Search all chunk regions:
|
|
1. Estimate current memory
|
|
2. Find best chunk for current memory
|
|
3. goto 1
|
|
|
|
Returns:
|
|
chunk_infos (Dict)
|
|
"""
|
|
if self.print_progress:
|
|
get_logger().info("AutoChunk start searching chunk regions")
|
|
|
|
chunk_infos = []
|
|
(
|
|
init_mem_peak,
|
|
_,
|
|
active_node,
|
|
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
|
|
mem_peak = init_mem_peak
|
|
|
|
while True:
|
|
chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
|
|
if chunk_info is None:
|
|
break
|
|
chunk_infos.append(chunk_info)
|
|
|
|
(
|
|
mem_peak,
|
|
_,
|
|
active_node,
|
|
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
|
|
|
if self.print_progress:
|
|
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
|
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
|
|
|
if self._stop_search(init_mem_peak, mem_peak):
|
|
break
|
|
if self.print_mem:
|
|
self.print_mem = False
|
|
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
|
|
return chunk_infos
|