GPT4All API Scaffolding. Matches OpenAI OpenAPI spec for chats and completions (#839)

* GPT4All API Scaffolding. Matches OpenAI OpenAI spec for engines, chats and completions

* Edits for docker building

* FastAPI app builds and pydantic models are accurate

* Added groovy download into dockerfile

* improved dockerfile

* Chat completions endpoint edits

* API uni test sketch

* Working example of groovy inference with open ai api

* Added lines to test

* Set default to mpt
This commit is contained in:
Andriy Mulyar
2023-06-28 14:28:52 -04:00
committed by GitHub
parent 6b8456bf99
commit 633e2a2137
21 changed files with 603 additions and 2 deletions

View File

@@ -0,0 +1,23 @@
# syntax=docker/dockerfile:1.0.0-experimental
FROM tiangolo/uvicorn-gunicorn:python3.11
ARG MODEL_BIN=ggml-mpt-7b-chat.bin
# Put first so anytime this file changes other cached layers are invalidated.
COPY gpt4all_api/requirements.txt /requirements.txt
RUN pip install --upgrade pip
# Run various pip install commands with ssh keys from host machine.
RUN --mount=type=ssh pip install -r /requirements.txt && \
rm -Rf /root/.cache && rm -Rf /tmp/pip-install*
# Finally, copy app and client.
COPY gpt4all_api/app /app
RUN mkdir -p /models
# Include the following line to bake a model into the image and not have to download it on API start.
RUN wget -q --show-progress=off https://gpt4all.io/models/${MODEL_BIN} -P /models \
&& md5sum /models/${MODEL_BIN}

View File

@@ -0,0 +1 @@
# FastAPI app for serving GPT4All models

View File

View File

@@ -0,0 +1,8 @@
from api_v1.routes import chat, completions, engines
from fastapi import APIRouter
router = APIRouter()
router.include_router(chat.router)
router.include_router(completions.router)
router.include_router(engines.router)

View File

@@ -0,0 +1,26 @@
import logging
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from starlette.requests import Request
from api_v1.settings import settings
log = logging.getLogger(__name__)
startup_msg_fmt = """
Starting up GPT4All API
"""
async def on_http_error(request: Request, exc: HTTPException):
return JSONResponse({'detail': exc.detail}, status_code=exc.status_code)
async def on_startup(app):
startup_msg = startup_msg_fmt.format(settings=settings)
log.info(startup_msg)
def startup_event_handler(app):
async def start_app() -> None:
await on_startup(app)
return start_app

View File

@@ -0,0 +1,63 @@
from fastapi import APIRouter, Depends, Response, Security, status
from pydantic import BaseModel, Field
from typing import List, Dict
import logging
import time
from api_v1.settings import settings
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class ChatCompletionMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = Field(..., description='The model to generate a completion from.')
messages: List[ChatCompletionMessage] = Field(..., description='The model to generate a completion from.')
class ChatCompletionChoice(BaseModel):
message: ChatCompletionMessage
index: int
finish_reason: str
class ChatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = 'text_completion'
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
router = APIRouter(prefix="/chat", tags=["Completions Endpoints"])
@router.post("/completions", response_model=ChatCompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
'''
Completes a GPT4All model response.
'''
return ChatCompletionResponse(
id='asdf',
created=time.time(),
model=request.model,
choices=[{}],
usage={
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0
}
)

View File

@@ -0,0 +1,86 @@
from fastapi import APIRouter, Depends, Response, Security, status
from pydantic import BaseModel, Field
from typing import List, Dict
import logging
from uuid import uuid4
from api_v1.settings import settings
from gpt4all import GPT4All
import time
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class CompletionRequest(BaseModel):
model: str = Field(..., description='The model to generate a completion from.')
prompt: str = Field(..., description='The prompt to begin completing from.')
max_tokens: int = Field(7, description='Max tokens to generate')
temperature: float = Field(0, description='Model temperature')
top_p: float = Field(1.0, description='top_p')
n: int = Field(1, description='')
stream: bool = Field(False, description='Stream responses')
class CompletionChoice(BaseModel):
text: str
index: int
logprobs: float
finish_reason: str
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CompletionResponse(BaseModel):
id: str
object: str = 'text_completion'
created: int
model: str
choices: List[CompletionChoice]
usage: CompletionUsage
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
@router.post("/", response_model=CompletionResponse)
async def completions(request: CompletionRequest):
'''
Completes a GPT4All model response.
'''
# global model
if request.stream:
raise NotImplementedError("Streaming is not yet implements")
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
output = model.generate(prompt=request.prompt,
n_predict = request.max_tokens,
top_k = 20,
top_p = request.top_p,
temp=request.temperature,
n_batch = 1024,
repeat_penalty = 1.2,
repeat_last_n = 10,
context_erase = 0)
return CompletionResponse(
id=str(uuid4()),
created=time.time(),
model=request.model,
choices=[dict(CompletionChoice(
text=output,
index=0,
logprobs=-1,
finish_reason='stop'
))],
usage={
'prompt_tokens': 0, #TODO how to compute this?
'completion_tokens': 0,
'total_tokens': 0
}
)

View File

@@ -0,0 +1,38 @@
from fastapi import APIRouter, Depends, Response, Security, status
from pydantic import BaseModel, Field
from typing import List, Dict
import logging
from api_v1.settings import settings
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class ListEnginesResponse(BaseModel):
data: List[Dict] = Field(..., description="All available models.")
class EngineResponse(BaseModel):
data: List[Dict] = Field(..., description="All available models.")
router = APIRouter(prefix="/engines", tags=["Search Endpoints"])
@router.get("/", response_model=ListEnginesResponse)
async def list_engines():
'''
List all available GPT4All models from
https://raw.githubusercontent.com/nomic-ai/gpt4all/main/gpt4all-chat/metadata/models.json
'''
raise NotImplementedError()
return ListEnginesResponse(data=[])
@router.get("/{engine_id}", response_model=EngineResponse)
async def retrieve_engine(engine_id: str):
'''
'''
raise NotImplementedError()
return EngineResponse()

View File

@@ -0,0 +1,12 @@
import logging
from fastapi import APIRouter
from fastapi.responses import JSONResponse
log = logging.getLogger(__name__)
router = APIRouter(prefix="/health", tags=["Health"])
@router.get('/', response_class=JSONResponse)
async def health_check():
"""Runs a health check on this instance of the API."""
return JSONResponse({'status': 'ok'}, headers={'Access-Control-Allow-Origin': '*'})

View File

@@ -0,0 +1,10 @@
from pydantic import BaseSettings
class Settings(BaseSettings):
app_environment = 'dev'
model: str = 'ggml-mpt-7b-chat.bin'
gpt4all_path: str = '/models'
settings = Settings()

View File

@@ -0,0 +1,3 @@
desc = 'GPT4All API'
endpoint_paths = {'health': '/health'}

View File

@@ -0,0 +1,61 @@
import os
import docs
import logging
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.cors import CORSMiddleware
from fastapi.logger import logger as fastapi_logger
from api_v1.settings import settings
from api_v1.api import router as v1_router
from api_v1 import events
import os
logger = logging.getLogger(__name__)
app = FastAPI(title='GPT4All API', description=docs.desc)
#CORS Configuration (in-case you want to deploy)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
logger.info('Adding v1 endpoints..')
# add v1
app.include_router(v1_router, prefix='/v1')
app.add_event_handler('startup', events.startup_event_handler(app))
app.add_exception_handler(HTTPException, events.on_http_error)
@app.on_event("startup")
async def startup():
global model
logger.info(f"Downloading/fetching model: {os.path.join(settings.gpt4all_path, settings.model)}")
from gpt4all import GPT4All
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
logger.info("GPT4All API is ready.")
@app.on_event("shutdown")
async def shutdown():
logger.info("Shutting down API")
# This is needed to get logs to show up in the app
if "gunicorn" in os.environ.get("SERVER_SOFTWARE", ""):
gunicorn_error_logger = logging.getLogger("gunicorn.error")
gunicorn_logger = logging.getLogger("gunicorn")
root_logger = logging.getLogger()
fastapi_logger.setLevel(gunicorn_logger.level)
fastapi_logger.handlers = gunicorn_error_logger.handlers
root_logger.setLevel(gunicorn_logger.level)
uvicorn_logger = logging.getLogger("uvicorn.access")
uvicorn_logger.handlers = gunicorn_error_logger.handlers
else:
# https://github.com/tiangolo/fastapi/issues/2019
LOG_FORMAT2 = "[%(asctime)s %(process)d:%(threadName)s] %(name)s - %(levelname)s - %(message)s | %(filename)s:%(lineno)d"
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT2)

View File

@@ -0,0 +1,35 @@
"""
Use the OpenAI python API to test gpt4all models.
"""
import openai
openai.api_base = "http://localhost:4891/v1"
openai.api_key = "not needed for a local LLM"
def test_completion():
model = "gpt4all-j-v1.3-groovy"
prompt = "Who is Michael Jordan?"
response = openai.Completion.create(
model=model,
prompt=prompt,
max_tokens=50,
temperature=0.28,
top_p=0.95,
n=1,
echo=True,
stream=False
)
assert len(response['choices'][0]['text']) > len(prompt)
print(response)
# def test_chat_completions():
# model = "gpt4all-j-v1.3-groovy"
# prompt = "Who is Michael Jordan?"
# response = openai.ChatCompletion.create(
# model=model,
# messages=[]
# )

View File

@@ -0,0 +1,10 @@
aiohttp>=3.6.2
aiofiles
pydantic>=1.4.0
requests>=2.24.0
ujson>=2.0.2
fastapi>=0.95.0
Jinja2>=3.0
gpt4all==0.2.3
pytest
openai