From b67b66cd447ecbbdea95e4dc8b6052c547a2ac79 Mon Sep 17 00:00:00 2001 From: Saurab-Shrestha Date: Thu, 11 Jan 2024 17:48:35 +0545 Subject: [PATCH] Added users model for authentication with roles and user roles --- .env | 15 + =1.4, | 0 =1.9.1 | 0 alembic.ini | 116 +++++++ alembic/README | 1 + alembic/env.py | 85 ++++++ alembic/script.py.mako | 26 ++ ..._create_user_model_role_model_and_user_.py | 61 ++++ .../9491fd6d6fad_create_user_model.py | 30 ++ alembic_readme.md | 7 + private_gpt/alembic.ini | 116 +++++++ private_gpt/launcher.py | 12 +- private_gpt/paths.py | 1 + private_gpt/ui/admin_ui.py | 289 ++++++++++++++++++ private_gpt/ui/common.py | 25 ++ private_gpt/ui/local_data/tests/.lock | 1 + private_gpt/ui/local_data/tests/meta.json | 1 + private_gpt/ui/ui.py | 51 +--- private_gpt/ui/users_ui.py | 141 +++++++++ private_gpt/users/__init__.py | 0 private_gpt/users/api/__init__.py | 0 private_gpt/users/api/deps.py | 68 +++++ private_gpt/users/api/v1/__init__.py | 0 private_gpt/users/api/v1/api.py | 9 + private_gpt/users/api/v1/routers/__init__.py | 0 private_gpt/users/api/v1/routers/auth.py | 132 ++++++++ private_gpt/users/api/v1/routers/roles.py | 18 ++ .../users/api/v1/routers/user_roles.py | 58 ++++ private_gpt/users/api/v1/routers/users.py | 165 ++++++++++ private_gpt/users/constants/__init__.py | 0 private_gpt/users/constants/role.py | 17 ++ private_gpt/users/core/__init__.py | 0 private_gpt/users/core/config.py | 64 ++++ private_gpt/users/core/security.py | 44 +++ private_gpt/users/crud/__init__.py | 3 + private_gpt/users/crud/base.py | 64 ++++ private_gpt/users/crud/role_crud.py | 13 + private_gpt/users/crud/user_crud.py | 77 +++++ private_gpt/users/crud/user_role_crud.py | 16 + private_gpt/users/db/__init__.py | 0 private_gpt/users/db/base.py | 4 + private_gpt/users/db/base_class.py | 15 + private_gpt/users/db/init_db.py | 11 + private_gpt/users/db/session.py | 13 + private_gpt/users/models/__init__.py | 1 + private_gpt/users/models/role.py | 7 + private_gpt/users/models/user.py | 47 +++ private_gpt/users/models/user_role.py | 26 ++ private_gpt/users/schemas/__init__.py | 4 + private_gpt/users/schemas/role.py | 37 +++ private_gpt/users/schemas/token.py | 16 + private_gpt/users/schemas/user.py | 50 +++ private_gpt/users/schemas/user_role.py | 41 +++ 53 files changed, 1955 insertions(+), 43 deletions(-) create mode 100644 .env create mode 100644 =1.4, create mode 100644 =1.9.1 create mode 100644 alembic.ini create mode 100644 alembic/README create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/19e9eccf2c81_create_user_model_role_model_and_user_.py create mode 100644 alembic/versions/9491fd6d6fad_create_user_model.py create mode 100644 alembic_readme.md create mode 100644 private_gpt/alembic.ini create mode 100644 private_gpt/ui/admin_ui.py create mode 100644 private_gpt/ui/common.py create mode 100644 private_gpt/ui/local_data/tests/.lock create mode 100644 private_gpt/ui/local_data/tests/meta.json create mode 100644 private_gpt/ui/users_ui.py create mode 100644 private_gpt/users/__init__.py create mode 100644 private_gpt/users/api/__init__.py create mode 100644 private_gpt/users/api/deps.py create mode 100644 private_gpt/users/api/v1/__init__.py create mode 100644 private_gpt/users/api/v1/api.py create mode 100644 private_gpt/users/api/v1/routers/__init__.py create mode 100644 private_gpt/users/api/v1/routers/auth.py create mode 100644 private_gpt/users/api/v1/routers/roles.py create mode 100644 private_gpt/users/api/v1/routers/user_roles.py create mode 100644 private_gpt/users/api/v1/routers/users.py create mode 100644 private_gpt/users/constants/__init__.py create mode 100644 private_gpt/users/constants/role.py create mode 100644 private_gpt/users/core/__init__.py create mode 100644 private_gpt/users/core/config.py create mode 100644 private_gpt/users/core/security.py create mode 100644 private_gpt/users/crud/__init__.py create mode 100644 private_gpt/users/crud/base.py create mode 100644 private_gpt/users/crud/role_crud.py create mode 100644 private_gpt/users/crud/user_crud.py create mode 100644 private_gpt/users/crud/user_role_crud.py create mode 100644 private_gpt/users/db/__init__.py create mode 100644 private_gpt/users/db/base.py create mode 100644 private_gpt/users/db/base_class.py create mode 100644 private_gpt/users/db/init_db.py create mode 100644 private_gpt/users/db/session.py create mode 100644 private_gpt/users/models/__init__.py create mode 100644 private_gpt/users/models/role.py create mode 100644 private_gpt/users/models/user.py create mode 100644 private_gpt/users/models/user_role.py create mode 100644 private_gpt/users/schemas/__init__.py create mode 100644 private_gpt/users/schemas/role.py create mode 100644 private_gpt/users/schemas/token.py create mode 100644 private_gpt/users/schemas/user.py create mode 100644 private_gpt/users/schemas/user_role.py diff --git a/.env b/.env new file mode 100644 index 00000000..0b2615c0 --- /dev/null +++ b/.env @@ -0,0 +1,15 @@ +PORT=8000 +ENVIRONMENT=dev + +DB_HOST=localhost +DB_USER=postgres +DB_PASSWORD=quick +DB_NAME=QuickGpt + +SUPER_ADMIN_EMAIL=superadmin@email.com +SUPER_ADMIN_PASSWORD=supersecretpassword +SUPER_ADMIN_ACCOUNT_NAME=superaccount + +SECRET_KEY=ba9dc3f976cf8fb40519dcd152a8d7d21c0b7861d841711cdb2602be8e85fd7c +ACCESS_TOKEN_EXPIRE_MINUTES=60 +REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days diff --git a/=1.4, b/=1.4, new file mode 100644 index 00000000..e69de29b diff --git a/=1.9.1 b/=1.9.1 new file mode 100644 index 00000000..e69de29b diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 00000000..c10d4ca0 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,116 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 00000000..98e4f9c4 --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 00000000..e69c895e --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,85 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +from private_gpt.users.db.base_class import Base +from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI + +from private_gpt.users.models.user import User +from private_gpt.users.models.role import Role +from private_gpt.users.models.user_role import UserRole + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. +config.set_section_option(config.config_ini_section, "sqlalchemy.url", SQLALCHEMY_DATABASE_URI) + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() \ No newline at end of file diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 00000000..fbc4b07d --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/19e9eccf2c81_create_user_model_role_model_and_user_.py b/alembic/versions/19e9eccf2c81_create_user_model_role_model_and_user_.py new file mode 100644 index 00000000..5830206f --- /dev/null +++ b/alembic/versions/19e9eccf2c81_create_user_model_role_model_and_user_.py @@ -0,0 +1,61 @@ +"""Create user model, role model and user role model + +Revision ID: 19e9eccf2c81 +Revises: +Create Date: 2024-01-11 16:33:53.253969 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '19e9eccf2c81' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('roles', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False) + op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=False) + op.create_table('users', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('email', sa.String(length=225), nullable=False), + sa.Column('hashed_password', sa.LargeBinary(), nullable=False), + sa.Column('fullname', sa.String(length=225), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('last_login', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('email') + ) + op.create_table('user_roles', + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('role_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('user_id', 'role_id'), + sa.UniqueConstraint('user_id', 'role_id', name='unique_user_role') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('user_roles') + op.drop_table('users') + op.drop_index(op.f('ix_roles_name'), table_name='roles') + op.drop_index(op.f('ix_roles_id'), table_name='roles') + op.drop_table('roles') + # ### end Alembic commands ### diff --git a/alembic/versions/9491fd6d6fad_create_user_model.py b/alembic/versions/9491fd6d6fad_create_user_model.py new file mode 100644 index 00000000..9b291fc2 --- /dev/null +++ b/alembic/versions/9491fd6d6fad_create_user_model.py @@ -0,0 +1,30 @@ +"""Create user model + +Revision ID: 9491fd6d6fad +Revises: 19e9eccf2c81 +Create Date: 2024-01-11 17:02:25.882848 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9491fd6d6fad' +down_revision: Union[str, None] = '19e9eccf2c81' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('unique_user_role', 'user_roles', type_='unique') + # ### end Alembic commands ### diff --git a/alembic_readme.md b/alembic_readme.md new file mode 100644 index 00000000..219862b1 --- /dev/null +++ b/alembic_readme.md @@ -0,0 +1,7 @@ +## **Alembic migrations:** + +`alembic init alembic` # initialize the alembic + +`alembic revision --autogenerate -m "Create user model"` # first migration + +`alembic upgrade 66b63a` # reflect migration on database (here 66b63a) is ssh value \ No newline at end of file diff --git a/private_gpt/alembic.ini b/private_gpt/alembic.ini new file mode 100644 index 00000000..c10d4ca0 --- /dev/null +++ b/private_gpt/alembic.ini @@ -0,0 +1,116 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/private_gpt/launcher.py b/private_gpt/launcher.py index 791e841a..f6fb8708 100644 --- a/private_gpt/launcher.py +++ b/private_gpt/launcher.py @@ -11,6 +11,8 @@ from private_gpt.server.completions.completions_router import completions_router from private_gpt.server.embeddings.embeddings_router import embeddings_router from private_gpt.server.health.health_router import health_router from private_gpt.server.ingest.ingest_router import ingest_router +from private_gpt.users.api.v1.api import api_router + from private_gpt.settings.settings import Settings logger = logging.getLogger(__name__) @@ -30,6 +32,9 @@ def create_app(root_injector: Injector) -> FastAPI: app.include_router(ingest_router) app.include_router(embeddings_router) app.include_router(health_router) + + app.include_router(api_router) + settings = root_injector.get(Settings) if settings.server.cors.enabled: @@ -45,9 +50,12 @@ def create_app(root_injector: Injector) -> FastAPI: if settings.ui.enabled: logger.debug("Importing the UI module") - from private_gpt.ui.ui import PrivateGptUi + from private_gpt.ui.admin_ui import PrivateAdminGptUi + admin_ui = root_injector.get(PrivateAdminGptUi) + admin_ui.mount_in_admin_app(app, '/admin') + from private_gpt.ui.ui import PrivateGptUi ui = root_injector.get(PrivateGptUi) ui.mount_in_app(app, settings.ui.path) - return app + return app \ No newline at end of file diff --git a/private_gpt/paths.py b/private_gpt/paths.py index 59db3a49..b7b4d466 100644 --- a/private_gpt/paths.py +++ b/private_gpt/paths.py @@ -7,6 +7,7 @@ from private_gpt.settings.settings import settings def _absolute_or_from_project_root(path: str) -> Path: if path.startswith("/"): return Path(path) + return PROJECT_ROOT_PATH / path diff --git a/private_gpt/ui/admin_ui.py b/private_gpt/ui/admin_ui.py new file mode 100644 index 00000000..e77e2b8a --- /dev/null +++ b/private_gpt/ui/admin_ui.py @@ -0,0 +1,289 @@ +"""This file should be imported only and only if you want to run the UI locally.""" +import itertools +import logging +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import gradio as gr # type: ignore +from fastapi import FastAPI +from gradio.themes.utils.colors import slate # type: ignore +from injector import inject, singleton +from llama_index.llms import ChatMessage, ChatResponse, MessageRole +from pydantic import BaseModel + +from private_gpt.constants import PROJECT_ROOT_PATH +from private_gpt.di import global_injector +from private_gpt.server.chat.chat_service import ChatService, CompletionGen +from private_gpt.server.chunks.chunks_service import Chunk, ChunksService +from private_gpt.server.ingest.ingest_service import IngestService +from private_gpt.settings.settings import settings +from private_gpt.ui.images import logo_svg +from private_gpt.ui.common import Source + +logger = logging.getLogger(__name__) + +THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) +# Should be "private_gpt/ui/avatar-bot.ico" +AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico" + +UI_TAB_TITLE = "My Private GPT" + +SOURCES_SEPARATOR = "\n\n Sources: \n" + +MODES = ["Query Docs", "Search in Docs", "LLM Chat"] + +@singleton +class PrivateAdminGptUi: + @inject + def __init__( + self, + ingest_service: IngestService, + chat_service: ChatService, + chunks_service: ChunksService, + ) -> None: + self._ingest_service = ingest_service + self._chat_service = chat_service + self._chunks_service = chunks_service + + # Cache the UI blocks + self._ui_block = None + + # Initialize system prompt based on default mode + self.mode = MODES[0] + self._system_prompt = self._get_default_system_prompt(self.mode) + + def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: + def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: + full_response: str = "" + stream = completion_gen.response + for delta in stream: + if isinstance(delta, str): + full_response += str(delta) + elif isinstance(delta, ChatResponse): + full_response += delta.delta or "" + yield full_response + + if completion_gen.sources: + full_response += SOURCES_SEPARATOR + cur_sources = Source.curate_sources(completion_gen.sources) + sources_text = "\n\n\n".join( + f"{index}. {source.file} (page {source.page})" + for index, source in enumerate(cur_sources, start=1) + ) + full_response += sources_text + yield full_response + + def build_history() -> list[ChatMessage]: + history_messages: list[ChatMessage] = list( + itertools.chain( + *[ + [ + ChatMessage(content=interaction[0], role=MessageRole.USER), + ChatMessage( + # Remove from history content the Sources information + content=interaction[1].split(SOURCES_SEPARATOR)[0], + role=MessageRole.ASSISTANT, + ), + ] + for interaction in history + ] + ) + ) + + # max 20 messages to try to avoid context overflow + return history_messages[:20] + + new_message = ChatMessage(content=message, role=MessageRole.USER) + all_messages = [*build_history(), new_message] + # If a system prompt is set, add it as a system message + if self._system_prompt: + all_messages.insert( + 0, + ChatMessage( + content=self._system_prompt, + role=MessageRole.SYSTEM, + ), + ) + match mode: + case "Query Docs": + query_stream = self._chat_service.stream_chat( + messages=all_messages, + use_context=True, + ) + yield from yield_deltas(query_stream) + case "LLM Chat": + llm_stream = self._chat_service.stream_chat( + messages=all_messages, + use_context=False, + ) + yield from yield_deltas(llm_stream) + + case "Search in Docs": + response = self._chunks_service.retrieve_relevant( + text=message, limit=4, prev_next_chunks=0 + ) + + sources = Source.curate_sources(response) + + yield "\n\n\n".join( + f"{index}. **{source.file} " + f"(page {source.page})**\n " + f"{source.text}" + for index, source in enumerate(sources, start=1) + ) + + # On initialization and on mode change, this function set the system prompt + # to the default prompt based on the mode (and user settings). + @staticmethod + def _get_default_system_prompt(mode: str) -> str: + p = "" + match mode: + # For query chat mode, obtain default system prompt from settings + case "Query Docs": + p = settings().ui.default_query_system_prompt + # For chat mode, obtain default system prompt from settings + case "LLM Chat": + p = settings().ui.default_chat_system_prompt + # For any other mode, clear the system prompt + case _: + p = "" + return p + + def _set_system_prompt(self, system_prompt_input: str) -> None: + logger.info(f"Setting system prompt to: {system_prompt_input}") + self._system_prompt = system_prompt_input + + def _set_current_mode(self, mode: str) -> Any: + self.mode = mode + self._set_system_prompt(self._get_default_system_prompt(mode)) + # Update placeholder and allow interaction if default system prompt is set + if self._system_prompt: + return gr.update(placeholder=self._system_prompt, interactive=True) + # Update placeholder and disable interaction if no default system prompt is set + else: + return gr.update(placeholder=self._system_prompt, interactive=False) + + def _list_ingested_files(self) -> list[list[str]]: + files = set() + for ingested_document in self._ingest_service.list_ingested(): + if ingested_document.doc_metadata is None: + # Skipping documents without metadata + continue + file_name = ingested_document.doc_metadata.get( + "file_name", "[FILE NAME MISSING]" + ) + files.add(file_name) + return [[row] for row in files] + + def _upload_file(self, files: list[str]) -> None: + logger.debug("Loading count=%s files", len(files)) + paths = [Path(file) for file in files] + self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) + + def _build_admin_ui_blocks(self) -> gr.Blocks: + logger.debug("Creating the UI blocks") + with gr.Blocks( + title=UI_TAB_TITLE, + theme=gr.themes.Soft(primary_hue=slate), + css=".logo { " + "display:flex;" + "background-color: #C7BAFF;" + "height: 80px;" + "border-radius: 8px;" + "align-content: center;" + "justify-content: center;" + "align-items: center;" + "}" + ".logo img { height: 25% }" + ".contain { display: flex !important; flex-direction: column !important; }" + "#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }" + "#chatbot { flex-grow: 1 !important; overflow: auto !important;}" + "#col { height: calc(100vh - 112px - 16px) !important; }", + ) as blocks: + with gr.Row(): + gr.HTML(f"