Updated the llm component

This commit is contained in:
Saurab-Shrestha 2024-05-02 10:58:03 +05:45
parent bc343206cc
commit 1963190d16
10 changed files with 145 additions and 30 deletions

4
.env
View File

@ -1,10 +1,10 @@
PORT=8000 PORT=8000
ENVIRONMENT=dev ENVIRONMENT=dev
DB_HOST=db DB_HOST=localhost
DB_USER=postgres DB_USER=postgres
DB_PORT=5432 DB_PORT=5432
DB_PASSWORD=admin DB_PASSWORD=quick
DB_NAME=QuickGpt DB_NAME=QuickGpt
SUPER_ADMIN_EMAIL=superadmin@email.com SUPER_ADMIN_EMAIL=superadmin@email.com

View File

@ -0,0 +1,54 @@
"""Chat items
Revision ID: 739fb4ac6615
Revises:
Create Date: 2024-05-01 19:20:19.652290
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '739fb4ac6615'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('chat_history',
sa.Column('conversation_id', sa.UUID(), nullable=False),
sa.Column('title', sa.String(length=255), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('_title_generated', sa.Boolean(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('conversation_id')
)
op.create_table('chat_items',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('index', sa.Integer(), nullable=False),
sa.Column('sender', sa.String(length=225), nullable=False),
sa.Column('content', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('like', sa.Boolean(), nullable=True),
sa.Column('conversation_id', sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(['conversation_id'], ['chat_history.conversation_id'], ),
sa.PrimaryKeyConstraint('id')
)
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
op.drop_table('chat_items')
op.drop_table('chat_history')
# ### end Alembic commands ###

View File

@ -1,4 +1,6 @@
import logging import logging
from collections.abc import Callable
from typing import Any
from injector import inject, singleton from injector import inject, singleton
from llama_index.core.llms import LLM, MockLLM from llama_index.core.llms import LLM, MockLLM
@ -18,14 +20,24 @@ class LLMComponent:
@inject @inject
def __init__(self, settings: Settings) -> None: def __init__(self, settings: Settings) -> None:
llm_mode = settings.llm.mode llm_mode = settings.llm.mode
if settings.llm.tokenizer: if settings.llm.tokenizer and settings.llm.mode != "mock":
# Try to download the tokenizer. If it fails, the LLM will still work
# using the default one, which is less accurate.
try:
set_global_tokenizer( set_global_tokenizer(
AutoTokenizer.from_pretrained( AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer, pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path), cache_dir=str(models_cache_path),
token=settings.huggingface.access_token,
) )
) )
except Exception as e:
logger.warning(
"Failed to download tokenizer %s. Falling back to "
"default tokenizer.",
settings.llm.tokenizer,
e,
)
logger.info("Initializing the LLM in mode=%s", llm_mode) logger.info("Initializing the LLM in mode=%s", llm_mode)
match settings.llm.mode: match settings.llm.mode:
@ -47,7 +59,8 @@ class LLMComponent:
"offload_kqv": True, "offload_kqv": True,
} }
self.llm = LlamaCPP( self.llm = LlamaCPP(
model_path=str(models_path / settings.llamacpp.llm_hf_model_file), model_path=str(
models_path / settings.llamacpp.llm_hf_model_file),
temperature=settings.llm.temperature, temperature=settings.llm.temperature,
max_new_tokens=settings.llm.max_new_tokens, max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window, context_window=settings.llm.context_window,
@ -130,6 +143,44 @@ class LLMComponent:
temperature=settings.llm.temperature, temperature=settings.llm.temperature,
context_window=settings.llm.context_window, context_window=settings.llm.context_window,
additional_kwargs=settings_kwargs, additional_kwargs=settings_kwargs,
request_timeout=ollama_settings.request_timeout,
)
if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
):
# Modify Ollama methods to use the "keep_alive" field.
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
kwargs["keep_alive"] = ollama_settings.keep_alive
return func(*args, **kwargs)
return wrapper
Ollama.chat = add_keep_alive(Ollama.chat)
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(
Ollama.stream_complete)
case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore
AzureOpenAI,
)
except ImportError as e:
raise ImportError(
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
) from e
azopenai_settings = settings.azopenai
self.llm = AzureOpenAI(
model=azopenai_settings.llm_model,
deployment_name=azopenai_settings.llm_deployment_name,
api_key=azopenai_settings.api_key,
azure_endpoint=azopenai_settings.azure_endpoint,
api_version=azopenai_settings.api_version,
) )
case "mock": case "mock":
self.llm = MockLLM() self.llm = MockLLM()

View File

@ -42,7 +42,8 @@ class AbstractPromptStyle(abc.ABC):
def completion_to_prompt(self, completion: str) -> str: def completion_to_prompt(self, completion: str) -> str:
prompt = self._completion_to_prompt(completion) prompt = self._completion_to_prompt(completion)
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt) logger.debug("Got for completion='%s' the prompt='%s'",
completion, prompt)
return prompt return prompt
@ -58,8 +59,10 @@ class DefaultPromptStyle(AbstractPromptStyle):
# Hacky way to override the functions # Hacky way to override the functions
# Override the functions to be None, and pass None to the LLM. # Override the functions to be None, and pass None to the LLM.
self.messages_to_prompt = None # type: ignore[method-assign, assignment] # type: ignore[method-assign, assignment]
self.completion_to_prompt = None # type: ignore[method-assign, assignment] self.messages_to_prompt = None
# type: ignore[method-assign, assignment]
self.completion_to_prompt = None
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return "" return ""
@ -215,7 +218,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
def get_prompt_style( def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None prompt_style: Literal["default", "llama2",
"tag", "mistral", "chatml"] | None
) -> AbstractPromptStyle: ) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string. """Get the prompt style to use from the given string.

View File

@ -261,7 +261,6 @@ def register(
) )
random_password = security.generate_random_password() random_password = security.generate_random_password()
# random_password = password # random_password = password
try: try:
company_id = current_user.company_id company_id = current_user.company_id
if company_id: if company_id:

View File

@ -4,6 +4,7 @@ import uuid
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, HTTPException, status, Security from fastapi import APIRouter, Depends, HTTPException, status, Security
from fastapi_pagination import Page, paginate
from private_gpt.users.api import deps from private_gpt.users.api import deps
from private_gpt.users import crud, models, schemas from private_gpt.users import crud, models, schemas
@ -12,22 +13,20 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/c", tags=["Chat Histories"]) router = APIRouter(prefix="/c", tags=["Chat Histories"])
@router.get("", response_model=list[schemas.ChatHistory]) @router.get("", response_model=Page[schemas.ChatHistory])
def list_chat_histories( def list_chat_histories(
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
skip: int = 0,
limit: int = 100,
current_user: models.User = Security( current_user: models.User = Security(
deps.get_current_user, deps.get_current_user,
), ),
) -> list[schemas.ChatHistory]: ) -> Page[schemas.ChatHistory]:
""" """
Retrieve a list of chat histories with pagination support. Retrieve a list of chat histories with pagination support.
""" """
try: try:
chat_histories = crud.chat.get_chat_history( chat_histories = crud.chat.get_chat_history(
db, user_id=current_user.id, skip=skip, limit=limit) db, user_id=current_user.id)
return chat_histories return paginate(chat_histories)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error(f"Error listing chat histories: {str(e)}") logger.error(f"Error listing chat histories: {str(e)}")
@ -66,6 +65,8 @@ def create_chat_history(
@router.get("/{conversation_id}", response_model=schemas.ChatHistory) @router.get("/{conversation_id}", response_model=schemas.ChatHistory)
def read_chat_history( def read_chat_history(
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
skip: int = 0,
limit: int = 20,
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
current_user: models.User = Security( current_user: models.User = Security(
deps.get_current_user, deps.get_current_user,
@ -75,7 +76,7 @@ def read_chat_history(
Read a chat history by ID Read a chat history by ID
""" """
try: try:
chat_history = crud.chat.get_by_id(db, id=conversation_id) chat_history = crud.chat.get_by_id(db, id=conversation_id, skip=skip, limit=limit)
if chat_history is None or chat_history.user_id != current_user.id: if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException( raise HTTPException(
status_code=404, detail="Chat history not found") status_code=404, detail="Chat history not found")

View File

@ -10,8 +10,11 @@ from private_gpt.users.core.config import settings
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 1 # 12 hrs # Default Value ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 1 # 12 hrs # Default Value
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days # Default Value REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days # Default Value
ALGORITHM = "HS256" ALGORITHM = "HS256"
JWT_SECRET_KEY = settings.SECRET_KEY # JWT_SECRET_KEY = settings.SECRET_KEY
JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY # JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY
JWT_SECRET_KEY = "QUICKGPT"
JWT_REFRESH_SECRET_KEY = "QUICKGPT_REFRESH"
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

View File

@ -7,10 +7,11 @@ import uuid
from private_gpt.users.crud.base import CRUDBase from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.chat import ChatHistory, ChatItem from private_gpt.users.models.chat import ChatHistory, ChatItem
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryCreate, ChatItemCreate, ChatItemUpdate from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryCreate, ChatItemCreate, ChatItemUpdate
from fastapi_pagination import Page, paginate
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]): class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
def get_by_id(self, db: Session, *, id: uuid.UUID) -> Optional[ChatHistory]: def get_by_id(self, db: Session, *, id: uuid.UUID, skip: int=0, limit: int=10) -> Optional[ChatHistory]:
chat_history = ( chat_history = (
db.query(self.model) db.query(self.model)
.filter(ChatHistory.conversation_id == id) .filter(ChatHistory.conversation_id == id)
@ -21,7 +22,9 @@ class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
chat_history.chat_items = ( chat_history.chat_items = (
db.query(ChatItem) db.query(ChatItem)
.filter(ChatItem.conversation_id == id) .filter(ChatItem.conversation_id == id)
.order_by(asc(getattr(ChatItem, 'index'))) .order_by(desc(getattr(ChatItem, 'index')))
.offset(skip)
.limit(limit)
.all() .all()
) )
return chat_history return chat_history

View File

@ -14,8 +14,8 @@ embedding:
embed_dim: 768 embed_dim: 768
ollama: ollama:
llm_model: mistral llm_model: llama3
embedding_model: nomic-embed-text embedding_model: mxbai-embed-large
api_base: http://localhost:11434 api_base: http://localhost:11434
nodestore: nodestore:

View File

@ -11,8 +11,8 @@ embedding:
mode: ollama mode: ollama
ollama: ollama:
llm_model: mistral llm_model: llama3
embedding_model: nomic-embed-text embedding_model: mxbai-embed-large
api_base: http://localhost:11434 api_base: http://localhost:11434
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
keep_alive: 5m keep_alive: 5m