mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-15 14:13:47 +00:00
Updated the llm component
This commit is contained in:
parent
bc343206cc
commit
1963190d16
4
.env
4
.env
@ -1,10 +1,10 @@
|
||||
PORT=8000
|
||||
ENVIRONMENT=dev
|
||||
|
||||
DB_HOST=db
|
||||
DB_HOST=localhost
|
||||
DB_USER=postgres
|
||||
DB_PORT=5432
|
||||
DB_PASSWORD=admin
|
||||
DB_PASSWORD=quick
|
||||
DB_NAME=QuickGpt
|
||||
|
||||
SUPER_ADMIN_EMAIL=superadmin@email.com
|
||||
|
54
alembic/versions/739fb4ac6615_chat_items.py
Normal file
54
alembic/versions/739fb4ac6615_chat_items.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Chat items
|
||||
|
||||
Revision ID: 739fb4ac6615
|
||||
Revises:
|
||||
Create Date: 2024-05-01 19:20:19.652290
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '739fb4ac6615'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('chat_history',
|
||||
sa.Column('conversation_id', sa.UUID(), nullable=False),
|
||||
sa.Column('title', sa.String(length=255), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.Column('_title_generated', sa.Boolean(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('conversation_id')
|
||||
)
|
||||
op.create_table('chat_items',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('index', sa.Integer(), nullable=False),
|
||||
sa.Column('sender', sa.String(length=225), nullable=False),
|
||||
sa.Column('content', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('like', sa.Boolean(), nullable=True),
|
||||
sa.Column('conversation_id', sa.UUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['conversation_id'], ['chat_history.conversation_id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
op.drop_table('chat_items')
|
||||
op.drop_table('chat_history')
|
||||
# ### end Alembic commands ###
|
@ -1,4 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.llms import LLM, MockLLM
|
||||
@ -18,14 +20,24 @@ class LLMComponent:
|
||||
@inject
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
llm_mode = settings.llm.mode
|
||||
if settings.llm.tokenizer:
|
||||
set_global_tokenizer(
|
||||
AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=settings.llm.tokenizer,
|
||||
cache_dir=str(models_cache_path),
|
||||
|
||||
if settings.llm.tokenizer and settings.llm.mode != "mock":
|
||||
# Try to download the tokenizer. If it fails, the LLM will still work
|
||||
# using the default one, which is less accurate.
|
||||
try:
|
||||
set_global_tokenizer(
|
||||
AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=settings.llm.tokenizer,
|
||||
cache_dir=str(models_cache_path),
|
||||
token=settings.huggingface.access_token,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to download tokenizer %s. Falling back to "
|
||||
"default tokenizer.",
|
||||
settings.llm.tokenizer,
|
||||
e,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Initializing the LLM in mode=%s", llm_mode)
|
||||
match settings.llm.mode:
|
||||
@ -47,7 +59,8 @@ class LLMComponent:
|
||||
"offload_kqv": True,
|
||||
}
|
||||
self.llm = LlamaCPP(
|
||||
model_path=str(models_path / settings.llamacpp.llm_hf_model_file),
|
||||
model_path=str(
|
||||
models_path / settings.llamacpp.llm_hf_model_file),
|
||||
temperature=settings.llm.temperature,
|
||||
max_new_tokens=settings.llm.max_new_tokens,
|
||||
context_window=settings.llm.context_window,
|
||||
@ -130,6 +143,44 @@ class LLMComponent:
|
||||
temperature=settings.llm.temperature,
|
||||
context_window=settings.llm.context_window,
|
||||
additional_kwargs=settings_kwargs,
|
||||
request_timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
ollama_settings.keep_alive
|
||||
!= ollama_settings.model_fields["keep_alive"].default
|
||||
):
|
||||
# Modify Ollama methods to use the "keep_alive" field.
|
||||
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
kwargs["keep_alive"] = ollama_settings.keep_alive
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
Ollama.chat = add_keep_alive(Ollama.chat)
|
||||
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
|
||||
Ollama.complete = add_keep_alive(Ollama.complete)
|
||||
Ollama.stream_complete = add_keep_alive(
|
||||
Ollama.stream_complete)
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
from llama_index.llms.azure_openai import ( # type: ignore
|
||||
AzureOpenAI,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
|
||||
) from e
|
||||
|
||||
azopenai_settings = settings.azopenai
|
||||
self.llm = AzureOpenAI(
|
||||
model=azopenai_settings.llm_model,
|
||||
deployment_name=azopenai_settings.llm_deployment_name,
|
||||
api_key=azopenai_settings.api_key,
|
||||
azure_endpoint=azopenai_settings.azure_endpoint,
|
||||
api_version=azopenai_settings.api_version,
|
||||
)
|
||||
case "mock":
|
||||
self.llm = MockLLM()
|
||||
|
@ -42,7 +42,8 @@ class AbstractPromptStyle(abc.ABC):
|
||||
|
||||
def completion_to_prompt(self, completion: str) -> str:
|
||||
prompt = self._completion_to_prompt(completion)
|
||||
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
||||
logger.debug("Got for completion='%s' the prompt='%s'",
|
||||
completion, prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
@ -58,8 +59,10 @@ class DefaultPromptStyle(AbstractPromptStyle):
|
||||
|
||||
# Hacky way to override the functions
|
||||
# Override the functions to be None, and pass None to the LLM.
|
||||
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
|
||||
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
|
||||
# type: ignore[method-assign, assignment]
|
||||
self.messages_to_prompt = None
|
||||
# type: ignore[method-assign, assignment]
|
||||
self.completion_to_prompt = None
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
return ""
|
||||
@ -215,7 +218,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
||||
|
||||
|
||||
def get_prompt_style(
|
||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
||||
prompt_style: Literal["default", "llama2",
|
||||
"tag", "mistral", "chatml"] | None
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
|
@ -261,7 +261,6 @@ def register(
|
||||
)
|
||||
random_password = security.generate_random_password()
|
||||
# random_password = password
|
||||
|
||||
try:
|
||||
company_id = current_user.company_id
|
||||
if company_id:
|
||||
|
@ -4,6 +4,7 @@ import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
from fastapi_pagination import Page, paginate
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users import crud, models, schemas
|
||||
@ -12,22 +13,20 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/c", tags=["Chat Histories"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[schemas.ChatHistory])
|
||||
@router.get("", response_model=Page[schemas.ChatHistory])
|
||||
def list_chat_histories(
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> list[schemas.ChatHistory]:
|
||||
) -> Page[schemas.ChatHistory]:
|
||||
"""
|
||||
Retrieve a list of chat histories with pagination support.
|
||||
"""
|
||||
try:
|
||||
chat_histories = crud.chat.get_chat_history(
|
||||
db, user_id=current_user.id, skip=skip, limit=limit)
|
||||
return chat_histories
|
||||
db, user_id=current_user.id)
|
||||
return paginate(chat_histories)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error listing chat histories: {str(e)}")
|
||||
@ -66,6 +65,8 @@ def create_chat_history(
|
||||
@router.get("/{conversation_id}", response_model=schemas.ChatHistory)
|
||||
def read_chat_history(
|
||||
conversation_id: uuid.UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
@ -75,7 +76,7 @@ def read_chat_history(
|
||||
Read a chat history by ID
|
||||
"""
|
||||
try:
|
||||
chat_history = crud.chat.get_by_id(db, id=conversation_id)
|
||||
chat_history = crud.chat.get_by_id(db, id=conversation_id, skip=skip, limit=limit)
|
||||
if chat_history is None or chat_history.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat history not found")
|
||||
|
@ -10,8 +10,11 @@ from private_gpt.users.core.config import settings
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 1 # 12 hrs # Default Value
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days # Default Value
|
||||
ALGORITHM = "HS256"
|
||||
JWT_SECRET_KEY = settings.SECRET_KEY
|
||||
JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY
|
||||
# JWT_SECRET_KEY = settings.SECRET_KEY
|
||||
# JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY
|
||||
|
||||
JWT_SECRET_KEY = "QUICKGPT"
|
||||
JWT_REFRESH_SECRET_KEY = "QUICKGPT_REFRESH"
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
@ -7,10 +7,11 @@ import uuid
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from private_gpt.users.models.chat import ChatHistory, ChatItem
|
||||
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryCreate, ChatItemCreate, ChatItemUpdate
|
||||
from fastapi_pagination import Page, paginate
|
||||
|
||||
|
||||
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
|
||||
def get_by_id(self, db: Session, *, id: uuid.UUID) -> Optional[ChatHistory]:
|
||||
def get_by_id(self, db: Session, *, id: uuid.UUID, skip: int=0, limit: int=10) -> Optional[ChatHistory]:
|
||||
chat_history = (
|
||||
db.query(self.model)
|
||||
.filter(ChatHistory.conversation_id == id)
|
||||
@ -21,7 +22,9 @@ class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
|
||||
chat_history.chat_items = (
|
||||
db.query(ChatItem)
|
||||
.filter(ChatItem.conversation_id == id)
|
||||
.order_by(asc(getattr(ChatItem, 'index')))
|
||||
.order_by(desc(getattr(ChatItem, 'index')))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return chat_history
|
||||
|
@ -14,8 +14,8 @@ embedding:
|
||||
embed_dim: 768
|
||||
|
||||
ollama:
|
||||
llm_model: mistral
|
||||
embedding_model: nomic-embed-text
|
||||
llm_model: llama3
|
||||
embedding_model: mxbai-embed-large
|
||||
api_base: http://localhost:11434
|
||||
|
||||
nodestore:
|
||||
|
@ -11,8 +11,8 @@ embedding:
|
||||
mode: ollama
|
||||
|
||||
ollama:
|
||||
llm_model: mistral
|
||||
embedding_model: nomic-embed-text
|
||||
llm_model: llama3
|
||||
embedding_model: mxbai-embed-large
|
||||
api_base: http://localhost:11434
|
||||
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
|
||||
keep_alive: 5m
|
||||
|
Loading…
Reference in New Issue
Block a user