diff --git a/gpt4all-api/gpt4all_api/app/api_v1/routes/chat.py b/gpt4all-api/gpt4all_api/app/api_v1/routes/chat.py index 381bf98f..aeccbd49 100644 --- a/gpt4all-api/gpt4all_api/app/api_v1/routes/chat.py +++ b/gpt4all-api/gpt4all_api/app/api_v1/routes/chat.py @@ -2,7 +2,8 @@ import logging import time from typing import List from uuid import uuid4 -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException +from gpt4all import GPT4All from pydantic import BaseModel, Field from api_v1.settings import settings from fastapi.responses import StreamingResponse @@ -18,6 +19,7 @@ class ChatCompletionMessage(BaseModel): class ChatCompletionRequest(BaseModel): model: str = Field(settings.model, description='The model to generate a completion from.') messages: List[ChatCompletionMessage] = Field(..., description='Messages for the chat completion.') + temperature: float = Field(settings.temp, description='Model temperature') class ChatCompletionChoice(BaseModel): message: ChatCompletionMessage @@ -45,15 +47,41 @@ async def chat_completion(request: ChatCompletionRequest): ''' Completes a GPT4All model response based on the last message in the chat. ''' - # Example: Echo the last message content with some modification + # GPU is not implemented yet + if settings.inference_mode == "gpu": + raise HTTPException(status_code=400, + detail=f"Not implemented yet: Can only infere in CPU mode.") + + # we only support the configured model + if request.model != settings.model: + raise HTTPException(status_code=400, + detail=f"The GPT4All inference server is booted to only infer: `{settings.model}`") + + # run only of we have a message if request.messages: - last_message = request.messages[-1].content - response_content = f"Echo: {last_message}" + model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path) + + # format system message and conversation history correctly + formatted_messages = "" + for message in request.messages: + formatted_messages += f"<|im_start|>{message.role}\n{message.content}<|im_end|>\n" + + # the LLM will complete the response of the assistant + formatted_messages += "<|im_start|>assistant\n" + response = model.generate( + prompt=formatted_messages, + temp=request.temperature + ) + + # the LLM may continue to hallucinate the conversation, but we want only the first response + # so, cut off everything after first <|im_end|> + index = response.find("<|im_end|>") + response_content = response[:index].strip() else: response_content = "No messages received." # Create a chat message for the response - response_message = ChatCompletionMessage(role="system", content=response_content) + response_message = ChatCompletionMessage(role="assistant", content=response_content) # Create a choice object with the response message response_choice = ChatCompletionChoice( diff --git a/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py b/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py index a310125a..c32b6220 100644 --- a/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py +++ b/gpt4all-api/gpt4all_api/app/tests/test_endpoints.py @@ -51,7 +51,7 @@ def test_batched_completion(): model = model_id # replace with your specific model ID prompt = "Who is Michael Jordan?" responses = [] - + # Loop to create completions one at a time for _ in range(3): response = openai.Completion.create( @@ -62,7 +62,7 @@ def test_batched_completion(): # Assertions to check the responses for response in responses: assert len(response['choices'][0]['text']) > len(prompt) - + assert len(responses) == 3 def test_embedding(): @@ -74,4 +74,20 @@ def test_embedding(): assert response["model"] == model assert isinstance(output, list) - assert all(isinstance(x, args) for x in output) \ No newline at end of file + assert all(isinstance(x, args) for x in output) + +def test_chat_completion(): + model = model_id + + response = openai.ChatCompletion.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Knock knock."}, + {"role": "assistant", "content": "Who's there?"}, + {"role": "user", "content": "Orange."}, + ] + ) + + assert response.choices[0].message.role == "assistant" + assert len(response.choices[0].message.content) > 0