This commit is contained in:
oahzxl
2023-01-10 09:59:47 +08:00
parent 1be0ac3cbf
commit 7d4abaa525
3 changed files with 113 additions and 16 deletions

View File

@@ -15,6 +15,10 @@ from .utils import (
class EstimateMemory(object):
"""
Estimate memory with chunk
"""
def __init__(self) -> None:
pass
@@ -31,8 +35,6 @@ class EstimateMemory(object):
}
out_size = activation_size(fwd_out)
out_node = [n.name] if out_size > 0 else []
# if any(i in n.name for i in ['transpose', 'permute', 'view']):
# out_size = 0
return out_size, out_node
def _get_output_node_size(self, n):
@@ -184,10 +186,24 @@ class EstimateMemory(object):
def estimate_chunk_inference_mem(
self,
node_list,
node_list: List,
chunk_infos=None,
print_mem=False,
):
"""
Estimate inference memory with chunk
Args:
node_list (List): _description_
chunk_infos (Dict): Chunk information. Defaults to None.
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
Returns:
act_memory_peak_log (List): peak memory of every node
act_memory_after_node_log (List): memory after excuting every node
active_node_list_log (List): active nodes of every node. active nodes refer to
nodes generated but not deleted.
"""
act_memory = 0.0
act_memory_peak_log = []
act_memory_after_node_log = []