mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
add colossalai kernel module (#55)
This commit is contained in:
24
colossalai/kernel/jit/bias_dropout_add.py
Normal file
24
colossalai/kernel/jit/bias_dropout_add.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_train(x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return bias_dropout_add(x, bias, residual, prob, True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_inference(x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return bias_dropout_add(x, bias, residual, prob, False)
|
Reference in New Issue
Block a user