[inference]Optimize the usage of the mid tensors space in flash attn (#5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
This commit is contained in:
yuehuayingxueluo
2024-01-26 14:00:10 +08:00
committed by GitHub
parent af8359c430
commit 4f28cb43c0
16 changed files with 199 additions and 57 deletions

View File

@@ -3,8 +3,7 @@ import random
import numpy as np
import pytest
import torch
import transformers
from transformers import AutoTokenizer, GenerationConfig
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import InferenceConfig
@@ -22,8 +21,8 @@ def setup_seed(seed):
def check_inference_engine(test_cai=False):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
).cuda()
@@ -81,4 +80,4 @@ def test_inference_engine():
if __name__ == "__main__":
test_inference_engine()
test_inference_engine()