mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-12 13:21:58 +00:00
server: improve correctness of request parsing and responses (#2929)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
1aae4ffe0a
commit
39005288c5
@ -317,9 +317,9 @@ jobs:
|
|||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
packages=(
|
packages=(
|
||||||
bison build-essential ccache cuda-compiler-11-8 flex gperf libcublas-dev-11-8 libfontconfig1 libfreetype6
|
bison build-essential ccache cuda-compiler-11-8 flex g++-12 gperf libcublas-dev-11-8 libfontconfig1
|
||||||
libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev libx11-6
|
libfreetype6 libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev
|
||||||
libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0
|
libx11-6 libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0
|
||||||
libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0
|
libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0
|
||||||
libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 patchelf
|
libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 patchelf
|
||||||
python3 vulkan-sdk
|
python3 vulkan-sdk
|
||||||
@ -352,6 +352,8 @@ jobs:
|
|||||||
~/Qt/Tools/CMake/bin/cmake \
|
~/Qt/Tools/CMake/bin/cmake \
|
||||||
-S ../gpt4all-chat -B . \
|
-S ../gpt4all-chat -B . \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DCMAKE_C_COMPILER=gcc-12 \
|
||||||
|
-DCMAKE_CXX_COMPILER=g++-12 \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
||||||
@ -391,9 +393,9 @@ jobs:
|
|||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
packages=(
|
packages=(
|
||||||
bison build-essential ccache cuda-compiler-11-8 flex gperf libcublas-dev-11-8 libfontconfig1 libfreetype6
|
bison build-essential ccache cuda-compiler-11-8 flex g++-12 gperf libcublas-dev-11-8 libfontconfig1
|
||||||
libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev libx11-6
|
libfreetype6 libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev
|
||||||
libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0
|
libx11-6 libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0
|
||||||
libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0
|
libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0
|
||||||
libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 patchelf
|
libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 patchelf
|
||||||
python3 vulkan-sdk
|
python3 vulkan-sdk
|
||||||
@ -426,6 +428,8 @@ jobs:
|
|||||||
~/Qt/Tools/CMake/bin/cmake \
|
~/Qt/Tools/CMake/bin/cmake \
|
||||||
-S ../gpt4all-chat -B . \
|
-S ../gpt4all-chat -B . \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DCMAKE_C_COMPILER=gcc-12 \
|
||||||
|
-DCMAKE_CXX_COMPILER=g++-12 \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
||||||
@ -447,7 +451,7 @@ jobs:
|
|||||||
|
|
||||||
build-offline-chat-installer-windows:
|
build-offline-chat-installer-windows:
|
||||||
machine:
|
machine:
|
||||||
image: 'windows-server-2019-vs2019:2022.08.1'
|
image: windows-server-2022-gui:current
|
||||||
resource_class: windows.large
|
resource_class: windows.large
|
||||||
shell: powershell.exe -ExecutionPolicy Bypass
|
shell: powershell.exe -ExecutionPolicy Bypass
|
||||||
steps:
|
steps:
|
||||||
@ -538,7 +542,7 @@ jobs:
|
|||||||
|
|
||||||
sign-offline-chat-installer-windows:
|
sign-offline-chat-installer-windows:
|
||||||
machine:
|
machine:
|
||||||
image: 'windows-server-2019-vs2019:2022.08.1'
|
image: windows-server-2022-gui:current
|
||||||
resource_class: windows.large
|
resource_class: windows.large
|
||||||
shell: powershell.exe -ExecutionPolicy Bypass
|
shell: powershell.exe -ExecutionPolicy Bypass
|
||||||
steps:
|
steps:
|
||||||
@ -568,7 +572,7 @@ jobs:
|
|||||||
|
|
||||||
build-online-chat-installer-windows:
|
build-online-chat-installer-windows:
|
||||||
machine:
|
machine:
|
||||||
image: 'windows-server-2019-vs2019:2022.08.1'
|
image: windows-server-2022-gui:current
|
||||||
resource_class: windows.large
|
resource_class: windows.large
|
||||||
shell: powershell.exe -ExecutionPolicy Bypass
|
shell: powershell.exe -ExecutionPolicy Bypass
|
||||||
steps:
|
steps:
|
||||||
@ -666,7 +670,7 @@ jobs:
|
|||||||
|
|
||||||
sign-online-chat-installer-windows:
|
sign-online-chat-installer-windows:
|
||||||
machine:
|
machine:
|
||||||
image: 'windows-server-2019-vs2019:2022.08.1'
|
image: windows-server-2022-gui:current
|
||||||
resource_class: windows.large
|
resource_class: windows.large
|
||||||
shell: powershell.exe -ExecutionPolicy Bypass
|
shell: powershell.exe -ExecutionPolicy Bypass
|
||||||
steps:
|
steps:
|
||||||
@ -720,9 +724,9 @@ jobs:
|
|||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
packages=(
|
packages=(
|
||||||
bison build-essential ccache cuda-compiler-11-8 flex gperf libcublas-dev-11-8 libfontconfig1 libfreetype6
|
bison build-essential ccache cuda-compiler-11-8 flex g++-12 gperf libcublas-dev-11-8 libfontconfig1
|
||||||
libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev libx11-6
|
libfreetype6 libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev
|
||||||
libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0
|
libx11-6 libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0
|
||||||
libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0
|
libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0
|
||||||
libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 python3
|
libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 python3
|
||||||
vulkan-sdk
|
vulkan-sdk
|
||||||
@ -744,6 +748,8 @@ jobs:
|
|||||||
~/Qt/Tools/CMake/bin/cmake \
|
~/Qt/Tools/CMake/bin/cmake \
|
||||||
-S gpt4all-chat -B build \
|
-S gpt4all-chat -B build \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DCMAKE_C_COMPILER=gcc-12 \
|
||||||
|
-DCMAKE_CXX_COMPILER=g++-12 \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
||||||
@ -758,7 +764,7 @@ jobs:
|
|||||||
|
|
||||||
build-gpt4all-chat-windows:
|
build-gpt4all-chat-windows:
|
||||||
machine:
|
machine:
|
||||||
image: 'windows-server-2019-vs2019:2022.08.1'
|
image: windows-server-2022-gui:current
|
||||||
resource_class: windows.large
|
resource_class: windows.large
|
||||||
shell: powershell.exe -ExecutionPolicy Bypass
|
shell: powershell.exe -ExecutionPolicy Bypass
|
||||||
steps:
|
steps:
|
||||||
@ -928,7 +934,8 @@ jobs:
|
|||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
packages=(
|
packages=(
|
||||||
build-essential ccache cmake cuda-compiler-11-8 libcublas-dev-11-8 libnvidia-compute-550-server vulkan-sdk
|
build-essential ccache cmake cuda-compiler-11-8 g++-12 libcublas-dev-11-8 libnvidia-compute-550-server
|
||||||
|
vulkan-sdk
|
||||||
)
|
)
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y "${packages[@]}"
|
sudo apt-get install -y "${packages[@]}"
|
||||||
@ -942,6 +949,8 @@ jobs:
|
|||||||
cd gpt4all-backend
|
cd gpt4all-backend
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DCMAKE_C_COMPILER=gcc-12 \
|
||||||
|
-DCMAKE_CXX_COMPILER=g++-12 \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
||||||
@ -1014,7 +1023,7 @@ jobs:
|
|||||||
|
|
||||||
build-py-windows:
|
build-py-windows:
|
||||||
machine:
|
machine:
|
||||||
image: 'windows-server-2019-vs2019:2022.08.1'
|
image: windows-server-2022-gui:current
|
||||||
resource_class: windows.large
|
resource_class: windows.large
|
||||||
shell: powershell.exe -ExecutionPolicy Bypass
|
shell: powershell.exe -ExecutionPolicy Bypass
|
||||||
steps:
|
steps:
|
||||||
@ -1122,7 +1131,8 @@ jobs:
|
|||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
packages=(
|
packages=(
|
||||||
build-essential ccache cmake cuda-compiler-11-8 libcublas-dev-11-8 libnvidia-compute-550-server vulkan-sdk
|
build-essential ccache cmake cuda-compiler-11-8 g++-12 libcublas-dev-11-8 libnvidia-compute-550-server
|
||||||
|
vulkan-sdk
|
||||||
)
|
)
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y "${packages[@]}"
|
sudo apt-get install -y "${packages[@]}"
|
||||||
@ -1135,6 +1145,9 @@ jobs:
|
|||||||
mkdir -p runtimes/build
|
mkdir -p runtimes/build
|
||||||
cd runtimes/build
|
cd runtimes/build
|
||||||
cmake ../.. \
|
cmake ../.. \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DCMAKE_C_COMPILER=gcc-12 \
|
||||||
|
-DCMAKE_C_COMPILER=g++-12 \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
@ -1204,7 +1217,7 @@ jobs:
|
|||||||
|
|
||||||
build-bindings-backend-windows:
|
build-bindings-backend-windows:
|
||||||
machine:
|
machine:
|
||||||
image: 'windows-server-2022-gui:2023.03.1'
|
image: windows-server-2022-gui:current
|
||||||
resource_class: windows.large
|
resource_class: windows.large
|
||||||
shell: powershell.exe -ExecutionPolicy Bypass
|
shell: powershell.exe -ExecutionPolicy Bypass
|
||||||
steps:
|
steps:
|
||||||
|
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -8,3 +8,6 @@
|
|||||||
[submodule "gpt4all-chat/deps/SingleApplication"]
|
[submodule "gpt4all-chat/deps/SingleApplication"]
|
||||||
path = gpt4all-chat/deps/SingleApplication
|
path = gpt4all-chat/deps/SingleApplication
|
||||||
url = https://github.com/nomic-ai/SingleApplication.git
|
url = https://github.com/nomic-ai/SingleApplication.git
|
||||||
|
[submodule "gpt4all-chat/deps/fmt"]
|
||||||
|
path = gpt4all-chat/deps/fmt
|
||||||
|
url = https://github.com/fmtlib/fmt.git
|
||||||
|
@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
|
|||||||
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
|
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
|
||||||
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 20)
|
set(CMAKE_CXX_STANDARD 23)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||||
set(BUILD_SHARED_LIBS ON)
|
set(BUILD_SHARED_LIBS ON)
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 443665aec4721ecf57df8162e7e093a0cd674a76
|
Subproject commit ced74fbad4b258507f3ec06e77eec9445583511a
|
@ -162,7 +162,7 @@ public:
|
|||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
PromptContext &ctx,
|
PromptContext &ctx,
|
||||||
bool special = false,
|
bool special = false,
|
||||||
std::string *fakeReply = nullptr);
|
std::optional<std::string_view> fakeReply = {});
|
||||||
|
|
||||||
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ public:
|
|||||||
protected:
|
protected:
|
||||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||||
// 'prompt' above calls these functions
|
// 'prompt' above calls these functions
|
||||||
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
|
virtual std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0;
|
||||||
virtual bool isSpecialToken(Token id) const = 0;
|
virtual bool isSpecialToken(Token id) const = 0;
|
||||||
virtual std::string tokenToString(Token id) const = 0;
|
virtual std::string tokenToString(Token id) const = 0;
|
||||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||||
@ -249,7 +249,8 @@ protected:
|
|||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
PromptContext &promptCtx,
|
PromptContext &promptCtx,
|
||||||
std::vector<Token> embd_inp);
|
std::vector<Token> embd_inp,
|
||||||
|
bool isResponse = false);
|
||||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
PromptContext &promptCtx);
|
PromptContext &promptCtx);
|
||||||
|
@ -536,13 +536,13 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
|||||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
|
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special)
|
||||||
{
|
{
|
||||||
bool atStart = m_tokenize_last_token == -1;
|
bool atStart = m_tokenize_last_token == -1;
|
||||||
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
|
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
|
||||||
std::vector<LLModel::Token> fres(str.length() + 4);
|
std::vector<LLModel::Token> fres(str.length() + 4);
|
||||||
int32_t fres_len = llama_tokenize_gpt4all(
|
int32_t fres_len = llama_tokenize_gpt4all(
|
||||||
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
||||||
/*parse_special*/ special, /*insert_space*/ insertSpace
|
/*parse_special*/ special, /*insert_space*/ insertSpace
|
||||||
);
|
);
|
||||||
fres.resize(fres_len);
|
fres.resize(fres_len);
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
struct LLamaPrivate;
|
struct LLamaPrivate;
|
||||||
@ -52,7 +53,7 @@ private:
|
|||||||
bool m_supportsCompletion = false;
|
bool m_supportsCompletion = false;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
|
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override;
|
||||||
bool isSpecialToken(Token id) const override;
|
bool isSpecialToken(Token id) const override;
|
||||||
std::string tokenToString(Token id) const override;
|
std::string tokenToString(Token id) const override;
|
||||||
Token sampleToken(PromptContext &ctx) const override;
|
Token sampleToken(PromptContext &ctx) const override;
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
struct LLModelWrapper {
|
struct LLModelWrapper {
|
||||||
@ -130,13 +131,10 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
||||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||||
|
|
||||||
std::string fake_reply_str;
|
|
||||||
if (fake_reply) { fake_reply_str = fake_reply; }
|
|
||||||
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
|
|
||||||
|
|
||||||
// Call the C++ prompt method
|
// Call the C++ prompt method
|
||||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
||||||
wrapper->promptContext, special, fake_reply_p);
|
wrapper->promptContext, special,
|
||||||
|
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
|
||||||
|
|
||||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||||
// which involves no copies
|
// which involves no copies
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace ranges = std::ranges;
|
namespace ranges = std::ranges;
|
||||||
@ -45,7 +46,7 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
PromptContext &promptCtx,
|
PromptContext &promptCtx,
|
||||||
bool special,
|
bool special,
|
||||||
std::string *fakeReply)
|
std::optional<std::string_view> fakeReply)
|
||||||
{
|
{
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
||||||
@ -129,11 +130,11 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
return; // error
|
return; // error
|
||||||
|
|
||||||
// decode the assistant's reply, either generated or spoofed
|
// decode the assistant's reply, either generated or spoofed
|
||||||
if (fakeReply == nullptr) {
|
if (!fakeReply) {
|
||||||
generateResponse(responseCallback, allowContextShift, promptCtx);
|
generateResponse(responseCallback, allowContextShift, promptCtx);
|
||||||
} else {
|
} else {
|
||||||
embd_inp = tokenize(promptCtx, *fakeReply, false);
|
embd_inp = tokenize(promptCtx, *fakeReply, false);
|
||||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
|
||||||
return; // error
|
return; // error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,7 +158,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
|||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
PromptContext &promptCtx,
|
PromptContext &promptCtx,
|
||||||
std::vector<Token> embd_inp) {
|
std::vector<Token> embd_inp,
|
||||||
|
bool isResponse) {
|
||||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||||
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
||||||
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
||||||
@ -196,7 +198,9 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
|||||||
for (size_t t = 0; t < tokens; ++t) {
|
for (size_t t = 0; t < tokens; ++t) {
|
||||||
promptCtx.tokens.push_back(batch.at(t));
|
promptCtx.tokens.push_back(batch.at(t));
|
||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
if (!promptCallback(batch.at(t)))
|
Token tok = batch.at(t);
|
||||||
|
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
||||||
|
if (!res)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
i = batch_end;
|
i = batch_end;
|
||||||
|
@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
|||||||
- Fix a typo in Model Settings (by [@3Simplex](https://github.com/3Simplex) in [#2916](https://github.com/nomic-ai/gpt4all/pull/2916))
|
- Fix a typo in Model Settings (by [@3Simplex](https://github.com/3Simplex) in [#2916](https://github.com/nomic-ai/gpt4all/pull/2916))
|
||||||
- Fix the antenna icon tooltip when using the local server ([#2922](https://github.com/nomic-ai/gpt4all/pull/2922))
|
- Fix the antenna icon tooltip when using the local server ([#2922](https://github.com/nomic-ai/gpt4all/pull/2922))
|
||||||
- Fix a few issues with locating files and handling errors when loading remote models on startup ([#2875](https://github.com/nomic-ai/gpt4all/pull/2875))
|
- Fix a few issues with locating files and handling errors when loading remote models on startup ([#2875](https://github.com/nomic-ai/gpt4all/pull/2875))
|
||||||
|
- Significantly improve API server request parsing and response correctness ([#2929](https://github.com/nomic-ai/gpt4all/pull/2929))
|
||||||
|
|
||||||
## [3.2.1] - 2024-08-13
|
## [3.2.1] - 2024-08-13
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.16)
|
cmake_minimum_required(VERSION 3.16)
|
||||||
|
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
set(CMAKE_CXX_STANDARD 20)
|
set(CMAKE_CXX_STANDARD 23)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
@ -64,6 +64,12 @@ message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}")
|
|||||||
|
|
||||||
set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
|
||||||
|
set(FMT_INSTALL OFF)
|
||||||
|
set(BUILD_SHARED_LIBS_SAVED "${BUILD_SHARED_LIBS}")
|
||||||
|
set(BUILD_SHARED_LIBS OFF)
|
||||||
|
add_subdirectory(deps/fmt)
|
||||||
|
set(BUILD_SHARED_LIBS "${BUILD_SHARED_LIBS_SAVED}")
|
||||||
|
|
||||||
add_subdirectory(../gpt4all-backend llmodel)
|
add_subdirectory(../gpt4all-backend llmodel)
|
||||||
|
|
||||||
set(CHAT_EXE_RESOURCES)
|
set(CHAT_EXE_RESOURCES)
|
||||||
@ -240,7 +246,7 @@ else()
|
|||||||
PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf)
|
PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf)
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(chat
|
target_link_libraries(chat
|
||||||
PRIVATE llmodel SingleApplication)
|
PRIVATE llmodel SingleApplication fmt::fmt)
|
||||||
|
|
||||||
|
|
||||||
# -- install --
|
# -- install --
|
||||||
|
1
gpt4all-chat/deps/fmt
Submodule
1
gpt4all-chat/deps/fmt
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 0c9fce2ffefecfdce794e1859584e25877b7b592
|
@ -239,16 +239,17 @@ void Chat::newPromptResponsePair(const QString &prompt)
|
|||||||
resetResponseState();
|
resetResponseState();
|
||||||
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
||||||
m_chatModel->appendPrompt("Prompt: ", prompt);
|
m_chatModel->appendPrompt("Prompt: ", prompt);
|
||||||
m_chatModel->appendResponse("Response: ", prompt);
|
m_chatModel->appendResponse("Response: ", QString());
|
||||||
emit resetResponseRequested();
|
emit resetResponseRequested();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
|
||||||
void Chat::serverNewPromptResponsePair(const QString &prompt)
|
void Chat::serverNewPromptResponsePair(const QString &prompt)
|
||||||
{
|
{
|
||||||
resetResponseState();
|
resetResponseState();
|
||||||
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
||||||
m_chatModel->appendPrompt("Prompt: ", prompt);
|
m_chatModel->appendPrompt("Prompt: ", prompt);
|
||||||
m_chatModel->appendResponse("Response: ", prompt);
|
m_chatModel->appendResponse("Response: ", QString());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Chat::restoringFromText() const
|
bool Chat::restoringFromText() const
|
||||||
|
@ -93,7 +93,7 @@ void ChatAPI::prompt(const std::string &prompt,
|
|||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
PromptContext &promptCtx,
|
PromptContext &promptCtx,
|
||||||
bool special,
|
bool special,
|
||||||
std::string *fakeReply) {
|
std::optional<std::string_view> fakeReply) {
|
||||||
|
|
||||||
Q_UNUSED(promptCallback);
|
Q_UNUSED(promptCallback);
|
||||||
Q_UNUSED(allowContextShift);
|
Q_UNUSED(allowContextShift);
|
||||||
@ -121,7 +121,7 @@ void ChatAPI::prompt(const std::string &prompt,
|
|||||||
if (fakeReply) {
|
if (fakeReply) {
|
||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
m_context.append(formattedPrompt);
|
m_context.append(formattedPrompt);
|
||||||
m_context.append(QString::fromStdString(*fakeReply));
|
m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,9 +12,10 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <stdexcept>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
class QNetworkAccessManager;
|
class QNetworkAccessManager;
|
||||||
@ -72,7 +73,7 @@ public:
|
|||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
PromptContext &ctx,
|
PromptContext &ctx,
|
||||||
bool special,
|
bool special,
|
||||||
std::string *fakeReply) override;
|
std::optional<std::string_view> fakeReply) override;
|
||||||
|
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() const override;
|
int32_t threadCount() const override;
|
||||||
@ -97,7 +98,7 @@ protected:
|
|||||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||||
// completely replace
|
// completely replace
|
||||||
|
|
||||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override
|
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override
|
||||||
{
|
{
|
||||||
(void)ctx;
|
(void)ctx;
|
||||||
(void)str;
|
(void)str;
|
||||||
|
@ -626,16 +626,16 @@ void ChatLLM::regenerateResponse()
|
|||||||
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
||||||
m_promptResponseTokens = 0;
|
m_promptResponseTokens = 0;
|
||||||
m_promptTokens = 0;
|
m_promptTokens = 0;
|
||||||
m_response = std::string();
|
m_response = m_trimmedResponse = std::string();
|
||||||
emit responseChanged(QString::fromStdString(m_response));
|
emit responseChanged(QString::fromStdString(m_trimmedResponse));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatLLM::resetResponse()
|
void ChatLLM::resetResponse()
|
||||||
{
|
{
|
||||||
m_promptTokens = 0;
|
m_promptTokens = 0;
|
||||||
m_promptResponseTokens = 0;
|
m_promptResponseTokens = 0;
|
||||||
m_response = std::string();
|
m_response = m_trimmedResponse = std::string();
|
||||||
emit responseChanged(QString::fromStdString(m_response));
|
emit responseChanged(QString::fromStdString(m_trimmedResponse));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatLLM::resetContext()
|
void ChatLLM::resetContext()
|
||||||
@ -645,9 +645,12 @@ void ChatLLM::resetContext()
|
|||||||
m_ctx = LLModel::PromptContext();
|
m_ctx = LLModel::PromptContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
QString ChatLLM::response() const
|
QString ChatLLM::response(bool trim) const
|
||||||
{
|
{
|
||||||
return QString::fromStdString(remove_leading_whitespace(m_response));
|
std::string resp = m_response;
|
||||||
|
if (trim)
|
||||||
|
resp = remove_leading_whitespace(resp);
|
||||||
|
return QString::fromStdString(resp);
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelInfo ChatLLM::modelInfo() const
|
ModelInfo ChatLLM::modelInfo() const
|
||||||
@ -705,7 +708,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
|
|||||||
// check for error
|
// check for error
|
||||||
if (token < 0) {
|
if (token < 0) {
|
||||||
m_response.append(response);
|
m_response.append(response);
|
||||||
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
|
m_trimmedResponse = remove_leading_whitespace(m_response);
|
||||||
|
emit responseChanged(QString::fromStdString(m_trimmedResponse));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -715,7 +719,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
|
|||||||
m_timer->inc();
|
m_timer->inc();
|
||||||
Q_ASSERT(!response.empty());
|
Q_ASSERT(!response.empty());
|
||||||
m_response.append(response);
|
m_response.append(response);
|
||||||
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
|
m_trimmedResponse = remove_leading_whitespace(m_response);
|
||||||
|
emit responseChanged(QString::fromStdString(m_trimmedResponse));
|
||||||
return !m_stopGenerating;
|
return !m_stopGenerating;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -741,7 +746,7 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
|
|||||||
|
|
||||||
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||||
int32_t repeat_penalty_tokens)
|
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply)
|
||||||
{
|
{
|
||||||
if (!isModelLoaded())
|
if (!isModelLoaded())
|
||||||
return false;
|
return false;
|
||||||
@ -751,7 +756,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
|||||||
|
|
||||||
QList<ResultInfo> databaseResults;
|
QList<ResultInfo> databaseResults;
|
||||||
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
|
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
|
||||||
if (!collectionList.isEmpty()) {
|
if (!fakeReply && !collectionList.isEmpty()) {
|
||||||
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
|
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
|
||||||
emit databaseResultsChanged(databaseResults);
|
emit databaseResultsChanged(databaseResults);
|
||||||
}
|
}
|
||||||
@ -797,7 +802,8 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
|||||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||||
}
|
}
|
||||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||||
/*allowContextShift*/ true, m_ctx);
|
/*allowContextShift*/ true, m_ctx, false,
|
||||||
|
fakeReply.transform(std::mem_fn(&QString::toStdString)));
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
printf("\n");
|
printf("\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
@ -805,9 +811,9 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
|||||||
m_timer->stop();
|
m_timer->stop();
|
||||||
qint64 elapsed = totalTime.elapsed();
|
qint64 elapsed = totalTime.elapsed();
|
||||||
std::string trimmed = trim_whitespace(m_response);
|
std::string trimmed = trim_whitespace(m_response);
|
||||||
if (trimmed != m_response) {
|
if (trimmed != m_trimmedResponse) {
|
||||||
m_response = trimmed;
|
m_trimmedResponse = trimmed;
|
||||||
emit responseChanged(QString::fromStdString(m_response));
|
emit responseChanged(QString::fromStdString(m_trimmedResponse));
|
||||||
}
|
}
|
||||||
|
|
||||||
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
|
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
|
||||||
@ -1078,6 +1084,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
|||||||
QString response;
|
QString response;
|
||||||
stream >> response;
|
stream >> response;
|
||||||
m_response = response.toStdString();
|
m_response = response.toStdString();
|
||||||
|
m_trimmedResponse = trim_whitespace(m_response);
|
||||||
QString nameResponse;
|
QString nameResponse;
|
||||||
stream >> nameResponse;
|
stream >> nameResponse;
|
||||||
m_nameResponse = nameResponse.toStdString();
|
m_nameResponse = nameResponse.toStdString();
|
||||||
@ -1306,10 +1313,9 @@ void ChatLLM::processRestoreStateFromText()
|
|||||||
|
|
||||||
auto &response = *it++;
|
auto &response = *it++;
|
||||||
Q_ASSERT(response.first != "Prompt: ");
|
Q_ASSERT(response.first != "Prompt: ");
|
||||||
auto responseText = response.second.toStdString();
|
|
||||||
|
|
||||||
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
|
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
|
||||||
/*allowContextShift*/ true, m_ctx, false, &responseText);
|
/*allowContextShift*/ true, m_ctx, false, response.second.toUtf8().constData());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!m_stopGenerating) {
|
if (!m_stopGenerating) {
|
||||||
|
@ -116,7 +116,7 @@ public:
|
|||||||
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
|
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
|
||||||
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
|
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
|
||||||
|
|
||||||
QString response() const;
|
QString response(bool trim = true) const;
|
||||||
|
|
||||||
ModelInfo modelInfo() const;
|
ModelInfo modelInfo() const;
|
||||||
void setModelInfo(const ModelInfo &info);
|
void setModelInfo(const ModelInfo &info);
|
||||||
@ -198,7 +198,7 @@ Q_SIGNALS:
|
|||||||
protected:
|
protected:
|
||||||
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||||
int32_t repeat_penalty_tokens);
|
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply = {});
|
||||||
bool handlePrompt(int32_t token);
|
bool handlePrompt(int32_t token);
|
||||||
bool handleResponse(int32_t token, const std::string &response);
|
bool handleResponse(int32_t token, const std::string &response);
|
||||||
bool handleNamePrompt(int32_t token);
|
bool handleNamePrompt(int32_t token);
|
||||||
@ -221,6 +221,7 @@ private:
|
|||||||
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
||||||
|
|
||||||
std::string m_response;
|
std::string m_response;
|
||||||
|
std::string m_trimmedResponse;
|
||||||
std::string m_nameResponse;
|
std::string m_nameResponse;
|
||||||
QString m_questionResponse;
|
QString m_questionResponse;
|
||||||
LLModelInfo m_llModelInfo;
|
LLModelInfo m_llModelInfo;
|
||||||
|
@ -20,24 +20,25 @@ class LocalDocsCollectionsModel : public QSortFilterProxyModel
|
|||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
Q_PROPERTY(int count READ count NOTIFY countChanged)
|
Q_PROPERTY(int count READ count NOTIFY countChanged)
|
||||||
Q_PROPERTY(int updatingCount READ updatingCount NOTIFY updatingCountChanged)
|
Q_PROPERTY(int updatingCount READ updatingCount NOTIFY updatingCountChanged)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit LocalDocsCollectionsModel(QObject *parent);
|
explicit LocalDocsCollectionsModel(QObject *parent);
|
||||||
|
int count() const { return rowCount(); }
|
||||||
|
int updatingCount() const;
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
int count() const { return rowCount(); }
|
|
||||||
void setCollections(const QList<QString> &collections);
|
void setCollections(const QList<QString> &collections);
|
||||||
int updatingCount() const;
|
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void countChanged();
|
void countChanged();
|
||||||
void updatingCountChanged();
|
void updatingCountChanged();
|
||||||
|
|
||||||
private Q_SLOT:
|
|
||||||
void maybeTriggerUpdatingCountChanged();
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
|
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
|
||||||
|
|
||||||
|
private Q_SLOTS:
|
||||||
|
void maybeTriggerUpdatingCountChanged();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
QList<QString> m_collections;
|
QList<QString> m_collections;
|
||||||
int m_updatingCount = 0;
|
int m_updatingCount = 0;
|
||||||
|
@ -18,10 +18,12 @@
|
|||||||
#include <QVector>
|
#include <QVector>
|
||||||
#include <Qt>
|
#include <Qt>
|
||||||
#include <QtGlobal>
|
#include <QtGlobal>
|
||||||
#include <QtQml>
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
using namespace Qt::Literals::StringLiterals;
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
|
||||||
|
|
||||||
struct ModelInfo {
|
struct ModelInfo {
|
||||||
Q_GADGET
|
Q_GADGET
|
||||||
Q_PROPERTY(QString id READ id WRITE setId)
|
Q_PROPERTY(QString id READ id WRITE setId)
|
||||||
@ -523,7 +525,7 @@ private:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit ModelList();
|
explicit ModelList();
|
||||||
~ModelList() { for (auto *model: m_models) { delete model; } }
|
~ModelList() override { for (auto *model: std::as_const(m_models)) { delete model; } }
|
||||||
friend class MyModelList;
|
friend class MyModelList;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <QSettings>
|
#include <QSettings>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
#include <QStringList>
|
#include <QStringList>
|
||||||
|
#include <QTranslator>
|
||||||
#include <QVector>
|
#include <QVector>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
@ -4,7 +4,13 @@
|
|||||||
#include "modellist.h"
|
#include "modellist.h"
|
||||||
#include "mysettings.h"
|
#include "mysettings.h"
|
||||||
|
|
||||||
|
#include <fmt/base.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include <QByteArray>
|
#include <QByteArray>
|
||||||
|
#include <QCborArray>
|
||||||
|
#include <QCborMap>
|
||||||
|
#include <QCborValue>
|
||||||
#include <QDateTime>
|
#include <QDateTime>
|
||||||
#include <QDebug>
|
#include <QDebug>
|
||||||
#include <QHostAddress>
|
#include <QHostAddress>
|
||||||
@ -14,19 +20,67 @@
|
|||||||
#include <QJsonDocument>
|
#include <QJsonDocument>
|
||||||
#include <QJsonObject>
|
#include <QJsonObject>
|
||||||
#include <QJsonValue>
|
#include <QJsonValue>
|
||||||
|
#include <QLatin1StringView>
|
||||||
#include <QPair>
|
#include <QPair>
|
||||||
|
#include <QVariant>
|
||||||
#include <Qt>
|
#include <Qt>
|
||||||
|
#include <QtCborCommon>
|
||||||
#include <QtLogging>
|
#include <QtLogging>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <optional>
|
||||||
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
namespace ranges = std::ranges;
|
||||||
|
using namespace std::string_literals;
|
||||||
using namespace Qt::Literals::StringLiterals;
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
|
||||||
//#define DEBUG
|
//#define DEBUG
|
||||||
|
|
||||||
|
|
||||||
|
#define MAKE_FORMATTER(type, conversion) \
|
||||||
|
template <> \
|
||||||
|
struct fmt::formatter<type, char>: fmt::formatter<std::string, char> { \
|
||||||
|
template <typename FmtContext> \
|
||||||
|
FmtContext::iterator format(const type &value, FmtContext &ctx) const \
|
||||||
|
{ \
|
||||||
|
return formatter<std::string, char>::format(conversion, ctx); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
MAKE_FORMATTER(QString, value.toStdString() );
|
||||||
|
MAKE_FORMATTER(QVariant, value.toString().toStdString());
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class InvalidRequestError: public std::invalid_argument {
|
||||||
|
using std::invalid_argument::invalid_argument;
|
||||||
|
|
||||||
|
public:
|
||||||
|
QHttpServerResponse asResponse() const
|
||||||
|
{
|
||||||
|
QJsonObject error {
|
||||||
|
{ "message", what(), },
|
||||||
|
{ "type", u"invalid_request_error"_s, },
|
||||||
|
{ "param", QJsonValue::Null },
|
||||||
|
{ "code", QJsonValue::Null },
|
||||||
|
};
|
||||||
|
return { QJsonObject {{ "error", error }},
|
||||||
|
QHttpServerResponder::StatusCode::BadRequest };
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Q_DISABLE_COPY_MOVE(InvalidRequestError)
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
static inline QJsonObject modelToJson(const ModelInfo &info)
|
static inline QJsonObject modelToJson(const ModelInfo &info)
|
||||||
{
|
{
|
||||||
QJsonObject model;
|
QJsonObject model;
|
||||||
@ -39,7 +93,7 @@ static inline QJsonObject modelToJson(const ModelInfo &info)
|
|||||||
|
|
||||||
QJsonArray permissions;
|
QJsonArray permissions;
|
||||||
QJsonObject permissionObj;
|
QJsonObject permissionObj;
|
||||||
permissionObj.insert("id", "foobarbaz");
|
permissionObj.insert("id", "placeholder");
|
||||||
permissionObj.insert("object", "model_permission");
|
permissionObj.insert("object", "model_permission");
|
||||||
permissionObj.insert("created", 0);
|
permissionObj.insert("created", 0);
|
||||||
permissionObj.insert("allow_create_engine", false);
|
permissionObj.insert("allow_create_engine", false);
|
||||||
@ -70,6 +124,328 @@ static inline QJsonObject resultToJson(const ResultInfo &info)
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BaseCompletionRequest {
|
||||||
|
public:
|
||||||
|
QString model; // required
|
||||||
|
// NB: some parameters are not supported yet
|
||||||
|
int32_t max_tokens = 16;
|
||||||
|
qint64 n = 1;
|
||||||
|
float temperature = 1.f;
|
||||||
|
float top_p = 1.f;
|
||||||
|
float min_p = 0.f;
|
||||||
|
|
||||||
|
BaseCompletionRequest() = default;
|
||||||
|
virtual ~BaseCompletionRequest() = default;
|
||||||
|
|
||||||
|
virtual BaseCompletionRequest &parse(QCborMap request)
|
||||||
|
{
|
||||||
|
parseImpl(request);
|
||||||
|
if (!request.isEmpty())
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"Unrecognized request argument supplied: {}", request.keys().constFirst().toString()
|
||||||
|
));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void parseImpl(QCborMap &request)
|
||||||
|
{
|
||||||
|
using enum Type;
|
||||||
|
|
||||||
|
auto reqValue = [&request](auto &&...args) { return takeValue(request, args...); };
|
||||||
|
QCborValue value;
|
||||||
|
|
||||||
|
this->model = reqValue("model", String, /*required*/ true).toString();
|
||||||
|
|
||||||
|
value = reqValue("frequency_penalty", Number, false, /*min*/ -2, /*max*/ 2);
|
||||||
|
if (value.isDouble() || value.toInteger() != 0)
|
||||||
|
throw InvalidRequestError("'frequency_penalty' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("max_tokens", Integer, false, /*min*/ 1);
|
||||||
|
if (!value.isNull())
|
||||||
|
this->max_tokens = int32_t(qMin(value.toInteger(), INT32_MAX));
|
||||||
|
|
||||||
|
value = reqValue("n", Integer, false, /*min*/ 1);
|
||||||
|
if (!value.isNull())
|
||||||
|
this->n = value.toInteger();
|
||||||
|
|
||||||
|
value = reqValue("presence_penalty", Number);
|
||||||
|
if (value.isDouble() || value.toInteger() != 0)
|
||||||
|
throw InvalidRequestError("'presence_penalty' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("seed", Integer);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'seed' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("stop");
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'stop' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("stream", Boolean);
|
||||||
|
if (value.isTrue())
|
||||||
|
throw InvalidRequestError("'stream' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("stream_options", Object);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'stream_options' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("temperature", Number, false, /*min*/ 0, /*max*/ 2);
|
||||||
|
if (!value.isNull())
|
||||||
|
this->temperature = float(value.toDouble());
|
||||||
|
|
||||||
|
value = reqValue("top_p", Number, /*min*/ 0, /*max*/ 1);
|
||||||
|
if (!value.isNull())
|
||||||
|
this->top_p = float(value.toDouble());
|
||||||
|
|
||||||
|
value = reqValue("min_p", Number, /*min*/ 0, /*max*/ 1);
|
||||||
|
if (!value.isNull())
|
||||||
|
this->min_p = float(value.toDouble());
|
||||||
|
|
||||||
|
reqValue("user", String); // validate but don't use
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class Type : uint8_t {
|
||||||
|
Boolean,
|
||||||
|
Integer,
|
||||||
|
Number,
|
||||||
|
String,
|
||||||
|
Array,
|
||||||
|
Object,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::unordered_map<Type, const char *> s_typeNames;
|
||||||
|
|
||||||
|
static bool typeMatches(const QCborValue &value, Type type) noexcept {
|
||||||
|
using enum Type;
|
||||||
|
switch (type) {
|
||||||
|
case Boolean: return value.isBool();
|
||||||
|
case Integer: return value.isInteger();
|
||||||
|
case Number: return value.isInteger() || value.isDouble();
|
||||||
|
case String: return value.isString();
|
||||||
|
case Array: return value.isArray();
|
||||||
|
case Object: return value.isMap();
|
||||||
|
}
|
||||||
|
Q_UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
|
static QCborValue takeValue(
|
||||||
|
QCborMap &obj, const char *key, std::optional<Type> type = {}, bool required = false,
|
||||||
|
std::optional<qint64> min = {}, std::optional<qint64> max = {}
|
||||||
|
) {
|
||||||
|
auto value = obj.take(QLatin1StringView(key));
|
||||||
|
if (value.isUndefined())
|
||||||
|
value = QCborValue(QCborSimpleType::Null);
|
||||||
|
if (required && value.isNull())
|
||||||
|
throw InvalidRequestError(fmt::format("you must provide a {} parameter", key));
|
||||||
|
if (type && !value.isNull() && !typeMatches(value, *type))
|
||||||
|
throw InvalidRequestError(fmt::format("'{}' is not of type '{}' - '{}'",
|
||||||
|
value.toVariant(), s_typeNames.at(*type), key));
|
||||||
|
if (!value.isNull()) {
|
||||||
|
double num = value.toDouble();
|
||||||
|
if (min && num < double(*min))
|
||||||
|
throw InvalidRequestError(fmt::format("{} is less than the minimum of {} - '{}'", num, *min, key));
|
||||||
|
if (max && num > double(*max))
|
||||||
|
throw InvalidRequestError(fmt::format("{} is greater than the maximum of {} - '{}'", num, *max, key));
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Q_DISABLE_COPY_MOVE(BaseCompletionRequest)
|
||||||
|
};
|
||||||
|
|
||||||
|
class CompletionRequest : public BaseCompletionRequest {
|
||||||
|
public:
|
||||||
|
QString prompt; // required
|
||||||
|
// some parameters are not supported yet - these ones are
|
||||||
|
bool echo = false;
|
||||||
|
|
||||||
|
CompletionRequest &parse(QCborMap request) override
|
||||||
|
{
|
||||||
|
BaseCompletionRequest::parse(std::move(request));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void parseImpl(QCborMap &request) override
|
||||||
|
{
|
||||||
|
using enum Type;
|
||||||
|
|
||||||
|
auto reqValue = [&request](auto &&...args) { return takeValue(request, args...); };
|
||||||
|
QCborValue value;
|
||||||
|
|
||||||
|
BaseCompletionRequest::parseImpl(request);
|
||||||
|
|
||||||
|
this->prompt = reqValue("prompt", String, /*required*/ true).toString();
|
||||||
|
|
||||||
|
value = reqValue("best_of", Integer);
|
||||||
|
{
|
||||||
|
qint64 bof = value.toInteger(1);
|
||||||
|
if (this->n > bof)
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"You requested that the server return more choices than it will generate (HINT: you must set 'n' "
|
||||||
|
"(currently {}) to be at most 'best_of' (currently {}), or omit either parameter if you don't "
|
||||||
|
"specifically want to use them.)",
|
||||||
|
this->n, bof
|
||||||
|
));
|
||||||
|
if (bof > this->n)
|
||||||
|
throw InvalidRequestError("'best_of' is not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
value = reqValue("echo", Boolean);
|
||||||
|
if (value.isBool())
|
||||||
|
this->echo = value.toBool();
|
||||||
|
|
||||||
|
// we don't bother deeply typechecking unsupported subobjects for now
|
||||||
|
value = reqValue("logit_bias", Object);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'logit_bias' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("logprobs", Integer, false, /*min*/ 0);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'logprobs' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("suffix", String);
|
||||||
|
if (!value.isNull() && !value.toString().isEmpty())
|
||||||
|
throw InvalidRequestError("'suffix' is not supported");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::unordered_map<BaseCompletionRequest::Type, const char *> BaseCompletionRequest::s_typeNames = {
|
||||||
|
{ BaseCompletionRequest::Type::Boolean, "boolean" },
|
||||||
|
{ BaseCompletionRequest::Type::Integer, "integer" },
|
||||||
|
{ BaseCompletionRequest::Type::Number, "number" },
|
||||||
|
{ BaseCompletionRequest::Type::String, "string" },
|
||||||
|
{ BaseCompletionRequest::Type::Array, "array" },
|
||||||
|
{ BaseCompletionRequest::Type::Object, "object" },
|
||||||
|
};
|
||||||
|
|
||||||
|
class ChatRequest : public BaseCompletionRequest {
|
||||||
|
public:
|
||||||
|
struct Message {
|
||||||
|
enum class Role : uint8_t {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
};
|
||||||
|
Role role;
|
||||||
|
QString content;
|
||||||
|
};
|
||||||
|
|
||||||
|
QList<Message> messages; // required
|
||||||
|
|
||||||
|
ChatRequest &parse(QCborMap request) override
|
||||||
|
{
|
||||||
|
BaseCompletionRequest::parse(std::move(request));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void parseImpl(QCborMap &request) override
|
||||||
|
{
|
||||||
|
using enum Type;
|
||||||
|
|
||||||
|
auto reqValue = [&request](auto &&...args) { return takeValue(request, args...); };
|
||||||
|
QCborValue value;
|
||||||
|
|
||||||
|
BaseCompletionRequest::parseImpl(request);
|
||||||
|
|
||||||
|
value = reqValue("messages", std::nullopt, /*required*/ true);
|
||||||
|
if (!value.isArray() || value.toArray().isEmpty())
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"Invalid type for 'messages': expected a non-empty array of objects, but got '{}' instead.",
|
||||||
|
value.toVariant()
|
||||||
|
));
|
||||||
|
|
||||||
|
this->messages.clear();
|
||||||
|
{
|
||||||
|
QCborArray arr = value.toArray();
|
||||||
|
Message::Role nextRole = Message::Role::User;
|
||||||
|
for (qsizetype i = 0; i < arr.size(); i++) {
|
||||||
|
const auto &elem = arr[i];
|
||||||
|
if (!elem.isMap())
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"Invalid type for 'messages[{}]': expected an object, but got '{}' instead.",
|
||||||
|
i, elem.toVariant()
|
||||||
|
));
|
||||||
|
QCborMap msg = elem.toMap();
|
||||||
|
Message res;
|
||||||
|
QString role = takeValue(msg, "role", String, /*required*/ true).toString();
|
||||||
|
if (role == u"system"_s)
|
||||||
|
continue; // FIXME(jared): don't ignore these
|
||||||
|
if (role == u"user"_s) {
|
||||||
|
res.role = Message::Role::User;
|
||||||
|
} else if (role == u"assistant"_s) {
|
||||||
|
res.role = Message::Role::Assistant;
|
||||||
|
} else {
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"Invalid 'messages[{}].role': expected one of 'system', 'assistant', or 'user', but got '{}'"
|
||||||
|
" instead.",
|
||||||
|
i, role.toStdString()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
res.content = takeValue(msg, "content", String, /*required*/ true).toString();
|
||||||
|
if (res.role != nextRole)
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"Invalid 'messages[{}].role': did not expect '{}' here", i, role
|
||||||
|
));
|
||||||
|
this->messages.append(res);
|
||||||
|
nextRole = res.role == Message::Role::User ? Message::Role::Assistant
|
||||||
|
: Message::Role::User;
|
||||||
|
|
||||||
|
if (!msg.isEmpty())
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"Invalid 'messages[{}]': unrecognized key: '{}'", i, msg.keys().constFirst().toString()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we don't bother deeply typechecking unsupported subobjects for now
|
||||||
|
value = reqValue("logit_bias", Object);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'logit_bias' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("logprobs", Boolean);
|
||||||
|
if (value.isTrue())
|
||||||
|
throw InvalidRequestError("'logprobs' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("top_logprobs", Integer, false, /*min*/ 0);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("The 'top_logprobs' parameter is only allowed when 'logprobs' is enabled.");
|
||||||
|
|
||||||
|
value = reqValue("response_format", Object);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'response_format' is not supported");
|
||||||
|
|
||||||
|
reqValue("service_tier", String); // validate but don't use
|
||||||
|
|
||||||
|
value = reqValue("tools", Array);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'tools' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("tool_choice");
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'tool_choice' is not supported");
|
||||||
|
|
||||||
|
// validate but don't use
|
||||||
|
reqValue("parallel_tool_calls", Boolean);
|
||||||
|
|
||||||
|
value = reqValue("function_call");
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'function_call' is not supported");
|
||||||
|
|
||||||
|
value = reqValue("functions", Array);
|
||||||
|
if (!value.isNull())
|
||||||
|
throw InvalidRequestError("'functions' is not supported");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T &parseRequest(T &request, QJsonObject &&obj)
|
||||||
|
{
|
||||||
|
// lossless conversion to CBOR exposes more type information
|
||||||
|
return request.parse(QCborMap::fromJsonObject(obj));
|
||||||
|
}
|
||||||
|
|
||||||
Server::Server(Chat *chat)
|
Server::Server(Chat *chat)
|
||||||
: ChatLLM(chat, true /*isServer*/)
|
: ChatLLM(chat, true /*isServer*/)
|
||||||
, m_chat(chat)
|
, m_chat(chat)
|
||||||
@ -80,20 +456,28 @@ Server::Server(Chat *chat)
|
|||||||
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
|
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
|
||||||
}
|
}
|
||||||
|
|
||||||
Server::~Server()
|
static QJsonObject requestFromJson(const QByteArray &request)
|
||||||
{
|
{
|
||||||
|
QJsonParseError err;
|
||||||
|
const QJsonDocument document = QJsonDocument::fromJson(request, &err);
|
||||||
|
if (err.error || !document.isObject())
|
||||||
|
throw InvalidRequestError(fmt::format(
|
||||||
|
"error parsing request JSON: {}",
|
||||||
|
err.error ? err.errorString().toStdString() : "not an object"s
|
||||||
|
));
|
||||||
|
return document.object();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::start()
|
void Server::start()
|
||||||
{
|
{
|
||||||
m_server = new QHttpServer(this);
|
m_server = std::make_unique<QHttpServer>(this);
|
||||||
if (!m_server->listen(QHostAddress::LocalHost, MySettings::globalInstance()->networkPort())) {
|
if (!m_server->listen(QHostAddress::LocalHost, MySettings::globalInstance()->networkPort())) {
|
||||||
qWarning() << "ERROR: Unable to start the server";
|
qWarning() << "ERROR: Unable to start the server";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_server->route("/v1/models", QHttpServerRequest::Method::Get,
|
m_server->route("/v1/models", QHttpServerRequest::Method::Get,
|
||||||
[](const QHttpServerRequest &request) {
|
[](const QHttpServerRequest &) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
|
|
||||||
@ -113,7 +497,7 @@ void Server::start()
|
|||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Get,
|
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Get,
|
||||||
[](const QString &model, const QHttpServerRequest &request) {
|
[](const QString &model, const QHttpServerRequest &) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
|
|
||||||
@ -137,7 +521,23 @@ void Server::start()
|
|||||||
[this](const QHttpServerRequest &request) {
|
[this](const QHttpServerRequest &request) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return handleCompletionRequest(request, false);
|
|
||||||
|
try {
|
||||||
|
auto reqObj = requestFromJson(request.body());
|
||||||
|
#if defined(DEBUG)
|
||||||
|
qDebug().noquote() << "/v1/completions request" << QJsonDocument(reqObj).toJson(QJsonDocument::Indented);
|
||||||
|
#endif
|
||||||
|
CompletionRequest req;
|
||||||
|
parseRequest(req, std::move(reqObj));
|
||||||
|
auto [resp, respObj] = handleCompletionRequest(req);
|
||||||
|
#if defined(DEBUG)
|
||||||
|
if (respObj)
|
||||||
|
qDebug().noquote() << "/v1/completions reply" << QJsonDocument(*respObj).toJson(QJsonDocument::Indented);
|
||||||
|
#endif
|
||||||
|
return std::move(resp);
|
||||||
|
} catch (const InvalidRequestError &e) {
|
||||||
|
return e.asResponse();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -145,13 +545,30 @@ void Server::start()
|
|||||||
[this](const QHttpServerRequest &request) {
|
[this](const QHttpServerRequest &request) {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return handleCompletionRequest(request, true);
|
|
||||||
|
try {
|
||||||
|
auto reqObj = requestFromJson(request.body());
|
||||||
|
#if defined(DEBUG)
|
||||||
|
qDebug().noquote() << "/v1/chat/completions request" << QJsonDocument(reqObj).toJson(QJsonDocument::Indented);
|
||||||
|
#endif
|
||||||
|
ChatRequest req;
|
||||||
|
parseRequest(req, std::move(reqObj));
|
||||||
|
auto [resp, respObj] = handleChatRequest(req);
|
||||||
|
(void)respObj;
|
||||||
|
#if defined(DEBUG)
|
||||||
|
if (respObj)
|
||||||
|
qDebug().noquote() << "/v1/chat/completions reply" << QJsonDocument(*respObj).toJson(QJsonDocument::Indented);
|
||||||
|
#endif
|
||||||
|
return std::move(resp);
|
||||||
|
} catch (const InvalidRequestError &e) {
|
||||||
|
return e.asResponse();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
// Respond with code 405 to wrong HTTP methods:
|
// Respond with code 405 to wrong HTTP methods:
|
||||||
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
|
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
|
||||||
[](const QHttpServerRequest &request) {
|
[] {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return QHttpServerResponse(
|
return QHttpServerResponse(
|
||||||
@ -163,7 +580,8 @@ void Server::start()
|
|||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Post,
|
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Post,
|
||||||
[](const QString &model, const QHttpServerRequest &request) {
|
[](const QString &model) {
|
||||||
|
(void)model;
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return QHttpServerResponse(
|
return QHttpServerResponse(
|
||||||
@ -175,7 +593,7 @@ void Server::start()
|
|||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
|
m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
|
||||||
[](const QHttpServerRequest &request) {
|
[] {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return QHttpServerResponse(
|
return QHttpServerResponse(
|
||||||
@ -186,7 +604,7 @@ void Server::start()
|
|||||||
);
|
);
|
||||||
|
|
||||||
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
|
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
|
||||||
[](const QHttpServerRequest &request) {
|
[] {
|
||||||
if (!MySettings::globalInstance()->serverChat())
|
if (!MySettings::globalInstance()->serverChat())
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||||
return QHttpServerResponse(
|
return QHttpServerResponse(
|
||||||
@ -205,268 +623,261 @@ void Server::start()
|
|||||||
&Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection);
|
&Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection);
|
||||||
}
|
}
|
||||||
|
|
||||||
QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &request, bool isChat)
|
static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
||||||
{
|
{
|
||||||
// We've been asked to do a completion...
|
return {QHttpServerResponse(args...), std::nullopt};
|
||||||
QJsonParseError err;
|
|
||||||
const QJsonDocument document = QJsonDocument::fromJson(request.body(), &err);
|
|
||||||
if (err.error || !document.isObject()) {
|
|
||||||
std::cerr << "ERROR: invalid json in completions body" << std::endl;
|
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
|
|
||||||
}
|
|
||||||
#if defined(DEBUG)
|
|
||||||
printf("/v1/completions %s\n", qPrintable(document.toJson(QJsonDocument::Indented)));
|
|
||||||
fflush(stdout);
|
|
||||||
#endif
|
|
||||||
const QJsonObject body = document.object();
|
|
||||||
if (!body.contains("model")) { // required
|
|
||||||
std::cerr << "ERROR: completions contains no model" << std::endl;
|
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
|
|
||||||
}
|
|
||||||
QJsonArray messages;
|
|
||||||
if (isChat) {
|
|
||||||
if (!body.contains("messages")) {
|
|
||||||
std::cerr << "ERROR: chat completions contains no messages" << std::endl;
|
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
|
|
||||||
}
|
|
||||||
messages = body["messages"].toArray();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const QString modelRequested = body["model"].toString();
|
auto Server::handleCompletionRequest(const CompletionRequest &request)
|
||||||
|
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
||||||
|
{
|
||||||
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
|
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
|
||||||
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
|
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
|
||||||
for (const ModelInfo &info : modelList) {
|
for (const ModelInfo &info : modelList) {
|
||||||
Q_ASSERT(info.installed);
|
Q_ASSERT(info.installed);
|
||||||
if (!info.installed)
|
if (!info.installed)
|
||||||
continue;
|
continue;
|
||||||
if (modelRequested == info.name() || modelRequested == info.filename()) {
|
if (request.model == info.name() || request.model == info.filename()) {
|
||||||
modelInfo = info;
|
modelInfo = info;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We only support one prompt for now
|
|
||||||
QList<QString> prompts;
|
|
||||||
if (body.contains("prompt")) {
|
|
||||||
QJsonValue promptValue = body["prompt"];
|
|
||||||
if (promptValue.isString())
|
|
||||||
prompts.append(promptValue.toString());
|
|
||||||
else {
|
|
||||||
QJsonArray array = promptValue.toArray();
|
|
||||||
for (const QJsonValue &v : array)
|
|
||||||
prompts.append(v.toString());
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
prompts.append(" ");
|
|
||||||
|
|
||||||
int max_tokens = 16;
|
|
||||||
if (body.contains("max_tokens"))
|
|
||||||
max_tokens = body["max_tokens"].toInt();
|
|
||||||
|
|
||||||
float temperature = 1.f;
|
|
||||||
if (body.contains("temperature"))
|
|
||||||
temperature = body["temperature"].toDouble();
|
|
||||||
|
|
||||||
float top_p = 1.f;
|
|
||||||
if (body.contains("top_p"))
|
|
||||||
top_p = body["top_p"].toDouble();
|
|
||||||
|
|
||||||
float min_p = 0.f;
|
|
||||||
if (body.contains("min_p"))
|
|
||||||
min_p = body["min_p"].toDouble();
|
|
||||||
|
|
||||||
int n = 1;
|
|
||||||
if (body.contains("n"))
|
|
||||||
n = body["n"].toInt();
|
|
||||||
|
|
||||||
int logprobs = -1; // supposed to be null by default??
|
|
||||||
if (body.contains("logprobs"))
|
|
||||||
logprobs = body["logprobs"].toInt();
|
|
||||||
|
|
||||||
bool echo = false;
|
|
||||||
if (body.contains("echo"))
|
|
||||||
echo = body["echo"].toBool();
|
|
||||||
|
|
||||||
// We currently don't support any of the following...
|
|
||||||
#if 0
|
|
||||||
// FIXME: Need configurable reverse prompts
|
|
||||||
QList<QString> stop;
|
|
||||||
if (body.contains("stop")) {
|
|
||||||
QJsonValue stopValue = body["stop"];
|
|
||||||
if (stopValue.isString())
|
|
||||||
stop.append(stopValue.toString());
|
|
||||||
else {
|
|
||||||
QJsonArray array = stopValue.toArray();
|
|
||||||
for (QJsonValue v : array)
|
|
||||||
stop.append(v.toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: QHttpServer doesn't support server-sent events
|
|
||||||
bool stream = false;
|
|
||||||
if (body.contains("stream"))
|
|
||||||
stream = body["stream"].toBool();
|
|
||||||
|
|
||||||
// FIXME: What does this do?
|
|
||||||
QString suffix;
|
|
||||||
if (body.contains("suffix"))
|
|
||||||
suffix = body["suffix"].toString();
|
|
||||||
|
|
||||||
// FIXME: We don't support
|
|
||||||
float presence_penalty = 0.f;
|
|
||||||
if (body.contains("presence_penalty"))
|
|
||||||
top_p = body["presence_penalty"].toDouble();
|
|
||||||
|
|
||||||
// FIXME: We don't support
|
|
||||||
float frequency_penalty = 0.f;
|
|
||||||
if (body.contains("frequency_penalty"))
|
|
||||||
top_p = body["frequency_penalty"].toDouble();
|
|
||||||
|
|
||||||
// FIXME: We don't support
|
|
||||||
int best_of = 1;
|
|
||||||
if (body.contains("best_of"))
|
|
||||||
logprobs = body["best_of"].toInt();
|
|
||||||
|
|
||||||
// FIXME: We don't need
|
|
||||||
QString user;
|
|
||||||
if (body.contains("user"))
|
|
||||||
suffix = body["user"].toString();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
QString actualPrompt = prompts.first();
|
|
||||||
|
|
||||||
// if we're a chat completion we have messages which means we need to prepend these to the prompt
|
|
||||||
if (!messages.isEmpty()) {
|
|
||||||
QList<QString> chats;
|
|
||||||
for (int i = 0; i < messages.count(); ++i) {
|
|
||||||
QJsonValue v = messages.at(i);
|
|
||||||
// FIXME: Deal with system messages correctly
|
|
||||||
QString role = v.toObject()["role"].toString();
|
|
||||||
if (role != "user")
|
|
||||||
continue;
|
|
||||||
QString content = v.toObject()["content"].toString();
|
|
||||||
if (!content.endsWith("\n") && i < messages.count() - 1)
|
|
||||||
content += "\n";
|
|
||||||
chats.append(content);
|
|
||||||
}
|
|
||||||
actualPrompt.prepend(chats.join("\n"));
|
|
||||||
}
|
|
||||||
|
|
||||||
// adds prompt/response items to GUI
|
// adds prompt/response items to GUI
|
||||||
emit requestServerNewPromptResponsePair(actualPrompt); // blocks
|
emit requestServerNewPromptResponsePair(request.prompt); // blocks
|
||||||
|
resetResponse();
|
||||||
|
|
||||||
// load the new model if necessary
|
// load the new model if necessary
|
||||||
setShouldBeLoaded(true);
|
setShouldBeLoaded(true);
|
||||||
|
|
||||||
if (modelInfo.filename().isEmpty()) {
|
if (modelInfo.filename().isEmpty()) {
|
||||||
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;
|
std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl;
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::BadRequest);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NB: this resets the context, regardless of whether this model is already loaded
|
// NB: this resets the context, regardless of whether this model is already loaded
|
||||||
if (!loadModel(modelInfo)) {
|
if (!loadModel(modelInfo)) {
|
||||||
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
|
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
|
|
||||||
const QString promptTemplate = modelInfo.promptTemplate();
|
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
||||||
const float top_k = modelInfo.topK();
|
const int top_k = modelInfo.topK();
|
||||||
const int n_batch = modelInfo.promptBatchSize();
|
const int n_batch = modelInfo.promptBatchSize();
|
||||||
const float repeat_penalty = modelInfo.repeatPenalty();
|
const auto repeat_penalty = float(modelInfo.repeatPenalty());
|
||||||
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
|
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
|
||||||
|
|
||||||
int promptTokens = 0;
|
int promptTokens = 0;
|
||||||
int responseTokens = 0;
|
int responseTokens = 0;
|
||||||
QList<QPair<QString, QList<ResultInfo>>> responses;
|
QList<QPair<QString, QList<ResultInfo>>> responses;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < request.n; ++i) {
|
||||||
if (!promptInternal(
|
if (!promptInternal(
|
||||||
m_collections,
|
m_collections,
|
||||||
actualPrompt,
|
request.prompt,
|
||||||
promptTemplate,
|
/*promptTemplate*/ u"%1"_s,
|
||||||
max_tokens /*n_predict*/,
|
request.max_tokens,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
request.top_p,
|
||||||
min_p,
|
request.min_p,
|
||||||
temperature,
|
request.temperature,
|
||||||
n_batch,
|
n_batch,
|
||||||
repeat_penalty,
|
repeat_penalty,
|
||||||
repeat_last_n)) {
|
repeat_last_n)) {
|
||||||
|
|
||||||
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
|
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
|
||||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
}
|
}
|
||||||
QString echoedPrompt = actualPrompt;
|
QString resp = response(/*trim*/ false);
|
||||||
if (!echoedPrompt.endsWith("\n"))
|
if (request.echo)
|
||||||
echoedPrompt += "\n";
|
resp = request.prompt + resp;
|
||||||
responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_databaseResults));
|
responses.append({resp, m_databaseResults});
|
||||||
if (!promptTokens)
|
if (!promptTokens)
|
||||||
promptTokens += m_promptTokens;
|
promptTokens = m_promptTokens;
|
||||||
responseTokens += m_promptResponseTokens - m_promptTokens;
|
responseTokens += m_promptResponseTokens - m_promptTokens;
|
||||||
if (i != n - 1)
|
if (i < request.n - 1)
|
||||||
resetResponse();
|
resetResponse();
|
||||||
}
|
}
|
||||||
|
|
||||||
QJsonObject responseObject;
|
QJsonObject responseObject {
|
||||||
responseObject.insert("id", "foobarbaz");
|
{ "id", "placeholder" },
|
||||||
responseObject.insert("object", "text_completion");
|
{ "object", "text_completion" },
|
||||||
responseObject.insert("created", QDateTime::currentSecsSinceEpoch());
|
{ "created", QDateTime::currentSecsSinceEpoch() },
|
||||||
responseObject.insert("model", modelInfo.name());
|
{ "model", modelInfo.name() },
|
||||||
|
};
|
||||||
|
|
||||||
QJsonArray choices;
|
QJsonArray choices;
|
||||||
|
{
|
||||||
if (isChat) {
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (const auto &r : responses) {
|
for (const auto &r : responses) {
|
||||||
QString result = r.first;
|
QString result = r.first;
|
||||||
QList<ResultInfo> infos = r.second;
|
QList<ResultInfo> infos = r.second;
|
||||||
QJsonObject choice;
|
QJsonObject choice {
|
||||||
choice.insert("index", index++);
|
{ "text", result },
|
||||||
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
|
{ "index", index++ },
|
||||||
QJsonObject message;
|
{ "logprobs", QJsonValue::Null },
|
||||||
message.insert("role", "assistant");
|
{ "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
|
||||||
message.insert("content", result);
|
};
|
||||||
choice.insert("message", message);
|
|
||||||
if (MySettings::globalInstance()->localDocsShowReferences()) {
|
if (MySettings::globalInstance()->localDocsShowReferences()) {
|
||||||
QJsonArray references;
|
QJsonArray references;
|
||||||
for (const auto &ref : infos)
|
for (const auto &ref : infos)
|
||||||
references.append(resultToJson(ref));
|
references.append(resultToJson(ref));
|
||||||
choice.insert("references", references);
|
choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references));
|
||||||
}
|
|
||||||
choices.append(choice);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
int index = 0;
|
|
||||||
for (const auto &r : responses) {
|
|
||||||
QString result = r.first;
|
|
||||||
QList<ResultInfo> infos = r.second;
|
|
||||||
QJsonObject choice;
|
|
||||||
choice.insert("text", result);
|
|
||||||
choice.insert("index", index++);
|
|
||||||
choice.insert("logprobs", QJsonValue::Null); // We don't support
|
|
||||||
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
|
|
||||||
if (MySettings::globalInstance()->localDocsShowReferences()) {
|
|
||||||
QJsonArray references;
|
|
||||||
for (const auto &ref : infos)
|
|
||||||
references.append(resultToJson(ref));
|
|
||||||
choice.insert("references", references);
|
|
||||||
}
|
}
|
||||||
choices.append(choice);
|
choices.append(choice);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
responseObject.insert("choices", choices);
|
responseObject.insert("choices", choices);
|
||||||
|
responseObject.insert("usage", QJsonObject {
|
||||||
|
{ "prompt_tokens", promptTokens },
|
||||||
|
{ "completion_tokens", responseTokens },
|
||||||
|
{ "total_tokens", promptTokens + responseTokens },
|
||||||
|
});
|
||||||
|
|
||||||
QJsonObject usage;
|
return {QHttpServerResponse(responseObject), responseObject};
|
||||||
usage.insert("prompt_tokens", int(promptTokens));
|
}
|
||||||
usage.insert("completion_tokens", int(responseTokens));
|
|
||||||
usage.insert("total_tokens", int(promptTokens + responseTokens));
|
auto Server::handleChatRequest(const ChatRequest &request)
|
||||||
responseObject.insert("usage", usage);
|
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>>
|
||||||
|
{
|
||||||
#if defined(DEBUG)
|
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
|
||||||
QJsonDocument newDoc(responseObject);
|
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
|
||||||
printf("/v1/completions %s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented)));
|
for (const ModelInfo &info : modelList) {
|
||||||
fflush(stdout);
|
Q_ASSERT(info.installed);
|
||||||
#endif
|
if (!info.installed)
|
||||||
|
continue;
|
||||||
return QHttpServerResponse(responseObject);
|
if (request.model == info.name() || request.model == info.filename()) {
|
||||||
|
modelInfo = info;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the new model if necessary
|
||||||
|
setShouldBeLoaded(true);
|
||||||
|
|
||||||
|
if (modelInfo.filename().isEmpty()) {
|
||||||
|
std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl;
|
||||||
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NB: this resets the context, regardless of whether this model is already loaded
|
||||||
|
if (!loadModel(modelInfo)) {
|
||||||
|
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
|
||||||
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
|
}
|
||||||
|
|
||||||
|
const QString promptTemplate = modelInfo.promptTemplate();
|
||||||
|
const int top_k = modelInfo.topK();
|
||||||
|
const int n_batch = modelInfo.promptBatchSize();
|
||||||
|
const auto repeat_penalty = float(modelInfo.repeatPenalty());
|
||||||
|
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
|
||||||
|
|
||||||
|
int promptTokens = 0;
|
||||||
|
int responseTokens = 0;
|
||||||
|
QList<QPair<QString, QList<ResultInfo>>> responses;
|
||||||
|
Q_ASSERT(!request.messages.isEmpty());
|
||||||
|
Q_ASSERT(request.messages.size() % 2 == 1);
|
||||||
|
for (int i = 0; i < request.messages.size() - 2; i += 2) {
|
||||||
|
using enum ChatRequest::Message::Role;
|
||||||
|
auto &user = request.messages[i];
|
||||||
|
auto &assistant = request.messages[i + 1];
|
||||||
|
Q_ASSERT(user.role == User);
|
||||||
|
Q_ASSERT(assistant.role == Assistant);
|
||||||
|
|
||||||
|
// adds prompt/response items to GUI
|
||||||
|
emit requestServerNewPromptResponsePair(user.content); // blocks
|
||||||
|
resetResponse();
|
||||||
|
|
||||||
|
if (!promptInternal(
|
||||||
|
{},
|
||||||
|
user.content,
|
||||||
|
promptTemplate,
|
||||||
|
request.max_tokens,
|
||||||
|
top_k,
|
||||||
|
request.top_p,
|
||||||
|
request.min_p,
|
||||||
|
request.temperature,
|
||||||
|
n_batch,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
assistant.content)
|
||||||
|
) {
|
||||||
|
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
|
||||||
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
|
}
|
||||||
|
promptTokens += m_promptResponseTokens; // previous responses are part of current prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
QString lastMessage = request.messages.last().content;
|
||||||
|
// adds prompt/response items to GUI
|
||||||
|
emit requestServerNewPromptResponsePair(lastMessage); // blocks
|
||||||
|
resetResponse();
|
||||||
|
|
||||||
|
for (int i = 0; i < request.n; ++i) {
|
||||||
|
if (!promptInternal(
|
||||||
|
m_collections,
|
||||||
|
lastMessage,
|
||||||
|
promptTemplate,
|
||||||
|
request.max_tokens,
|
||||||
|
top_k,
|
||||||
|
request.top_p,
|
||||||
|
request.min_p,
|
||||||
|
request.temperature,
|
||||||
|
n_batch,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n)
|
||||||
|
) {
|
||||||
|
std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
|
||||||
|
return makeError(QHttpServerResponder::StatusCode::InternalServerError);
|
||||||
|
}
|
||||||
|
responses.append({response(), m_databaseResults});
|
||||||
|
// FIXME(jared): these are UI counts and do not include framing tokens, which they should
|
||||||
|
if (i == 0)
|
||||||
|
promptTokens += m_promptTokens;
|
||||||
|
responseTokens += m_promptResponseTokens - m_promptTokens;
|
||||||
|
if (i != request.n - 1)
|
||||||
|
resetResponse();
|
||||||
|
}
|
||||||
|
|
||||||
|
QJsonObject responseObject {
|
||||||
|
{ "id", "placeholder" },
|
||||||
|
{ "object", "chat.completion" },
|
||||||
|
{ "created", QDateTime::currentSecsSinceEpoch() },
|
||||||
|
{ "model", modelInfo.name() },
|
||||||
|
};
|
||||||
|
|
||||||
|
QJsonArray choices;
|
||||||
|
{
|
||||||
|
int index = 0;
|
||||||
|
for (const auto &r : responses) {
|
||||||
|
QString result = r.first;
|
||||||
|
QList<ResultInfo> infos = r.second;
|
||||||
|
QJsonObject message {
|
||||||
|
{ "role", "assistant" },
|
||||||
|
{ "content", result },
|
||||||
|
};
|
||||||
|
QJsonObject choice {
|
||||||
|
{ "index", index++ },
|
||||||
|
{ "message", message },
|
||||||
|
{ "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" },
|
||||||
|
{ "logprobs", QJsonValue::Null },
|
||||||
|
};
|
||||||
|
if (MySettings::globalInstance()->localDocsShowReferences()) {
|
||||||
|
QJsonArray references;
|
||||||
|
for (const auto &ref : infos)
|
||||||
|
references.append(resultToJson(ref));
|
||||||
|
choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references));
|
||||||
|
}
|
||||||
|
choices.append(choice);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
responseObject.insert("choices", choices);
|
||||||
|
responseObject.insert("usage", QJsonObject {
|
||||||
|
{ "prompt_tokens", promptTokens },
|
||||||
|
{ "completion_tokens", responseTokens },
|
||||||
|
{ "total_tokens", promptTokens + responseTokens },
|
||||||
|
});
|
||||||
|
|
||||||
|
return {QHttpServerResponse(responseObject), responseObject};
|
||||||
}
|
}
|
||||||
|
@ -4,22 +4,29 @@
|
|||||||
#include "chatllm.h"
|
#include "chatllm.h"
|
||||||
#include "database.h"
|
#include "database.h"
|
||||||
|
|
||||||
#include <QHttpServerRequest>
|
#include <QHttpServer>
|
||||||
#include <QHttpServerResponse>
|
#include <QHttpServerResponse>
|
||||||
#include <QObject>
|
#include <QJsonObject>
|
||||||
#include <QList>
|
#include <QList>
|
||||||
|
#include <QObject>
|
||||||
#include <QString>
|
#include <QString>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
class Chat;
|
class Chat;
|
||||||
class QHttpServer;
|
class ChatRequest;
|
||||||
|
class CompletionRequest;
|
||||||
|
|
||||||
|
|
||||||
class Server : public ChatLLM
|
class Server : public ChatLLM
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Server(Chat *parent);
|
explicit Server(Chat *chat);
|
||||||
virtual ~Server();
|
~Server() override = default;
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
void start();
|
void start();
|
||||||
@ -27,14 +34,17 @@ public Q_SLOTS:
|
|||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void requestServerNewPromptResponsePair(const QString &prompt);
|
void requestServerNewPromptResponsePair(const QString &prompt);
|
||||||
|
|
||||||
|
private:
|
||||||
|
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;
|
||||||
|
auto handleChatRequest(const ChatRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
|
|
||||||
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
|
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
|
||||||
void handleCollectionListChanged(const QList<QString> &collectionList) { m_collections = collectionList; }
|
void handleCollectionListChanged(const QList<QString> &collectionList) { m_collections = collectionList; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Chat *m_chat;
|
Chat *m_chat;
|
||||||
QHttpServer *m_server;
|
std::unique_ptr<QHttpServer> m_server;
|
||||||
QList<ResultInfo> m_databaseResults;
|
QList<ResultInfo> m_databaseResults;
|
||||||
QList<QString> m_collections;
|
QList<QString> m_collections;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user