add autotune (#4822)

This commit is contained in:
Xu Kai
2023-09-28 13:47:35 +08:00
committed by GitHub
parent 822051d888
commit c3bef20478
2 changed files with 182 additions and 5 deletions

View File

@@ -3,7 +3,8 @@
import torch
import triton
import triton.language as tl
from auto_gptq.nn_modules.triton_utils import custom_autotune
from .custom_autotune import autotune, matmul248_kernel_config_pruner
@triton.jit
@@ -94,7 +95,7 @@ def silu(x):
return x * tl.sigmoid(x)
@custom_autotune.autotune(
@autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
@@ -124,7 +125,7 @@ def silu(x):
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"early_config_prune": matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
@@ -266,7 +267,7 @@ def cai_gptq_matmul_248_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
@autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
@@ -296,7 +297,7 @@ def cai_gptq_matmul_248_kernel(
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"early_config_prune": matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},