GPU Inference Server (#1112)

* feat: local inference server

* fix: source to use bash + vars

* chore: isort and black

* fix: make file + inference mode

* chore: logging

* refactor: remove old links

* fix: add new env vars

* feat: hf inference server

* refactor: remove old links

* test: batch and single response

* chore: black + isort

* separate gpu and cpu dockerfiles

* moved gpu to separate dockerfile

* Fixed test endpoints

* Edits to API. server won't start due to failed instantiation error

* Method signature

* fix: gpu_infer

* tests: fix tests

---------

Co-authored-by: Andriy Mulyar <andriy.mulyar@gmail.com>
This commit is contained in:
Zach Nussbaum
2023-07-21 14:13:29 -05:00
committed by GitHub
parent 58f0fcab57
commit 8aba2c9009
14 changed files with 271 additions and 112 deletions

View File

@@ -1,8 +1,10 @@
import logging
from api_v1.settings import settings
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from starlette.requests import Request
from api_v1.settings import settings
log = logging.getLogger(__name__)
@@ -19,8 +21,9 @@ async def on_startup(app):
startup_msg = startup_msg_fmt.format(settings=settings)
log.info(startup_msg)
def startup_event_handler(app):
async def start_app() -> None:
await on_startup(app)
return start_app
return start_app

View File

@@ -1,9 +1,10 @@
from fastapi import APIRouter, Depends, Response, Security, status
from pydantic import BaseModel, Field
from typing import List, Dict
import logging
import time
from typing import Dict, List
from api_v1.settings import settings
from fastapi import APIRouter, Depends, Response, Security, status
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -11,11 +12,11 @@ logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class ChatCompletionMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = Field(..., description='The model to generate a completion from.')
messages: List[ChatCompletionMessage] = Field(..., description='The model to generate a completion from.')
@@ -26,11 +27,13 @@ class ChatCompletionChoice(BaseModel):
index: int
finish_reason: str
class ChatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = 'text_completion'
@@ -42,6 +45,7 @@ class ChatCompletionResponse(BaseModel):
router = APIRouter(prefix="/chat", tags=["Completions Endpoints"])
@router.post("/completions", response_model=ChatCompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
'''
@@ -53,11 +57,5 @@ async def chat_completion(request: ChatCompletionRequest):
created=time.time(),
model=request.model,
choices=[{}],
usage={
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0
}
usage={'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
)

View File

@@ -1,14 +1,16 @@
import json
from fastapi import APIRouter, Depends, Response, Security, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Iterable, AsyncIterable
import logging
from uuid import uuid4
from api_v1.settings import settings
from gpt4all import GPT4All
import time
from typing import Dict, List, Union
from uuid import uuid4
import aiohttp
import asyncio
from api_v1.settings import settings
from fastapi import APIRouter, Depends, Response, Security, status, HTTPException
from fastapi.responses import StreamingResponse
from gpt4all import GPT4All
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -16,14 +18,17 @@ logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class CompletionRequest(BaseModel):
model: str = Field(..., description='The model to generate a completion from.')
prompt: str = Field(..., description='The prompt to begin completing from.')
max_tokens: int = Field(7, description='Max tokens to generate')
temperature: float = Field(0, description='Model temperature')
top_p: float = Field(1.0, description='top_p')
n: int = Field(1, description='')
model: str = Field(settings.model, description='The model to generate a completion from.')
prompt: Union[List[str], str] = Field(..., description='The prompt to begin completing from.')
max_tokens: int = Field(None, description='Max tokens to generate')
temperature: float = Field(settings.temp, description='Model temperature')
top_p: float = Field(settings.top_k, description='top_p')
top_k: int = Field(settings.top_k, description='top_k')
n: int = Field(1, description='How many completions to generate for each prompt')
stream: bool = Field(False, description='Stream responses')
repeat_penalty: float = Field(settings.repeat_penalty, description='Repeat penalty')
class CompletionChoice(BaseModel):
@@ -58,7 +63,6 @@ class CompletionStreamResponse(BaseModel):
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
"""
Streams a GPT4All output to the client.
@@ -80,6 +84,27 @@ def stream_completion(output: Iterable, base_response: CompletionStreamResponse)
))]
yield f"data: {json.dumps(dict(chunk))}\n\n"
async def gpu_infer(payload, header):
async with aiohttp.ClientSession() as session:
try:
async with session.post(
settings.hf_inference_server_host, headers=header, data=json.dumps(payload)
) as response:
resp = await response.json()
return resp
except aiohttp.ClientError as e:
# Handle client-side errors (e.g., connection error, invalid URL)
logger.error(f"Client error: {e}")
except aiohttp.ServerError as e:
# Handle server-side errors (e.g., internal server error)
logger.error(f"Server error: {e}")
except json.JSONDecodeError as e:
# Handle JSON decoding errors
logger.error(f"JSON decoding error: {e}")
except Exception as e:
# Handle other unexpected exceptions
logger.error(f"Unexpected error: {e}")
@router.post("/", response_model=CompletionResponse)
async def completions(request: CompletionRequest):
@@ -87,42 +112,104 @@ async def completions(request: CompletionRequest):
Completes a GPT4All model response.
'''
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
if request.model != settings.model:
raise HTTPException(status_code=400, detail=f"The GPT4All inference server is booted to only infer: `{settings.model}`")
output = model.generate(prompt=request.prompt,
n_predict=request.max_tokens,
streaming=request.stream,
top_k=20,
top_p=request.top_p,
temp=request.temperature,
n_batch=1024,
repeat_penalty=1.2,
repeat_last_n=10)
if settings.inference_mode == "gpu":
params = request.dict(exclude={'model', 'prompt', 'max_tokens', 'n'})
params["max_new_tokens"] = request.max_tokens
params["num_return_sequences"] = request.n
header = {"Content-Type": "application/json"}
payload = {"parameters": params}
if isinstance(request.prompt, list):
tasks = []
for prompt in request.prompt:
payload["inputs"] = prompt
task = gpu_infer(payload, header)
tasks.append(task)
results = await asyncio.gather(*tasks)
choices = []
for response in results:
scores = response["scores"] if "scores" in response else -1.0
choices.append(
dict(
CompletionChoice(
text=response["generated_text"], index=0, logprobs=scores, finish_reason='stop'
)
)
)
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=choices,
usage={'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
)
else:
# If streaming, we need to return a StreamingResponse
payload["inputs"] = request.prompt
resp = await gpu_infer(payload, header)
output = resp["generated_text"]
# this returns all logprobs
scores = resp["scores"] if "scores" in resp else -1.0
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(text=output, index=0, logprobs=scores, finish_reason='stop'))],
usage={'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
)
# If streaming, we need to return a StreamingResponse
if request.stream:
base_chunk = CompletionStreamResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[]
)
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
media_type="text/event-stream")
else:
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)
if isinstance(request.prompt, list):
if len(request.prompt) > 1:
raise HTTPException(status_code=400, detail="Can only infer one inference per request in CPU mode.")
else:
request.prompt = request.prompt[0]
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
output = model.generate(prompt=request.prompt,
max_tokens=request.max_tokens,
streaming=request.stream,
top_k=request.top_k,
top_p=request.top_p,
temp=request.temperature,
)
# If streaming, we need to return a StreamingResponse
if request.stream:
base_chunk = CompletionStreamResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[]
)
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
media_type="text/event-stream")
else:
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, # TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)

View File

@@ -1,22 +1,27 @@
import logging
from typing import Dict, List
from api_v1.settings import settings
from fastapi import APIRouter, Depends, Response, Security, status
from pydantic import BaseModel, Field
from typing import List, Dict
import logging
from api_v1.settings import settings
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class ListEnginesResponse(BaseModel):
data: List[Dict] = Field(..., description="All available models.")
class EngineResponse(BaseModel):
data: List[Dict] = Field(..., description="All available models.")
router = APIRouter(prefix="/engines", tags=["Search Endpoints"])
@router.get("/", response_model=ListEnginesResponse)
async def list_engines():
'''
@@ -29,10 +34,7 @@ async def list_engines():
@router.get("/{engine_id}", response_model=EngineResponse)
async def retrieve_engine(engine_id: str):
'''
'''
''' '''
raise NotImplementedError()
return EngineResponse()

View File

@@ -1,6 +1,7 @@
import logging
from fastapi import APIRouter
from fastapi.responses import JSONResponse
log = logging.getLogger(__name__)
router = APIRouter(prefix="/health", tags=["Health"])

View File

@@ -5,6 +5,14 @@ class Settings(BaseSettings):
app_environment = 'dev'
model: str = 'ggml-mpt-7b-chat.bin'
gpt4all_path: str = '/models'
inference_mode: str = "cpu"
hf_inference_server_host: str = "http://gpt4all_gpu:80/generate"
temp: float = 0.18
top_p: float = 1.0
top_k: int = 50
repeat_penalty: float = 1.18
settings = Settings()

View File

@@ -1,19 +1,19 @@
import os
import docs
import logging
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.cors import CORSMiddleware
from fastapi.logger import logger as fastapi_logger
from api_v1.settings import settings
from api_v1.api import router as v1_router
from api_v1 import events
import os
import docs
from api_v1 import events
from api_v1.api import router as v1_router
from api_v1.settings import settings
from fastapi import FastAPI, HTTPException, Request
from fastapi.logger import logger as fastapi_logger
from starlette.middleware.cors import CORSMiddleware
logger = logging.getLogger(__name__)
app = FastAPI(title='GPT4All API', description=docs.desc)
#CORS Configuration (in-case you want to deploy)
# CORS Configuration (in-case you want to deploy)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -29,14 +29,23 @@ app.include_router(v1_router, prefix='/v1')
app.add_event_handler('startup', events.startup_event_handler(app))
app.add_exception_handler(HTTPException, events.on_http_error)
@app.on_event("startup")
async def startup():
global model
logger.info(f"Downloading/fetching model: {os.path.join(settings.gpt4all_path, settings.model)}")
from gpt4all import GPT4All
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
if settings.inference_mode == "cpu":
logger.info(f"Downloading/fetching model: {os.path.join(settings.gpt4all_path, settings.model)}")
from gpt4all import GPT4All
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
logger.info(f"GPT4All API is ready to infer from {settings.model} on CPU.")
else:
# is it possible to do this once the server is up?
## TODO block until HF inference server is up.
logger.info(f"GPT4All API is ready to infer from {settings.model} on CPU.")
logger.info("GPT4All API is ready.")
@app.on_event("shutdown")
async def shutdown():
@@ -57,5 +66,7 @@ if "gunicorn" in os.environ.get("SERVER_SOFTWARE", ""):
uvicorn_logger.handlers = gunicorn_error_logger.handlers
else:
# https://github.com/tiangolo/fastapi/issues/2019
LOG_FORMAT2 = "[%(asctime)s %(process)d:%(threadName)s] %(name)s - %(levelname)s - %(message)s | %(filename)s:%(lineno)d"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT2)
LOG_FORMAT2 = (
"[%(asctime)s %(process)d:%(threadName)s] %(name)s - %(levelname)s - %(message)s | %(filename)s:%(lineno)d"
)
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT2)

View File

@@ -2,30 +2,22 @@
Use the OpenAI python API to test gpt4all models.
"""
import openai
openai.api_base = "http://localhost:4891/v1"
openai.api_key = "not needed for a local LLM"
def test_completion():
model = "gpt4all-j-v1.3-groovy"
model = "ggml-mpt-7b-chat.bin"
prompt = "Who is Michael Jordan?"
response = openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=False
model=model, prompt=prompt, max_tokens=50, temperature=0.28, top_p=0.95, n=1, echo=True, stream=False
)
assert len(response['choices'][0]['text']) > len(prompt)
print(response)
def test_streaming_completion():
model = "gpt4all-j-v1.3-groovy"
model = "ggml-mpt-7b-chat.bin"
prompt = "Who is Michael Jordan?"
tokens = []
for resp in openai.Completion.create(
@@ -42,10 +34,12 @@ def test_streaming_completion():
assert (len(tokens) > 0)
assert (len("".join(tokens)) > len(prompt))
# def test_chat_completions():
# model = "gpt4all-j-v1.3-groovy"
# prompt = "Who is Michael Jordan?"
# response = openai.ChatCompletion.create(
# model=model,
# messages=[]
# )
def test_batched_completion():
model = "ggml-mpt-7b-chat.bin"
prompt = "Who is Michael Jordan?"
response = openai.Completion.create(
model=model, prompt=[prompt] * 3, max_tokens=50, temperature=0.28, top_p=0.95, n=1, echo=True, stream=False
)
assert len(response['choices'][0]['text']) > len(prompt)
assert len(response['choices']) == 3

View File

@@ -1,10 +1,12 @@
aiohttp>=3.6.2
aiofiles
pydantic>=1.4.0
pydantic>=1.4.0,<2.0.0
requests>=2.24.0
ujson>=2.0.2
fastapi>=0.95.0
Jinja2>=3.0
gpt4all==1.0.1
gpt4all>=1.0.0
pytest
openai
openai
black
isort