mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
|
||||
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
|
||||
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
|
||||
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
|
||||
bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
|
||||
parameter=0,
|
||||
temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
|
||||
activation_size([x_tensor, y_tensor]),
|
||||
buffer=0)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
|
||||
parameter=0,
|
||||
temp=activation_size([output_tensor]) * 3
|
||||
+ activation_size([condition_tensor])
|
||||
- activation_size([x_tensor, y_tensor]),
|
||||
buffer=0,
|
||||
)
|
||||
|
||||
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
|
||||
total_mem_cost = MemoryCost(
|
||||
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
|
||||
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
|
||||
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
|
||||
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
|
||||
)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
|
Reference in New Issue
Block a user