mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[autoparallel] Patch meta information of torch.where
(#2822)
* [autoparallel] patch meta information of torch.where * [autoparallel] pre-commit modified
This commit is contained in:
@@ -6,3 +6,4 @@ from .linear import *
|
||||
from .norm import *
|
||||
from .pooling import *
|
||||
from .tensor import *
|
||||
from .where import *
|
||||
|
@@ -0,0 +1,60 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ["where_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.where)
|
||||
def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.where meta information generator
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
condition_tensor, x_tensor, y_tensor, output_tensor = [arg.data for arg in args]
|
||||
|
||||
# compute cost
|
||||
fwd_compute_cost = 0
|
||||
|
||||
# if we need to broadcast the condition tensor, during backward we need to do a reduce_sum
|
||||
bwd_compute_cost = 0
|
||||
if x_tensor.shape != output_tensor.shape:
|
||||
bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [x_tensor])
|
||||
if y_tensor.shape != output_tensor.shape:
|
||||
bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [y_tensor])
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# memory cost
|
||||
# during the forward phase, torch.where will allocate memory for output tensor and condition tensor
|
||||
# during the backward phase, torch.where will allocate temp memory which is 3 times as output tensor, then generate
|
||||
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
|
||||
parameter=0,
|
||||
temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
|
||||
activation_size([x_tensor, y_tensor]),
|
||||
buffer=0)
|
||||
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [condition_tensor]
|
||||
fwd_buffer = []
|
||||
fwd_out = [output_tensor]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
Reference in New Issue
Block a user