diff --git a/.env b/.env index e21e3d06..c8958669 100644 --- a/.env +++ b/.env @@ -1,10 +1,10 @@ PORT=8000 ENVIRONMENT=dev -DB_HOST=db +DB_HOST=localhost DB_USER=postgres DB_PORT=5432 -DB_PASSWORD=admin +DB_PASSWORD=quick DB_NAME=QuickGpt SUPER_ADMIN_EMAIL=superadmin@email.com diff --git a/alembic/versions/739fb4ac6615_chat_items.py b/alembic/versions/739fb4ac6615_chat_items.py new file mode 100644 index 00000000..935d15fe --- /dev/null +++ b/alembic/versions/739fb4ac6615_chat_items.py @@ -0,0 +1,54 @@ +"""Chat items + +Revision ID: 739fb4ac6615 +Revises: +Create Date: 2024-05-01 19:20:19.652290 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '739fb4ac6615' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('chat_history', + sa.Column('conversation_id', sa.UUID(), nullable=False), + sa.Column('title', sa.String(length=255), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('_title_generated', sa.Boolean(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('conversation_id') + ) + op.create_table('chat_items', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('index', sa.Integer(), nullable=False), + sa.Column('sender', sa.String(length=225), nullable=False), + sa.Column('content', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('like', sa.Boolean(), nullable=True), + sa.Column('conversation_id', sa.UUID(), nullable=False), + sa.ForeignKeyConstraint(['conversation_id'], ['chat_history.conversation_id'], ), + sa.PrimaryKeyConstraint('id') + ) + # op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # op.drop_constraint('unique_user_role', 'user_roles', type_='unique') + op.drop_table('chat_items') + op.drop_table('chat_history') + # ### end Alembic commands ### diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index be29427d..aa5913c8 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -1,4 +1,6 @@ import logging +from collections.abc import Callable +from typing import Any from injector import inject, singleton from llama_index.core.llms import LLM, MockLLM @@ -18,14 +20,24 @@ class LLMComponent: @inject def __init__(self, settings: Settings) -> None: llm_mode = settings.llm.mode - if settings.llm.tokenizer: - set_global_tokenizer( - AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=settings.llm.tokenizer, - cache_dir=str(models_cache_path), - + if settings.llm.tokenizer and settings.llm.mode != "mock": + # Try to download the tokenizer. If it fails, the LLM will still work + # using the default one, which is less accurate. + try: + set_global_tokenizer( + AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=settings.llm.tokenizer, + cache_dir=str(models_cache_path), + token=settings.huggingface.access_token, + ) + ) + except Exception as e: + logger.warning( + "Failed to download tokenizer %s. Falling back to " + "default tokenizer.", + settings.llm.tokenizer, + e, ) - ) logger.info("Initializing the LLM in mode=%s", llm_mode) match settings.llm.mode: @@ -47,7 +59,8 @@ class LLMComponent: "offload_kqv": True, } self.llm = LlamaCPP( - model_path=str(models_path / settings.llamacpp.llm_hf_model_file), + model_path=str( + models_path / settings.llamacpp.llm_hf_model_file), temperature=settings.llm.temperature, max_new_tokens=settings.llm.max_new_tokens, context_window=settings.llm.context_window, @@ -130,6 +143,44 @@ class LLMComponent: temperature=settings.llm.temperature, context_window=settings.llm.context_window, additional_kwargs=settings_kwargs, + request_timeout=ollama_settings.request_timeout, + ) + + if ( + ollama_settings.keep_alive + != ollama_settings.model_fields["keep_alive"].default + ): + # Modify Ollama methods to use the "keep_alive" field. + def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + kwargs["keep_alive"] = ollama_settings.keep_alive + return func(*args, **kwargs) + + return wrapper + + Ollama.chat = add_keep_alive(Ollama.chat) + Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) + Ollama.complete = add_keep_alive(Ollama.complete) + Ollama.stream_complete = add_keep_alive( + Ollama.stream_complete) + + case "azopenai": + try: + from llama_index.llms.azure_openai import ( # type: ignore + AzureOpenAI, + ) + except ImportError as e: + raise ImportError( + "Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`" + ) from e + + azopenai_settings = settings.azopenai + self.llm = AzureOpenAI( + model=azopenai_settings.llm_model, + deployment_name=azopenai_settings.llm_deployment_name, + api_key=azopenai_settings.api_key, + azure_endpoint=azopenai_settings.azure_endpoint, + api_version=azopenai_settings.api_version, ) case "mock": self.llm = MockLLM() diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 985d217b..359b8a55 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -42,7 +42,8 @@ class AbstractPromptStyle(abc.ABC): def completion_to_prompt(self, completion: str) -> str: prompt = self._completion_to_prompt(completion) - logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt) + logger.debug("Got for completion='%s' the prompt='%s'", + completion, prompt) return prompt @@ -58,8 +59,10 @@ class DefaultPromptStyle(AbstractPromptStyle): # Hacky way to override the functions # Override the functions to be None, and pass None to the LLM. - self.messages_to_prompt = None # type: ignore[method-assign, assignment] - self.completion_to_prompt = None # type: ignore[method-assign, assignment] + # type: ignore[method-assign, assignment] + self.messages_to_prompt = None + # type: ignore[method-assign, assignment] + self.completion_to_prompt = None def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: return "" @@ -215,7 +218,8 @@ class ChatMLPromptStyle(AbstractPromptStyle): def get_prompt_style( - prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None + prompt_style: Literal["default", "llama2", + "tag", "mistral", "chatml"] | None ) -> AbstractPromptStyle: """Get the prompt style to use from the given string. diff --git a/private_gpt/users/api/v1/routers/auth.py b/private_gpt/users/api/v1/routers/auth.py index 1ee87760..1b63cf09 100644 --- a/private_gpt/users/api/v1/routers/auth.py +++ b/private_gpt/users/api/v1/routers/auth.py @@ -261,7 +261,6 @@ def register( ) random_password = security.generate_random_password() # random_password = password - try: company_id = current_user.company_id if company_id: diff --git a/private_gpt/users/api/v1/routers/chat_history.py b/private_gpt/users/api/v1/routers/chat_history.py index 93ed6294..87e48b41 100644 --- a/private_gpt/users/api/v1/routers/chat_history.py +++ b/private_gpt/users/api/v1/routers/chat_history.py @@ -4,6 +4,7 @@ import uuid from sqlalchemy.orm import Session from fastapi.responses import JSONResponse from fastapi import APIRouter, Depends, HTTPException, status, Security +from fastapi_pagination import Page, paginate from private_gpt.users.api import deps from private_gpt.users import crud, models, schemas @@ -12,22 +13,20 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/c", tags=["Chat Histories"]) -@router.get("", response_model=list[schemas.ChatHistory]) +@router.get("", response_model=Page[schemas.ChatHistory]) def list_chat_histories( db: Session = Depends(deps.get_db), - skip: int = 0, - limit: int = 100, current_user: models.User = Security( deps.get_current_user, ), -) -> list[schemas.ChatHistory]: +) -> Page[schemas.ChatHistory]: """ Retrieve a list of chat histories with pagination support. """ try: chat_histories = crud.chat.get_chat_history( - db, user_id=current_user.id, skip=skip, limit=limit) - return chat_histories + db, user_id=current_user.id) + return paginate(chat_histories) except Exception as e: print(traceback.format_exc()) logger.error(f"Error listing chat histories: {str(e)}") @@ -66,6 +65,8 @@ def create_chat_history( @router.get("/{conversation_id}", response_model=schemas.ChatHistory) def read_chat_history( conversation_id: uuid.UUID, + skip: int = 0, + limit: int = 20, db: Session = Depends(deps.get_db), current_user: models.User = Security( deps.get_current_user, @@ -75,7 +76,7 @@ def read_chat_history( Read a chat history by ID """ try: - chat_history = crud.chat.get_by_id(db, id=conversation_id) + chat_history = crud.chat.get_by_id(db, id=conversation_id, skip=skip, limit=limit) if chat_history is None or chat_history.user_id != current_user.id: raise HTTPException( status_code=404, detail="Chat history not found") diff --git a/private_gpt/users/core/security.py b/private_gpt/users/core/security.py index f3ec97f8..c187331d 100644 --- a/private_gpt/users/core/security.py +++ b/private_gpt/users/core/security.py @@ -10,8 +10,11 @@ from private_gpt.users.core.config import settings ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 1 # 12 hrs # Default Value REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days # Default Value ALGORITHM = "HS256" -JWT_SECRET_KEY = settings.SECRET_KEY -JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY +# JWT_SECRET_KEY = settings.SECRET_KEY +# JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY + +JWT_SECRET_KEY = "QUICKGPT" +JWT_REFRESH_SECRET_KEY = "QUICKGPT_REFRESH" pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") diff --git a/private_gpt/users/crud/chat_crud.py b/private_gpt/users/crud/chat_crud.py index 87450a0c..7968fea9 100644 --- a/private_gpt/users/crud/chat_crud.py +++ b/private_gpt/users/crud/chat_crud.py @@ -7,10 +7,11 @@ import uuid from private_gpt.users.crud.base import CRUDBase from private_gpt.users.models.chat import ChatHistory, ChatItem from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryCreate, ChatItemCreate, ChatItemUpdate +from fastapi_pagination import Page, paginate class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]): - def get_by_id(self, db: Session, *, id: uuid.UUID) -> Optional[ChatHistory]: + def get_by_id(self, db: Session, *, id: uuid.UUID, skip: int=0, limit: int=10) -> Optional[ChatHistory]: chat_history = ( db.query(self.model) .filter(ChatHistory.conversation_id == id) @@ -21,7 +22,9 @@ class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]): chat_history.chat_items = ( db.query(ChatItem) .filter(ChatItem.conversation_id == id) - .order_by(asc(getattr(ChatItem, 'index'))) + .order_by(desc(getattr(ChatItem, 'index'))) + .offset(skip) + .limit(limit) .all() ) return chat_history diff --git a/settings-ollama-pg.yaml b/settings-ollama-pg.yaml index b9798245..77f02cd7 100644 --- a/settings-ollama-pg.yaml +++ b/settings-ollama-pg.yaml @@ -14,8 +14,8 @@ embedding: embed_dim: 768 ollama: - llm_model: mistral - embedding_model: nomic-embed-text + llm_model: llama3 + embedding_model: mxbai-embed-large api_base: http://localhost:11434 nodestore: diff --git a/settings-ollama.yaml b/settings-ollama.yaml index 13663dc7..9bc181d3 100644 --- a/settings-ollama.yaml +++ b/settings-ollama.yaml @@ -11,8 +11,8 @@ embedding: mode: ollama ollama: - llm_model: mistral - embedding_model: nomic-embed-text + llm_model: llama3 + embedding_model: mxbai-embed-large api_base: http://localhost:11434 embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama keep_alive: 5m