GPU Inference Server (#1112)

* feat: local inference server

* fix: source to use bash + vars

* chore: isort and black

* fix: make file + inference mode

* chore: logging

* refactor: remove old links

* fix: add new env vars

* feat: hf inference server

* refactor: remove old links

* test: batch and single response

* chore: black + isort

* separate gpu and cpu dockerfiles

* moved gpu to separate dockerfile

* Fixed test endpoints

* Edits to API. server won't start due to failed instantiation error

* Method signature

* fix: gpu_infer

* tests: fix tests

---------

Co-authored-by: Andriy Mulyar <andriy.mulyar@gmail.com>
This commit is contained in:
Zach Nussbaum
2023-07-21 14:13:29 -05:00
committed by GitHub
parent 58f0fcab57
commit 8aba2c9009
14 changed files with 271 additions and 112 deletions

View File

@@ -2,30 +2,22 @@
Use the OpenAI python API to test gpt4all models.
"""
import openai
openai.api_base = "http://localhost:4891/v1"
openai.api_key = "not needed for a local LLM"
def test_completion():
model = "gpt4all-j-v1.3-groovy"
model = "ggml-mpt-7b-chat.bin"
prompt = "Who is Michael Jordan?"
response = openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=False
model=model, prompt=prompt, max_tokens=50, temperature=0.28, top_p=0.95, n=1, echo=True, stream=False
)
assert len(response['choices'][0]['text']) > len(prompt)
print(response)
def test_streaming_completion():
model = "gpt4all-j-v1.3-groovy"
model = "ggml-mpt-7b-chat.bin"
prompt = "Who is Michael Jordan?"
tokens = []
for resp in openai.Completion.create(
@@ -42,10 +34,12 @@ def test_streaming_completion():
assert (len(tokens) > 0)
assert (len("".join(tokens)) > len(prompt))
# def test_chat_completions():
# model = "gpt4all-j-v1.3-groovy"
# prompt = "Who is Michael Jordan?"
# response = openai.ChatCompletion.create(
# model=model,
# messages=[]
# )
def test_batched_completion():
model = "ggml-mpt-7b-chat.bin"
prompt = "Who is Michael Jordan?"
response = openai.Completion.create(
model=model, prompt=[prompt] * 3, max_tokens=50, temperature=0.28, top_p=0.95, n=1, echo=True, stream=False
)
assert len(response['choices'][0]['text']) > len(prompt)
assert len(response['choices']) == 3