mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -26,8 +27,8 @@ if HAS_TRITON:
|
||||
X_GATE2,
|
||||
X_UP,
|
||||
Y,
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
@@ -41,9 +42,9 @@ if HAS_TRITON:
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
|
||||
# Write output
|
||||
@@ -58,8 +59,8 @@ if HAS_TRITON:
|
||||
X_GATE2_GRAD,
|
||||
X_UP_GRAD,
|
||||
Y_GRAD,
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
@@ -76,10 +77,10 @@ if HAS_TRITON:
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
|
||||
|
||||
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
@@ -147,14 +148,9 @@ if HAS_TRITON:
|
||||
# restore setting
|
||||
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
|
||||
# enqueue kernel
|
||||
_llama_act_combine_forward[(M,)](x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
y,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
_llama_act_combine_forward[(M,)](
|
||||
x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||
)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@@ -166,20 +162,25 @@ if HAS_TRITON:
|
||||
|
||||
# init grad
|
||||
y_grad = grad_outputs[0]
|
||||
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
|
||||
x_gate2), torch.empty_like(x_up)
|
||||
x_gate1_grad, x_gate2_grad, x_up_grad = (
|
||||
torch.empty_like(x_gate1),
|
||||
torch.empty_like(x_gate2),
|
||||
torch.empty_like(x_up),
|
||||
)
|
||||
|
||||
# enqueue kernel
|
||||
_llama_act_combine_backward[(M,)](x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
x_gate1_grad,
|
||||
x_gate2_grad,
|
||||
x_up_grad,
|
||||
y_grad,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
_llama_act_combine_backward[(M,)](
|
||||
x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
x_gate1_grad,
|
||||
x_gate2_grad,
|
||||
x_up_grad,
|
||||
y_grad,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
|
||||
return x_gate_grad, x_up_grad, None, None
|
||||
|
Reference in New Issue
Block a user