mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -3,8 +3,7 @@ import random
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
@@ -22,8 +21,8 @@ def setup_seed(seed):
|
||||
def check_inference_engine(test_cai=False):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||
)
|
||||
).cuda()
|
||||
@@ -81,4 +80,4 @@ def test_inference_engine():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference_engine()
|
||||
test_inference_engine()
|
||||
|
Reference in New Issue
Block a user