mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-16 14:36:56 +00:00
Updated the llm component
This commit is contained in:
parent
bc343206cc
commit
1963190d16
4
.env
4
.env
@ -1,10 +1,10 @@
|
|||||||
PORT=8000
|
PORT=8000
|
||||||
ENVIRONMENT=dev
|
ENVIRONMENT=dev
|
||||||
|
|
||||||
DB_HOST=db
|
DB_HOST=localhost
|
||||||
DB_USER=postgres
|
DB_USER=postgres
|
||||||
DB_PORT=5432
|
DB_PORT=5432
|
||||||
DB_PASSWORD=admin
|
DB_PASSWORD=quick
|
||||||
DB_NAME=QuickGpt
|
DB_NAME=QuickGpt
|
||||||
|
|
||||||
SUPER_ADMIN_EMAIL=superadmin@email.com
|
SUPER_ADMIN_EMAIL=superadmin@email.com
|
||||||
|
54
alembic/versions/739fb4ac6615_chat_items.py
Normal file
54
alembic/versions/739fb4ac6615_chat_items.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""Chat items
|
||||||
|
|
||||||
|
Revision ID: 739fb4ac6615
|
||||||
|
Revises:
|
||||||
|
Create Date: 2024-05-01 19:20:19.652290
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '739fb4ac6615'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('chat_history',
|
||||||
|
sa.Column('conversation_id', sa.UUID(), nullable=False),
|
||||||
|
sa.Column('title', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('_title_generated', sa.Boolean(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('conversation_id')
|
||||||
|
)
|
||||||
|
op.create_table('chat_items',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('index', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('sender', sa.String(length=225), nullable=False),
|
||||||
|
sa.Column('content', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('like', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('conversation_id', sa.UUID(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['conversation_id'], ['chat_history.conversation_id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||||
|
op.drop_table('chat_items')
|
||||||
|
op.drop_table('chat_history')
|
||||||
|
# ### end Alembic commands ###
|
@ -1,4 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
from llama_index.core.llms import LLM, MockLLM
|
from llama_index.core.llms import LLM, MockLLM
|
||||||
@ -18,14 +20,24 @@ class LLMComponent:
|
|||||||
@inject
|
@inject
|
||||||
def __init__(self, settings: Settings) -> None:
|
def __init__(self, settings: Settings) -> None:
|
||||||
llm_mode = settings.llm.mode
|
llm_mode = settings.llm.mode
|
||||||
if settings.llm.tokenizer:
|
if settings.llm.tokenizer and settings.llm.mode != "mock":
|
||||||
set_global_tokenizer(
|
# Try to download the tokenizer. If it fails, the LLM will still work
|
||||||
AutoTokenizer.from_pretrained(
|
# using the default one, which is less accurate.
|
||||||
pretrained_model_name_or_path=settings.llm.tokenizer,
|
try:
|
||||||
cache_dir=str(models_cache_path),
|
set_global_tokenizer(
|
||||||
|
AutoTokenizer.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=settings.llm.tokenizer,
|
||||||
|
cache_dir=str(models_cache_path),
|
||||||
|
token=settings.huggingface.access_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to download tokenizer %s. Falling back to "
|
||||||
|
"default tokenizer.",
|
||||||
|
settings.llm.tokenizer,
|
||||||
|
e,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Initializing the LLM in mode=%s", llm_mode)
|
logger.info("Initializing the LLM in mode=%s", llm_mode)
|
||||||
match settings.llm.mode:
|
match settings.llm.mode:
|
||||||
@ -47,7 +59,8 @@ class LLMComponent:
|
|||||||
"offload_kqv": True,
|
"offload_kqv": True,
|
||||||
}
|
}
|
||||||
self.llm = LlamaCPP(
|
self.llm = LlamaCPP(
|
||||||
model_path=str(models_path / settings.llamacpp.llm_hf_model_file),
|
model_path=str(
|
||||||
|
models_path / settings.llamacpp.llm_hf_model_file),
|
||||||
temperature=settings.llm.temperature,
|
temperature=settings.llm.temperature,
|
||||||
max_new_tokens=settings.llm.max_new_tokens,
|
max_new_tokens=settings.llm.max_new_tokens,
|
||||||
context_window=settings.llm.context_window,
|
context_window=settings.llm.context_window,
|
||||||
@ -130,6 +143,44 @@ class LLMComponent:
|
|||||||
temperature=settings.llm.temperature,
|
temperature=settings.llm.temperature,
|
||||||
context_window=settings.llm.context_window,
|
context_window=settings.llm.context_window,
|
||||||
additional_kwargs=settings_kwargs,
|
additional_kwargs=settings_kwargs,
|
||||||
|
request_timeout=ollama_settings.request_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
ollama_settings.keep_alive
|
||||||
|
!= ollama_settings.model_fields["keep_alive"].default
|
||||||
|
):
|
||||||
|
# Modify Ollama methods to use the "keep_alive" field.
|
||||||
|
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
kwargs["keep_alive"] = ollama_settings.keep_alive
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
Ollama.chat = add_keep_alive(Ollama.chat)
|
||||||
|
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
|
||||||
|
Ollama.complete = add_keep_alive(Ollama.complete)
|
||||||
|
Ollama.stream_complete = add_keep_alive(
|
||||||
|
Ollama.stream_complete)
|
||||||
|
|
||||||
|
case "azopenai":
|
||||||
|
try:
|
||||||
|
from llama_index.llms.azure_openai import ( # type: ignore
|
||||||
|
AzureOpenAI,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
azopenai_settings = settings.azopenai
|
||||||
|
self.llm = AzureOpenAI(
|
||||||
|
model=azopenai_settings.llm_model,
|
||||||
|
deployment_name=azopenai_settings.llm_deployment_name,
|
||||||
|
api_key=azopenai_settings.api_key,
|
||||||
|
azure_endpoint=azopenai_settings.azure_endpoint,
|
||||||
|
api_version=azopenai_settings.api_version,
|
||||||
)
|
)
|
||||||
case "mock":
|
case "mock":
|
||||||
self.llm = MockLLM()
|
self.llm = MockLLM()
|
||||||
|
@ -42,7 +42,8 @@ class AbstractPromptStyle(abc.ABC):
|
|||||||
|
|
||||||
def completion_to_prompt(self, completion: str) -> str:
|
def completion_to_prompt(self, completion: str) -> str:
|
||||||
prompt = self._completion_to_prompt(completion)
|
prompt = self._completion_to_prompt(completion)
|
||||||
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
logger.debug("Got for completion='%s' the prompt='%s'",
|
||||||
|
completion, prompt)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
@ -58,8 +59,10 @@ class DefaultPromptStyle(AbstractPromptStyle):
|
|||||||
|
|
||||||
# Hacky way to override the functions
|
# Hacky way to override the functions
|
||||||
# Override the functions to be None, and pass None to the LLM.
|
# Override the functions to be None, and pass None to the LLM.
|
||||||
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
|
# type: ignore[method-assign, assignment]
|
||||||
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
|
self.messages_to_prompt = None
|
||||||
|
# type: ignore[method-assign, assignment]
|
||||||
|
self.completion_to_prompt = None
|
||||||
|
|
||||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
return ""
|
return ""
|
||||||
@ -215,7 +218,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
|||||||
|
|
||||||
|
|
||||||
def get_prompt_style(
|
def get_prompt_style(
|
||||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
prompt_style: Literal["default", "llama2",
|
||||||
|
"tag", "mistral", "chatml"] | None
|
||||||
) -> AbstractPromptStyle:
|
) -> AbstractPromptStyle:
|
||||||
"""Get the prompt style to use from the given string.
|
"""Get the prompt style to use from the given string.
|
||||||
|
|
||||||
|
@ -261,7 +261,6 @@ def register(
|
|||||||
)
|
)
|
||||||
random_password = security.generate_random_password()
|
random_password = security.generate_random_password()
|
||||||
# random_password = password
|
# random_password = password
|
||||||
|
|
||||||
try:
|
try:
|
||||||
company_id = current_user.company_id
|
company_id = current_user.company_id
|
||||||
if company_id:
|
if company_id:
|
||||||
|
@ -4,6 +4,7 @@ import uuid
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||||
|
from fastapi_pagination import Page, paginate
|
||||||
|
|
||||||
from private_gpt.users.api import deps
|
from private_gpt.users.api import deps
|
||||||
from private_gpt.users import crud, models, schemas
|
from private_gpt.users import crud, models, schemas
|
||||||
@ -12,22 +13,20 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix="/c", tags=["Chat Histories"])
|
router = APIRouter(prefix="/c", tags=["Chat Histories"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[schemas.ChatHistory])
|
@router.get("", response_model=Page[schemas.ChatHistory])
|
||||||
def list_chat_histories(
|
def list_chat_histories(
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
skip: int = 0,
|
|
||||||
limit: int = 100,
|
|
||||||
current_user: models.User = Security(
|
current_user: models.User = Security(
|
||||||
deps.get_current_user,
|
deps.get_current_user,
|
||||||
),
|
),
|
||||||
) -> list[schemas.ChatHistory]:
|
) -> Page[schemas.ChatHistory]:
|
||||||
"""
|
"""
|
||||||
Retrieve a list of chat histories with pagination support.
|
Retrieve a list of chat histories with pagination support.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
chat_histories = crud.chat.get_chat_history(
|
chat_histories = crud.chat.get_chat_history(
|
||||||
db, user_id=current_user.id, skip=skip, limit=limit)
|
db, user_id=current_user.id)
|
||||||
return chat_histories
|
return paginate(chat_histories)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
logger.error(f"Error listing chat histories: {str(e)}")
|
logger.error(f"Error listing chat histories: {str(e)}")
|
||||||
@ -66,6 +65,8 @@ def create_chat_history(
|
|||||||
@router.get("/{conversation_id}", response_model=schemas.ChatHistory)
|
@router.get("/{conversation_id}", response_model=schemas.ChatHistory)
|
||||||
def read_chat_history(
|
def read_chat_history(
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 20,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
current_user: models.User = Security(
|
current_user: models.User = Security(
|
||||||
deps.get_current_user,
|
deps.get_current_user,
|
||||||
@ -75,7 +76,7 @@ def read_chat_history(
|
|||||||
Read a chat history by ID
|
Read a chat history by ID
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
chat_history = crud.chat.get_by_id(db, id=conversation_id)
|
chat_history = crud.chat.get_by_id(db, id=conversation_id, skip=skip, limit=limit)
|
||||||
if chat_history is None or chat_history.user_id != current_user.id:
|
if chat_history is None or chat_history.user_id != current_user.id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail="Chat history not found")
|
status_code=404, detail="Chat history not found")
|
||||||
|
@ -10,8 +10,11 @@ from private_gpt.users.core.config import settings
|
|||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 1 # 12 hrs # Default Value
|
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 1 # 12 hrs # Default Value
|
||||||
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days # Default Value
|
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days # Default Value
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
JWT_SECRET_KEY = settings.SECRET_KEY
|
# JWT_SECRET_KEY = settings.SECRET_KEY
|
||||||
JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY
|
# JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY
|
||||||
|
|
||||||
|
JWT_SECRET_KEY = "QUICKGPT"
|
||||||
|
JWT_REFRESH_SECRET_KEY = "QUICKGPT_REFRESH"
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
@ -7,10 +7,11 @@ import uuid
|
|||||||
from private_gpt.users.crud.base import CRUDBase
|
from private_gpt.users.crud.base import CRUDBase
|
||||||
from private_gpt.users.models.chat import ChatHistory, ChatItem
|
from private_gpt.users.models.chat import ChatHistory, ChatItem
|
||||||
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryCreate, ChatItemCreate, ChatItemUpdate
|
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryCreate, ChatItemCreate, ChatItemUpdate
|
||||||
|
from fastapi_pagination import Page, paginate
|
||||||
|
|
||||||
|
|
||||||
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
|
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
|
||||||
def get_by_id(self, db: Session, *, id: uuid.UUID) -> Optional[ChatHistory]:
|
def get_by_id(self, db: Session, *, id: uuid.UUID, skip: int=0, limit: int=10) -> Optional[ChatHistory]:
|
||||||
chat_history = (
|
chat_history = (
|
||||||
db.query(self.model)
|
db.query(self.model)
|
||||||
.filter(ChatHistory.conversation_id == id)
|
.filter(ChatHistory.conversation_id == id)
|
||||||
@ -21,7 +22,9 @@ class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
|
|||||||
chat_history.chat_items = (
|
chat_history.chat_items = (
|
||||||
db.query(ChatItem)
|
db.query(ChatItem)
|
||||||
.filter(ChatItem.conversation_id == id)
|
.filter(ChatItem.conversation_id == id)
|
||||||
.order_by(asc(getattr(ChatItem, 'index')))
|
.order_by(desc(getattr(ChatItem, 'index')))
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
return chat_history
|
return chat_history
|
||||||
|
@ -14,8 +14,8 @@ embedding:
|
|||||||
embed_dim: 768
|
embed_dim: 768
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
llm_model: mistral
|
llm_model: llama3
|
||||||
embedding_model: nomic-embed-text
|
embedding_model: mxbai-embed-large
|
||||||
api_base: http://localhost:11434
|
api_base: http://localhost:11434
|
||||||
|
|
||||||
nodestore:
|
nodestore:
|
||||||
|
@ -11,8 +11,8 @@ embedding:
|
|||||||
mode: ollama
|
mode: ollama
|
||||||
|
|
||||||
ollama:
|
ollama:
|
||||||
llm_model: mistral
|
llm_model: llama3
|
||||||
embedding_model: nomic-embed-text
|
embedding_model: mxbai-embed-large
|
||||||
api_base: http://localhost:11434
|
api_base: http://localhost:11434
|
||||||
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
|
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
|
||||||
keep_alive: 5m
|
keep_alive: 5m
|
||||||
|
Loading…
Reference in New Issue
Block a user