Updated the llm component

This commit is contained in:
Saurab-Shrestha 2024-05-02 10:58:03 +05:45
parent bc343206cc
commit 1963190d16
10 changed files with 145 additions and 30 deletions

4
.env
View File

@ -1,10 +1,10 @@
PORT=8000
ENVIRONMENT=dev
DB_HOST=db
DB_HOST=localhost
DB_USER=postgres
DB_PORT=5432
DB_PASSWORD=admin
DB_PASSWORD=quick
DB_NAME=QuickGpt
SUPER_ADMIN_EMAIL=superadmin@email.com

View File

@ -0,0 +1,54 @@
"""Chat items
Revision ID: 739fb4ac6615
Revises:
Create Date: 2024-05-01 19:20:19.652290
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '739fb4ac6615'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('chat_history',
sa.Column('conversation_id', sa.UUID(), nullable=False),
sa.Column('title', sa.String(length=255), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('_title_generated', sa.Boolean(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('conversation_id')
)
op.create_table('chat_items',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('index', sa.Integer(), nullable=False),
sa.Column('sender', sa.String(length=225), nullable=False),
sa.Column('content', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('like', sa.Boolean(), nullable=True),
sa.Column('conversation_id', sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(['conversation_id'], ['chat_history.conversation_id'], ),
sa.PrimaryKeyConstraint('id')
)
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
op.drop_table('chat_items')
op.drop_table('chat_history')
# ### end Alembic commands ###

View File

@ -1,4 +1,6 @@
import logging
from collections.abc import Callable
from typing import Any
from injector import inject, singleton
from llama_index.core.llms import LLM, MockLLM
@ -18,14 +20,24 @@ class LLMComponent:
@inject
def __init__(self, settings: Settings) -> None:
llm_mode = settings.llm.mode
if settings.llm.tokenizer:
set_global_tokenizer(
AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path),
if settings.llm.tokenizer and settings.llm.mode != "mock":
# Try to download the tokenizer. If it fails, the LLM will still work
# using the default one, which is less accurate.
try:
set_global_tokenizer(
AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path),
token=settings.huggingface.access_token,
)
)
except Exception as e:
logger.warning(
"Failed to download tokenizer %s. Falling back to "
"default tokenizer.",
settings.llm.tokenizer,
e,
)
)
logger.info("Initializing the LLM in mode=%s", llm_mode)
match settings.llm.mode:
@ -47,7 +59,8 @@ class LLMComponent:
"offload_kqv": True,
}
self.llm = LlamaCPP(
model_path=str(models_path / settings.llamacpp.llm_hf_model_file),
model_path=str(
models_path / settings.llamacpp.llm_hf_model_file),
temperature=settings.llm.temperature,
max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window,
@ -130,6 +143,44 @@ class LLMComponent:
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
additional_kwargs=settings_kwargs,
request_timeout=ollama_settings.request_timeout,
)
if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
):
# Modify Ollama methods to use the "keep_alive" field.
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
kwargs["keep_alive"] = ollama_settings.keep_alive
return func(*args, **kwargs)
return wrapper
Ollama.chat = add_keep_alive(Ollama.chat)
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(
Ollama.stream_complete)
case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore
AzureOpenAI,
)
except ImportError as e:
raise ImportError(
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
) from e
azopenai_settings = settings.azopenai
self.llm = AzureOpenAI(
model=azopenai_settings.llm_model,
deployment_name=azopenai_settings.llm_deployment_name,
api_key=azopenai_settings.api_key,
azure_endpoint=azopenai_settings.azure_endpoint,
api_version=azopenai_settings.api_version,
)
case "mock":
self.llm = MockLLM()

View File

@ -42,7 +42,8 @@ class AbstractPromptStyle(abc.ABC):
def completion_to_prompt(self, completion: str) -> str:
prompt = self._completion_to_prompt(completion)
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
logger.debug("Got for completion='%s' the prompt='%s'",
completion, prompt)
return prompt
@ -58,8 +59,10 @@ class DefaultPromptStyle(AbstractPromptStyle):
# Hacky way to override the functions
# Override the functions to be None, and pass None to the LLM.
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
# type: ignore[method-assign, assignment]
self.messages_to_prompt = None
# type: ignore[method-assign, assignment]
self.completion_to_prompt = None
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return ""
@ -215,7 +218,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
prompt_style: Literal["default", "llama2",
"tag", "mistral", "chatml"] | None
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.

View File

@ -261,7 +261,6 @@ def register(
)
random_password = security.generate_random_password()
# random_password = password
try:
company_id = current_user.company_id
if company_id:

View File

@ -4,6 +4,7 @@ import uuid
from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, HTTPException, status, Security
from fastapi_pagination import Page, paginate
from private_gpt.users.api import deps
from private_gpt.users import crud, models, schemas
@ -12,22 +13,20 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/c", tags=["Chat Histories"])
@router.get("", response_model=list[schemas.ChatHistory])
@router.get("", response_model=Page[schemas.ChatHistory])
def list_chat_histories(
db: Session = Depends(deps.get_db),
skip: int = 0,
limit: int = 100,
current_user: models.User = Security(
deps.get_current_user,
),
) -> list[schemas.ChatHistory]:
) -> Page[schemas.ChatHistory]:
"""
Retrieve a list of chat histories with pagination support.
"""
try:
chat_histories = crud.chat.get_chat_history(
db, user_id=current_user.id, skip=skip, limit=limit)
return chat_histories
db, user_id=current_user.id)
return paginate(chat_histories)
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error listing chat histories: {str(e)}")
@ -66,6 +65,8 @@ def create_chat_history(
@router.get("/{conversation_id}", response_model=schemas.ChatHistory)
def read_chat_history(
conversation_id: uuid.UUID,
skip: int = 0,
limit: int = 20,
db: Session = Depends(deps.get_db),
current_user: models.User = Security(
deps.get_current_user,
@ -75,7 +76,7 @@ def read_chat_history(
Read a chat history by ID
"""
try:
chat_history = crud.chat.get_by_id(db, id=conversation_id)
chat_history = crud.chat.get_by_id(db, id=conversation_id, skip=skip, limit=limit)
if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException(
status_code=404, detail="Chat history not found")

View File

@ -10,8 +10,11 @@ from private_gpt.users.core.config import settings
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 1 # 12 hrs # Default Value
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days # Default Value
ALGORITHM = "HS256"
JWT_SECRET_KEY = settings.SECRET_KEY
JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY
# JWT_SECRET_KEY = settings.SECRET_KEY
# JWT_REFRESH_SECRET_KEY = settings.REFRESH_KEY
JWT_SECRET_KEY = "QUICKGPT"
JWT_REFRESH_SECRET_KEY = "QUICKGPT_REFRESH"
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

View File

@ -7,10 +7,11 @@ import uuid
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.chat import ChatHistory, ChatItem
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryCreate, ChatItemCreate, ChatItemUpdate
from fastapi_pagination import Page, paginate
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
def get_by_id(self, db: Session, *, id: uuid.UUID) -> Optional[ChatHistory]:
def get_by_id(self, db: Session, *, id: uuid.UUID, skip: int=0, limit: int=10) -> Optional[ChatHistory]:
chat_history = (
db.query(self.model)
.filter(ChatHistory.conversation_id == id)
@ -21,7 +22,9 @@ class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryCreate]):
chat_history.chat_items = (
db.query(ChatItem)
.filter(ChatItem.conversation_id == id)
.order_by(asc(getattr(ChatItem, 'index')))
.order_by(desc(getattr(ChatItem, 'index')))
.offset(skip)
.limit(limit)
.all()
)
return chat_history

View File

@ -14,8 +14,8 @@ embedding:
embed_dim: 768
ollama:
llm_model: mistral
embedding_model: nomic-embed-text
llm_model: llama3
embedding_model: mxbai-embed-large
api_base: http://localhost:11434
nodestore:

View File

@ -11,8 +11,8 @@ embedding:
mode: ollama
ollama:
llm_model: mistral
embedding_model: nomic-embed-text
llm_model: llama3
embedding_model: mxbai-embed-large
api_base: http://localhost:11434
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
keep_alive: 5m