mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
from functools import reduce
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..registry import meta_profiler_module
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
|
||||
w_hh: torch.Tensor) -> Tuple[int, int]:
|
||||
def _rnn_flops(
|
||||
flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor
|
||||
) -> Tuple[int, int]:
|
||||
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
|
||||
|
||||
# matrix matrix mult ih state and internal state
|
||||
@@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
|
||||
flops = 0
|
||||
macs = 0
|
||||
for i in range(self.num_layers):
|
||||
w_ih = self.__getattr__('weight_ih_l' + str(i))
|
||||
w_hh = self.__getattr__('weight_hh_l' + str(i))
|
||||
w_ih = self.__getattr__("weight_ih_l" + str(i))
|
||||
w_hh = self.__getattr__("weight_hh_l" + str(i))
|
||||
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
|
||||
if self.bias:
|
||||
b_ih = self.__getattr__('bias_ih_l' + str(i))
|
||||
b_hh = self.__getattr__('bias_hh_l' + str(i))
|
||||
b_ih = self.__getattr__("bias_ih_l" + str(i))
|
||||
b_hh = self.__getattr__("bias_hh_l" + str(i))
|
||||
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
|
||||
flops *= reduce(operator.mul, input.shape[:2])
|
||||
macs *= reduce(operator.mul, input.shape[:2])
|
||||
@@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
|
||||
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
w_ih = self.__getattr__('weight_ih_l')
|
||||
w_hh = self.__getattr__('weight_hh_l')
|
||||
w_ih = self.__getattr__("weight_ih_l")
|
||||
w_hh = self.__getattr__("weight_hh_l")
|
||||
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
|
||||
if self.bias:
|
||||
b_ih = self.__getattr__('bias_ih_l')
|
||||
b_hh = self.__getattr__('bias_hh_l')
|
||||
b_ih = self.__getattr__("bias_ih_l")
|
||||
b_hh = self.__getattr__("bias_hh_l")
|
||||
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
|
||||
flops *= input.shape[0]
|
||||
macs *= input.shape[0]
|
||||
|
Reference in New Issue
Block a user