mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-10-10 12:23:42 +00:00
Update to gpt4all version 1.0.1. Implement the Streaming version of the completions endpoint. Implemented an openai python client test for the new streaming functionality. (#1129)
Co-authored-by: Brandon <bbeiler@ridgelineintl.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, Response, Security, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Iterable, AsyncIterable
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
from api_v1.settings import settings
|
||||
@@ -10,6 +13,7 @@ import time
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@@ -28,10 +32,13 @@ class CompletionChoice(BaseModel):
|
||||
logprobs: float
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = 'text_completion'
|
||||
@@ -41,46 +48,81 @@ class CompletionResponse(BaseModel):
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = 'text_completion'
|
||||
created: int
|
||||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
|
||||
|
||||
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
|
||||
|
||||
|
||||
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
|
||||
"""
|
||||
Streams a GPT4All output to the client.
|
||||
|
||||
Args:
|
||||
output: The output of GPT4All.generate(), which is an iterable of tokens.
|
||||
base_response: The base response object, which is cloned and modified for each token.
|
||||
|
||||
Returns:
|
||||
A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format.
|
||||
"""
|
||||
for token in output:
|
||||
chunk = base_response.copy()
|
||||
chunk.choices = [dict(CompletionChoice(
|
||||
text=token,
|
||||
index=0,
|
||||
logprobs=-1,
|
||||
finish_reason=''
|
||||
))]
|
||||
yield f"data: {json.dumps(dict(chunk))}\n\n"
|
||||
|
||||
|
||||
@router.post("/", response_model=CompletionResponse)
|
||||
async def completions(request: CompletionRequest):
|
||||
'''
|
||||
Completes a GPT4All model response.
|
||||
'''
|
||||
|
||||
# global model
|
||||
if request.stream:
|
||||
raise NotImplementedError("Streaming is not yet implements")
|
||||
|
||||
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
||||
|
||||
output = model.generate(prompt=request.prompt,
|
||||
n_predict = request.max_tokens,
|
||||
top_k = 20,
|
||||
top_p = request.top_p,
|
||||
temp=request.temperature,
|
||||
n_batch = 1024,
|
||||
repeat_penalty = 1.2,
|
||||
repeat_last_n = 10,
|
||||
context_erase = 0)
|
||||
|
||||
|
||||
return CompletionResponse(
|
||||
id=str(uuid4()),
|
||||
created=time.time(),
|
||||
model=request.model,
|
||||
choices=[dict(CompletionChoice(
|
||||
text=output,
|
||||
index=0,
|
||||
logprobs=-1,
|
||||
finish_reason='stop'
|
||||
))],
|
||||
usage={
|
||||
'prompt_tokens': 0, #TODO how to compute this?
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0
|
||||
}
|
||||
)
|
||||
|
||||
n_predict=request.max_tokens,
|
||||
streaming=request.stream,
|
||||
top_k=20,
|
||||
top_p=request.top_p,
|
||||
temp=request.temperature,
|
||||
n_batch=1024,
|
||||
repeat_penalty=1.2,
|
||||
repeat_last_n=10)
|
||||
|
||||
# If streaming, we need to return a StreamingResponse
|
||||
if request.stream:
|
||||
base_chunk = CompletionStreamResponse(
|
||||
id=str(uuid4()),
|
||||
created=time.time(),
|
||||
model=request.model,
|
||||
choices=[]
|
||||
)
|
||||
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
return CompletionResponse(
|
||||
id=str(uuid4()),
|
||||
created=time.time(),
|
||||
model=request.model,
|
||||
choices=[dict(CompletionChoice(
|
||||
text=output,
|
||||
index=0,
|
||||
logprobs=-1,
|
||||
finish_reason='stop'
|
||||
))],
|
||||
usage={
|
||||
'prompt_tokens': 0, #TODO how to compute this?
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0
|
||||
}
|
||||
)
|
||||
|
Reference in New Issue
Block a user