[inference]Add alibi to flash attn function (#5678)

* add alibi to flash attn function

* rm redundant modifications
This commit is contained in:
yuehuayingxueluo
2024-04-30 19:35:05 +08:00
committed by GitHub
parent ef8e4ffe31
commit f79963199c
2 changed files with 6 additions and 13 deletions

View File

@@ -121,9 +121,7 @@ class InferenceEngine:
casuallm = _supported_models[arch](hf_config)
if isinstance(casuallm, AutoModelForCausalLM):
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
model = (
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
)
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
else:
model = _supported_models[arch](hf_config)
else: