mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline (#2261)
This commit is contained in:
@@ -8,9 +8,9 @@ from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
@@ -46,7 +46,7 @@ class MetaInfoProp:
|
||||
"""
|
||||
Check if the node is inplace operation.
|
||||
"""
|
||||
if node.op == 'call_method':
|
||||
if node.op == 'call_module':
|
||||
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
||||
elif node.op == "call_function":
|
||||
return node.target in OUTPUT_SAVED_OPS
|
||||
@@ -102,56 +102,51 @@ class MetaInfoProp:
|
||||
meta_info: MetaInfo
|
||||
|
||||
# set data_ptr for input_tensor in MetaInfo class
|
||||
input_tensor: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensor: List[torch.Tensor] = meta_info.fwd_out
|
||||
input_tensors: List[torch.Tensor] = meta_info.fwd_in
|
||||
buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer
|
||||
output_tensors: List[torch.Tensor] = meta_info.fwd_out
|
||||
|
||||
if len(input_tensor) > 0:
|
||||
if self._is_inplace(node):
|
||||
# inplace operation will not create new tensor, and it only has one parent node
|
||||
# TODO: Verify this observation
|
||||
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
|
||||
parent_node = list(node._input_nodes.keys())[0]
|
||||
parent_tensor = parent_node.meta.get("fwd_out")[0]
|
||||
parent_tensor: torch.Tensor
|
||||
for tensor in input_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
for tensor in buffer_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
for tensor in output_tensors:
|
||||
tensor.data_ptr = parent_tensor.data_ptr
|
||||
|
||||
else:
|
||||
for par in node._input_nodes:
|
||||
if par.meta:
|
||||
if len(par.meta["fwd_out"]) > 0:
|
||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||
for tensor in par.meta["fwd_out"]:
|
||||
tensor: torch.Tensor
|
||||
target_tensor = next(
|
||||
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
target_tensor.data_ptr = tensor.data_ptr
|
||||
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||
for tensor in par.meta.get("fwd_out", []):
|
||||
tensor: torch.Tensor
|
||||
target_input_tensor = next(
|
||||
(x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||
if target_input_tensor is not None:
|
||||
target_input_tensor.data_ptr = tensor.data_ptr
|
||||
|
||||
# set data_ptr for tensor in input_tensor that is not set
|
||||
for tensor in input_tensor:
|
||||
for tensor in input_tensors:
|
||||
if not tensor.data_ptr():
|
||||
self._set_data_ptr(tensor)
|
||||
|
||||
# attach it to graph_info
|
||||
graph_info.fwd_in = input_tensor
|
||||
|
||||
if self._is_inplace(node):
|
||||
# inplace operation will not create new tensor
|
||||
# set data_ptr for buffer_tensor and output_tensor of current node
|
||||
for tensor in input_tensor:
|
||||
tensor: torch.Tensor
|
||||
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
||||
None)
|
||||
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
||||
None)
|
||||
target_buffer_tensor.data_ptr = tensor.data_ptr
|
||||
target_output_tensor.data_ptr = tensor.data_ptr
|
||||
# attach them to graph_info
|
||||
graph_info.fwd_tmp = buffer_tensor
|
||||
graph_info.fwd_out = output_tensor
|
||||
|
||||
else:
|
||||
# set data_ptr for buffer_tensor
|
||||
for tensor in buffer_tensor:
|
||||
for tensor in buffer_tensors:
|
||||
self._set_data_ptr(tensor)
|
||||
# attach it to graph_info
|
||||
graph_info.fwd_tmp = buffer_tensor
|
||||
|
||||
# set data_ptr for output_tensor
|
||||
for tensor in output_tensor:
|
||||
for tensor in output_tensors:
|
||||
self._set_data_ptr(tensor)
|
||||
# attach it to graph_info
|
||||
graph_info.fwd_out = output_tensor
|
||||
|
||||
# attach them to graph_info
|
||||
graph_info.fwd_in = input_tensors
|
||||
graph_info.fwd_tmp = buffer_tensors
|
||||
graph_info.fwd_out = output_tensors
|
||||
|
||||
# fetch other memory informations
|
||||
memory_cost = meta_info.memory_cost
|
||||
|
Reference in New Issue
Block a user