mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-19 12:14:20 +00:00
GPU Inference Server (#1112)
* feat: local inference server * fix: source to use bash + vars * chore: isort and black * fix: make file + inference mode * chore: logging * refactor: remove old links * fix: add new env vars * feat: hf inference server * refactor: remove old links * test: batch and single response * chore: black + isort * separate gpu and cpu dockerfiles * moved gpu to separate dockerfile * Fixed test endpoints * Edits to API. server won't start due to failed instantiation error * Method signature * fix: gpu_infer * tests: fix tests --------- Co-authored-by: Andriy Mulyar <andriy.mulyar@gmail.com>
This commit is contained in:
parent
58f0fcab57
commit
8aba2c9009
7
gpt4all-api/.isort.cfg
Normal file
7
gpt4all-api/.isort.cfg
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
[settings]
|
||||||
|
known_third_party=geopy,nltk,np,numpy,pandas,pysbd,fire,torch
|
||||||
|
|
||||||
|
line_length=120
|
||||||
|
include_trailing_comma=True
|
||||||
|
multi_line_output=3
|
||||||
|
use_parentheses=True
|
@ -17,6 +17,18 @@ Then, start the backend with:
|
|||||||
docker compose up --build
|
docker compose up --build
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This will run both the API and locally hosted GPU inference server. If you want to run the API without the GPU inference server, you can run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up --build gpt4all_api
|
||||||
|
```
|
||||||
|
|
||||||
|
To run the API with the GPU inference server, you will need to include environment variables (like the `MODEL_ID`). Edit the `.env` file and run
|
||||||
|
```bash
|
||||||
|
docker compose --env-file .env up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### Spinning up your app
|
#### Spinning up your app
|
||||||
Run `docker compose up` to spin up the backend. Monitor the logs for errors in-case you forgot to set an environment variable above.
|
Run `docker compose up` to spin up the backend. Monitor the logs for errors in-case you forgot to set an environment variable above.
|
||||||
|
|
||||||
|
24
gpt4all-api/docker-compose.gpu.yaml
Normal file
24
gpt4all-api/docker-compose.gpu.yaml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
version: "3.8"
|
||||||
|
|
||||||
|
services:
|
||||||
|
gpt4all_gpu:
|
||||||
|
image: ghcr.io/huggingface/text-generation-inference
|
||||||
|
container_name: gpt4all_gpu
|
||||||
|
restart: always #restart on error (usually code compilation from save during bad state)
|
||||||
|
environment:
|
||||||
|
- HUGGING_FACE_HUB_TOKEN=token
|
||||||
|
- USE_FLASH_ATTENTION=false
|
||||||
|
- MODEL_ID=''
|
||||||
|
- NUM_SHARD=1
|
||||||
|
command: --model-id $MODEL_ID --num-shard $NUM_SHARD
|
||||||
|
volumes:
|
||||||
|
- ./:/data
|
||||||
|
ports:
|
||||||
|
- "8080:80"
|
||||||
|
shm_size: 1g
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
capabilities: [gpu]
|
@ -1,4 +1,4 @@
|
|||||||
version: "3.5"
|
version: "3.8"
|
||||||
|
|
||||||
services:
|
services:
|
||||||
gpt4all_api:
|
gpt4all_api:
|
||||||
@ -13,6 +13,7 @@ services:
|
|||||||
- LOGLEVEL=debug
|
- LOGLEVEL=debug
|
||||||
- PORT=4891
|
- PORT=4891
|
||||||
- model=ggml-mpt-7b-chat.bin
|
- model=ggml-mpt-7b-chat.bin
|
||||||
|
- inference_mode=cpu
|
||||||
volumes:
|
volumes:
|
||||||
- './gpt4all_api/app:/app'
|
- './gpt4all_api/app:/app'
|
||||||
command: ["/start-reload.sh"]
|
command: ["/start-reload.sh"]
|
@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from api_v1.settings import settings
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from api_v1.settings import settings
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -19,8 +21,9 @@ async def on_startup(app):
|
|||||||
startup_msg = startup_msg_fmt.format(settings=settings)
|
startup_msg = startup_msg_fmt.format(settings=settings)
|
||||||
log.info(startup_msg)
|
log.info(startup_msg)
|
||||||
|
|
||||||
|
|
||||||
def startup_event_handler(app):
|
def startup_event_handler(app):
|
||||||
async def start_app() -> None:
|
async def start_app() -> None:
|
||||||
await on_startup(app)
|
await on_startup(app)
|
||||||
|
|
||||||
return start_app
|
return start_app
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from fastapi import APIRouter, Depends, Response, Security, status
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Dict
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from api_v1.settings import settings
|
from api_v1.settings import settings
|
||||||
|
from fastapi import APIRouter, Depends, Response, Security, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@ -11,11 +12,11 @@ logger.setLevel(logging.DEBUG)
|
|||||||
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str = Field(..., description='The model to generate a completion from.')
|
model: str = Field(..., description='The model to generate a completion from.')
|
||||||
messages: List[ChatCompletionMessage] = Field(..., description='The model to generate a completion from.')
|
messages: List[ChatCompletionMessage] = Field(..., description='The model to generate a completion from.')
|
||||||
@ -26,11 +27,13 @@ class ChatCompletionChoice(BaseModel):
|
|||||||
index: int
|
index: int
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionUsage(BaseModel):
|
class ChatCompletionUsage(BaseModel):
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str = 'text_completion'
|
object: str = 'text_completion'
|
||||||
@ -42,6 +45,7 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["Completions Endpoints"])
|
router = APIRouter(prefix="/chat", tags=["Completions Endpoints"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/completions", response_model=ChatCompletionResponse)
|
@router.post("/completions", response_model=ChatCompletionResponse)
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
'''
|
'''
|
||||||
@ -53,11 +57,5 @@ async def chat_completion(request: ChatCompletionRequest):
|
|||||||
created=time.time(),
|
created=time.time(),
|
||||||
model=request.model,
|
model=request.model,
|
||||||
choices=[{}],
|
choices=[{}],
|
||||||
usage={
|
usage={'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
|
||||||
'prompt_tokens': 0,
|
|
||||||
'completion_tokens': 0,
|
|
||||||
'total_tokens': 0
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Response, Security, status
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Dict, Iterable, AsyncIterable
|
from typing import List, Dict, Iterable, AsyncIterable
|
||||||
import logging
|
import logging
|
||||||
from uuid import uuid4
|
|
||||||
from api_v1.settings import settings
|
|
||||||
from gpt4all import GPT4All
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from api_v1.settings import settings
|
||||||
|
from fastapi import APIRouter, Depends, Response, Security, status, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@ -16,14 +18,17 @@ logger.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str = Field(..., description='The model to generate a completion from.')
|
model: str = Field(settings.model, description='The model to generate a completion from.')
|
||||||
prompt: str = Field(..., description='The prompt to begin completing from.')
|
prompt: Union[List[str], str] = Field(..., description='The prompt to begin completing from.')
|
||||||
max_tokens: int = Field(7, description='Max tokens to generate')
|
max_tokens: int = Field(None, description='Max tokens to generate')
|
||||||
temperature: float = Field(0, description='Model temperature')
|
temperature: float = Field(settings.temp, description='Model temperature')
|
||||||
top_p: float = Field(1.0, description='top_p')
|
top_p: float = Field(settings.top_k, description='top_p')
|
||||||
n: int = Field(1, description='')
|
top_k: int = Field(settings.top_k, description='top_k')
|
||||||
|
n: int = Field(1, description='How many completions to generate for each prompt')
|
||||||
stream: bool = Field(False, description='Stream responses')
|
stream: bool = Field(False, description='Stream responses')
|
||||||
|
repeat_penalty: float = Field(settings.repeat_penalty, description='Repeat penalty')
|
||||||
|
|
||||||
|
|
||||||
class CompletionChoice(BaseModel):
|
class CompletionChoice(BaseModel):
|
||||||
@ -58,7 +63,6 @@ class CompletionStreamResponse(BaseModel):
|
|||||||
|
|
||||||
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
|
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
|
||||||
|
|
||||||
|
|
||||||
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
|
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
|
||||||
"""
|
"""
|
||||||
Streams a GPT4All output to the client.
|
Streams a GPT4All output to the client.
|
||||||
@ -80,6 +84,27 @@ def stream_completion(output: Iterable, base_response: CompletionStreamResponse)
|
|||||||
))]
|
))]
|
||||||
yield f"data: {json.dumps(dict(chunk))}\n\n"
|
yield f"data: {json.dumps(dict(chunk))}\n\n"
|
||||||
|
|
||||||
|
async def gpu_infer(payload, header):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
settings.hf_inference_server_host, headers=header, data=json.dumps(payload)
|
||||||
|
) as response:
|
||||||
|
resp = await response.json()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
# Handle client-side errors (e.g., connection error, invalid URL)
|
||||||
|
logger.error(f"Client error: {e}")
|
||||||
|
except aiohttp.ServerError as e:
|
||||||
|
# Handle server-side errors (e.g., internal server error)
|
||||||
|
logger.error(f"Server error: {e}")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# Handle JSON decoding errors
|
||||||
|
logger.error(f"JSON decoding error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
# Handle other unexpected exceptions
|
||||||
|
logger.error(f"Unexpected error: {e}")
|
||||||
|
|
||||||
@router.post("/", response_model=CompletionResponse)
|
@router.post("/", response_model=CompletionResponse)
|
||||||
async def completions(request: CompletionRequest):
|
async def completions(request: CompletionRequest):
|
||||||
@ -87,42 +112,104 @@ async def completions(request: CompletionRequest):
|
|||||||
Completes a GPT4All model response.
|
Completes a GPT4All model response.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
if request.model != settings.model:
|
||||||
|
raise HTTPException(status_code=400, detail=f"The GPT4All inference server is booted to only infer: `{settings.model}`")
|
||||||
|
|
||||||
output = model.generate(prompt=request.prompt,
|
if settings.inference_mode == "gpu":
|
||||||
n_predict=request.max_tokens,
|
params = request.dict(exclude={'model', 'prompt', 'max_tokens', 'n'})
|
||||||
streaming=request.stream,
|
params["max_new_tokens"] = request.max_tokens
|
||||||
top_k=20,
|
params["num_return_sequences"] = request.n
|
||||||
top_p=request.top_p,
|
|
||||||
temp=request.temperature,
|
header = {"Content-Type": "application/json"}
|
||||||
n_batch=1024,
|
payload = {"parameters": params}
|
||||||
repeat_penalty=1.2,
|
if isinstance(request.prompt, list):
|
||||||
repeat_last_n=10)
|
tasks = []
|
||||||
|
for prompt in request.prompt:
|
||||||
|
payload["inputs"] = prompt
|
||||||
|
task = gpu_infer(payload, header)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
choices = []
|
||||||
|
for response in results:
|
||||||
|
scores = response["scores"] if "scores" in response else -1.0
|
||||||
|
choices.append(
|
||||||
|
dict(
|
||||||
|
CompletionChoice(
|
||||||
|
text=response["generated_text"], index=0, logprobs=scores, finish_reason='stop'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return CompletionResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
created=time.time(),
|
||||||
|
model=request.model,
|
||||||
|
choices=choices,
|
||||||
|
usage={'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If streaming, we need to return a StreamingResponse
|
||||||
|
payload["inputs"] = request.prompt
|
||||||
|
|
||||||
|
resp = await gpu_infer(payload, header)
|
||||||
|
|
||||||
|
output = resp["generated_text"]
|
||||||
|
# this returns all logprobs
|
||||||
|
scores = resp["scores"] if "scores" in resp else -1.0
|
||||||
|
|
||||||
|
return CompletionResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
created=time.time(),
|
||||||
|
model=request.model,
|
||||||
|
choices=[dict(CompletionChoice(text=output, index=0, logprobs=scores, finish_reason='stop'))],
|
||||||
|
usage={'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
|
||||||
|
)
|
||||||
|
|
||||||
# If streaming, we need to return a StreamingResponse
|
|
||||||
if request.stream:
|
|
||||||
base_chunk = CompletionStreamResponse(
|
|
||||||
id=str(uuid4()),
|
|
||||||
created=time.time(),
|
|
||||||
model=request.model,
|
|
||||||
choices=[]
|
|
||||||
)
|
|
||||||
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
|
|
||||||
media_type="text/event-stream")
|
|
||||||
else:
|
else:
|
||||||
return CompletionResponse(
|
|
||||||
id=str(uuid4()),
|
if isinstance(request.prompt, list):
|
||||||
created=time.time(),
|
if len(request.prompt) > 1:
|
||||||
model=request.model,
|
raise HTTPException(status_code=400, detail="Can only infer one inference per request in CPU mode.")
|
||||||
choices=[dict(CompletionChoice(
|
else:
|
||||||
text=output,
|
request.prompt = request.prompt[0]
|
||||||
index=0,
|
|
||||||
logprobs=-1,
|
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
||||||
finish_reason='stop'
|
|
||||||
))],
|
output = model.generate(prompt=request.prompt,
|
||||||
usage={
|
max_tokens=request.max_tokens,
|
||||||
'prompt_tokens': 0, #TODO how to compute this?
|
streaming=request.stream,
|
||||||
'completion_tokens': 0,
|
top_k=request.top_k,
|
||||||
'total_tokens': 0
|
top_p=request.top_p,
|
||||||
}
|
temp=request.temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If streaming, we need to return a StreamingResponse
|
||||||
|
if request.stream:
|
||||||
|
base_chunk = CompletionStreamResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
created=time.time(),
|
||||||
|
model=request.model,
|
||||||
|
choices=[]
|
||||||
|
)
|
||||||
|
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return CompletionResponse(
|
||||||
|
id=str(uuid4()),
|
||||||
|
created=time.time(),
|
||||||
|
model=request.model,
|
||||||
|
choices=[dict(CompletionChoice(
|
||||||
|
text=output,
|
||||||
|
index=0,
|
||||||
|
logprobs=-1,
|
||||||
|
finish_reason='stop'
|
||||||
|
))],
|
||||||
|
usage={
|
||||||
|
'prompt_tokens': 0, # TODO how to compute this?
|
||||||
|
'completion_tokens': 0,
|
||||||
|
'total_tokens': 0
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@ -1,22 +1,27 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from api_v1.settings import settings
|
||||||
from fastapi import APIRouter, Depends, Response, Security, status
|
from fastapi import APIRouter, Depends, Response, Security, status
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict
|
|
||||||
import logging
|
|
||||||
from api_v1.settings import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||||
|
|
||||||
|
|
||||||
class ListEnginesResponse(BaseModel):
|
class ListEnginesResponse(BaseModel):
|
||||||
data: List[Dict] = Field(..., description="All available models.")
|
data: List[Dict] = Field(..., description="All available models.")
|
||||||
|
|
||||||
|
|
||||||
class EngineResponse(BaseModel):
|
class EngineResponse(BaseModel):
|
||||||
data: List[Dict] = Field(..., description="All available models.")
|
data: List[Dict] = Field(..., description="All available models.")
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/engines", tags=["Search Endpoints"])
|
router = APIRouter(prefix="/engines", tags=["Search Endpoints"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=ListEnginesResponse)
|
@router.get("/", response_model=ListEnginesResponse)
|
||||||
async def list_engines():
|
async def list_engines():
|
||||||
'''
|
'''
|
||||||
@ -29,10 +34,7 @@ async def list_engines():
|
|||||||
|
|
||||||
@router.get("/{engine_id}", response_model=EngineResponse)
|
@router.get("/{engine_id}", response_model=EngineResponse)
|
||||||
async def retrieve_engine(engine_id: str):
|
async def retrieve_engine(engine_id: str):
|
||||||
'''
|
''' '''
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return EngineResponse()
|
return EngineResponse()
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/health", tags=["Health"])
|
router = APIRouter(prefix="/health", tags=["Health"])
|
||||||
|
@ -5,6 +5,14 @@ class Settings(BaseSettings):
|
|||||||
app_environment = 'dev'
|
app_environment = 'dev'
|
||||||
model: str = 'ggml-mpt-7b-chat.bin'
|
model: str = 'ggml-mpt-7b-chat.bin'
|
||||||
gpt4all_path: str = '/models'
|
gpt4all_path: str = '/models'
|
||||||
|
inference_mode: str = "cpu"
|
||||||
|
hf_inference_server_host: str = "http://gpt4all_gpu:80/generate"
|
||||||
|
|
||||||
|
temp: float = 0.18
|
||||||
|
top_p: float = 1.0
|
||||||
|
top_k: int = 50
|
||||||
|
repeat_penalty: float = 1.18
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
import os
|
|
||||||
import docs
|
|
||||||
import logging
|
import logging
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.logger import logger as fastapi_logger
|
|
||||||
from api_v1.settings import settings
|
|
||||||
from api_v1.api import router as v1_router
|
|
||||||
from api_v1 import events
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import docs
|
||||||
|
from api_v1 import events
|
||||||
|
from api_v1.api import router as v1_router
|
||||||
|
from api_v1.settings import settings
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.logger import logger as fastapi_logger
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title='GPT4All API', description=docs.desc)
|
app = FastAPI(title='GPT4All API', description=docs.desc)
|
||||||
|
|
||||||
#CORS Configuration (in-case you want to deploy)
|
# CORS Configuration (in-case you want to deploy)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=["*"],
|
||||||
@ -29,14 +29,23 @@ app.include_router(v1_router, prefix='/v1')
|
|||||||
app.add_event_handler('startup', events.startup_event_handler(app))
|
app.add_event_handler('startup', events.startup_event_handler(app))
|
||||||
app.add_exception_handler(HTTPException, events.on_http_error)
|
app.add_exception_handler(HTTPException, events.on_http_error)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
global model
|
global model
|
||||||
logger.info(f"Downloading/fetching model: {os.path.join(settings.gpt4all_path, settings.model)}")
|
if settings.inference_mode == "cpu":
|
||||||
from gpt4all import GPT4All
|
logger.info(f"Downloading/fetching model: {os.path.join(settings.gpt4all_path, settings.model)}")
|
||||||
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
|
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
||||||
|
|
||||||
|
logger.info(f"GPT4All API is ready to infer from {settings.model} on CPU.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# is it possible to do this once the server is up?
|
||||||
|
## TODO block until HF inference server is up.
|
||||||
|
logger.info(f"GPT4All API is ready to infer from {settings.model} on CPU.")
|
||||||
|
|
||||||
logger.info("GPT4All API is ready.")
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown():
|
async def shutdown():
|
||||||
@ -57,5 +66,7 @@ if "gunicorn" in os.environ.get("SERVER_SOFTWARE", ""):
|
|||||||
uvicorn_logger.handlers = gunicorn_error_logger.handlers
|
uvicorn_logger.handlers = gunicorn_error_logger.handlers
|
||||||
else:
|
else:
|
||||||
# https://github.com/tiangolo/fastapi/issues/2019
|
# https://github.com/tiangolo/fastapi/issues/2019
|
||||||
LOG_FORMAT2 = "[%(asctime)s %(process)d:%(threadName)s] %(name)s - %(levelname)s - %(message)s | %(filename)s:%(lineno)d"
|
LOG_FORMAT2 = (
|
||||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT2)
|
"[%(asctime)s %(process)d:%(threadName)s] %(name)s - %(levelname)s - %(message)s | %(filename)s:%(lineno)d"
|
||||||
|
)
|
||||||
|
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT2)
|
||||||
|
@ -2,30 +2,22 @@
|
|||||||
Use the OpenAI python API to test gpt4all models.
|
Use the OpenAI python API to test gpt4all models.
|
||||||
"""
|
"""
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.api_base = "http://localhost:4891/v1"
|
openai.api_base = "http://localhost:4891/v1"
|
||||||
|
|
||||||
openai.api_key = "not needed for a local LLM"
|
openai.api_key = "not needed for a local LLM"
|
||||||
|
|
||||||
|
|
||||||
def test_completion():
|
def test_completion():
|
||||||
model = "gpt4all-j-v1.3-groovy"
|
model = "ggml-mpt-7b-chat.bin"
|
||||||
prompt = "Who is Michael Jordan?"
|
prompt = "Who is Michael Jordan?"
|
||||||
response = openai.Completion.create(
|
response = openai.Completion.create(
|
||||||
model=model,
|
model=model, prompt=prompt, max_tokens=50, temperature=0.28, top_p=0.95, n=1, echo=True, stream=False
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=50,
|
|
||||||
temperature=0.28,
|
|
||||||
top_p=0.95,
|
|
||||||
n=1,
|
|
||||||
echo=True,
|
|
||||||
stream=False
|
|
||||||
)
|
)
|
||||||
assert len(response['choices'][0]['text']) > len(prompt)
|
assert len(response['choices'][0]['text']) > len(prompt)
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_completion():
|
def test_streaming_completion():
|
||||||
model = "gpt4all-j-v1.3-groovy"
|
model = "ggml-mpt-7b-chat.bin"
|
||||||
prompt = "Who is Michael Jordan?"
|
prompt = "Who is Michael Jordan?"
|
||||||
tokens = []
|
tokens = []
|
||||||
for resp in openai.Completion.create(
|
for resp in openai.Completion.create(
|
||||||
@ -42,10 +34,12 @@ def test_streaming_completion():
|
|||||||
assert (len(tokens) > 0)
|
assert (len(tokens) > 0)
|
||||||
assert (len("".join(tokens)) > len(prompt))
|
assert (len("".join(tokens)) > len(prompt))
|
||||||
|
|
||||||
# def test_chat_completions():
|
|
||||||
# model = "gpt4all-j-v1.3-groovy"
|
def test_batched_completion():
|
||||||
# prompt = "Who is Michael Jordan?"
|
model = "ggml-mpt-7b-chat.bin"
|
||||||
# response = openai.ChatCompletion.create(
|
prompt = "Who is Michael Jordan?"
|
||||||
# model=model,
|
response = openai.Completion.create(
|
||||||
# messages=[]
|
model=model, prompt=[prompt] * 3, max_tokens=50, temperature=0.28, top_p=0.95, n=1, echo=True, stream=False
|
||||||
# )
|
)
|
||||||
|
assert len(response['choices'][0]['text']) > len(prompt)
|
||||||
|
assert len(response['choices']) == 3
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
aiohttp>=3.6.2
|
aiohttp>=3.6.2
|
||||||
aiofiles
|
aiofiles
|
||||||
pydantic>=1.4.0
|
pydantic>=1.4.0,<2.0.0
|
||||||
requests>=2.24.0
|
requests>=2.24.0
|
||||||
ujson>=2.0.2
|
ujson>=2.0.2
|
||||||
fastapi>=0.95.0
|
fastapi>=0.95.0
|
||||||
Jinja2>=3.0
|
Jinja2>=3.0
|
||||||
gpt4all==1.0.1
|
gpt4all>=1.0.0
|
||||||
pytest
|
pytest
|
||||||
openai
|
openai
|
||||||
|
black
|
||||||
|
isort
|
@ -1,22 +1,26 @@
|
|||||||
ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
|
ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
|
||||||
APP_NAME:=gpt4all_api
|
APP_NAME:=gpt4all_api
|
||||||
PYTHON:=python3.8
|
PYTHON:=python3.8
|
||||||
|
SHELL := /bin/bash
|
||||||
|
|
||||||
all: dependencies
|
all: dependencies
|
||||||
|
|
||||||
fresh: clean dependencies
|
fresh: clean dependencies
|
||||||
|
|
||||||
testenv: clean_testenv test_build
|
testenv: clean_testenv test_build
|
||||||
docker compose up --build
|
docker compose -f docker-compose.yaml up --build
|
||||||
|
|
||||||
|
testenv_gpu: clean_testenv test_build
|
||||||
|
docker compose -f docker-compose.yaml -f docker-compose.gpu.yaml up --build
|
||||||
|
|
||||||
testenv_d: clean_testenv test_build
|
testenv_d: clean_testenv test_build
|
||||||
docker compose up --build -d
|
docker compose up --build -d
|
||||||
|
|
||||||
test:
|
test:
|
||||||
docker compose exec gpt4all_api pytest -svv --disable-warnings -p no:cacheprovider /app/tests
|
docker compose exec $(APP_NAME) pytest -svv --disable-warnings -p no:cacheprovider /app/tests
|
||||||
|
|
||||||
test_build:
|
test_build:
|
||||||
DOCKER_BUILDKIT=1 docker build -t gpt4all_api --progress plain -f gpt4all_api/Dockerfile.buildkit .
|
DOCKER_BUILDKIT=1 docker build -t $(APP_NAME) --progress plain -f $(APP_NAME)/Dockerfile.buildkit .
|
||||||
|
|
||||||
clean_testenv:
|
clean_testenv:
|
||||||
docker compose down -v
|
docker compose down -v
|
||||||
@ -27,7 +31,7 @@ venv:
|
|||||||
if [ ! -d $(ROOT_DIR)/env ]; then $(PYTHON) -m venv $(ROOT_DIR)/env; fi
|
if [ ! -d $(ROOT_DIR)/env ]; then $(PYTHON) -m venv $(ROOT_DIR)/env; fi
|
||||||
|
|
||||||
dependencies: venv
|
dependencies: venv
|
||||||
source $(ROOT_DIR)/env/bin/activate; yes w | python -m pip install -r $(ROOT_DIR)/gpt4all_api/requirements.txt
|
source $(ROOT_DIR)/env/bin/activate; $(PYTHON) -m pip install -r $(ROOT_DIR)/$(APP_NAME)/requirements.txt
|
||||||
|
|
||||||
clean: clean_testenv
|
clean: clean_testenv
|
||||||
# Remove existing environment
|
# Remove existing environment
|
||||||
@ -35,3 +39,8 @@ clean: clean_testenv
|
|||||||
rm -rf $(ROOT_DIR)/$(APP_NAME)/*.pyc;
|
rm -rf $(ROOT_DIR)/$(APP_NAME)/*.pyc;
|
||||||
|
|
||||||
|
|
||||||
|
black:
|
||||||
|
source $(ROOT_DIR)/env/bin/activate; black -l 120 -S --target-version py38 $(APP_NAME)
|
||||||
|
|
||||||
|
isort:
|
||||||
|
source $(ROOT_DIR)/env/bin/activate; isort --ignore-whitespace --atomic -w 120 $(APP_NAME)
|
Loading…
Reference in New Issue
Block a user