mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
add autotune (#4822)
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from auto_gptq.nn_modules.triton_utils import custom_autotune
|
||||
|
||||
from .custom_autotune import autotune, matmul248_kernel_config_pruner
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -94,7 +95,7 @@ def silu(x):
|
||||
return x * tl.sigmoid(x)
|
||||
|
||||
|
||||
@custom_autotune.autotune(
|
||||
@autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
@@ -124,7 +125,7 @@ def silu(x):
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"early_config_prune": matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
@@ -266,7 +267,7 @@ def cai_gptq_matmul_248_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@custom_autotune.autotune(
|
||||
@autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
|
||||
@@ -296,7 +297,7 @@ def cai_gptq_matmul_248_kernel(
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"early_config_prune": matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
|
Reference in New Issue
Block a user