Working refactor. Dependency clean-up pending.

This commit is contained in:
imartinez 2024-02-28 18:45:54 +01:00
parent 12f3a39e8a
commit d0a7d991a2
20 changed files with 877 additions and 907 deletions

1424
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,7 @@
import logging import logging
from injector import inject, singleton from injector import inject, singleton
from llama_index import MockEmbedding from llama_index.core.embeddings import BaseEmbedding, MockEmbedding
from llama_index.embeddings.base import BaseEmbedding
from private_gpt.paths import models_cache_path from private_gpt.paths import models_cache_path
from private_gpt.settings.settings import Settings from private_gpt.settings.settings import Settings
@ -20,7 +19,7 @@ class EmbeddingComponent:
logger.info("Initializing the embedding model in mode=%s", embedding_mode) logger.info("Initializing the embedding model in mode=%s", embedding_mode)
match embedding_mode: match embedding_mode:
case "local": case "local":
from llama_index.embeddings import HuggingFaceEmbedding from llama_index.embeddings.huggingface import HuggingFaceEmbedding
self.embedding_model = HuggingFaceEmbedding( self.embedding_model = HuggingFaceEmbedding(
model_name=settings.local.embedding_hf_model_name, model_name=settings.local.embedding_hf_model_name,
@ -36,7 +35,7 @@ class EmbeddingComponent:
endpoint_name=settings.sagemaker.embedding_endpoint_name, endpoint_name=settings.sagemaker.embedding_endpoint_name,
) )
case "openai": case "openai":
from llama_index import OpenAIEmbedding from llama_index.embeddings.openai import OpenAIEmbedding
openai_settings = settings.openai.api_key openai_settings = settings.openai.api_key
self.embedding_model = OpenAIEmbedding(api_key=openai_settings) self.embedding_model = OpenAIEmbedding(api_key=openai_settings)

View File

@ -8,16 +8,13 @@ import threading
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from llama_index import ( from llama_index.core.data_structs import IndexDict
Document, from llama_index.core.embeddings.utils import EmbedType
ServiceContext, from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
StorageContext, from llama_index.core.indices.base import BaseIndex
VectorStoreIndex, from llama_index.core.ingestion import run_transformations
load_index_from_storage, from llama_index.core.schema import Document, TransformComponent
) from llama_index.core.storage import StorageContext
from llama_index.data_structs import IndexDict
from llama_index.indices.base import BaseIndex
from llama_index.ingestion import run_transformations
from private_gpt.components.ingest.ingest_helper import IngestionHelper from private_gpt.components.ingest.ingest_helper import IngestionHelper
from private_gpt.paths import local_data_path from private_gpt.paths import local_data_path
@ -30,13 +27,15 @@ class BaseIngestComponent(abc.ABC):
def __init__( def __init__(
self, self,
storage_context: StorageContext, storage_context: StorageContext,
service_context: ServiceContext, embed_model: EmbedType,
transformations: list[TransformComponent],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
logger.debug("Initializing base ingest component type=%s", type(self).__name__) logger.debug("Initializing base ingest component type=%s", type(self).__name__)
self.storage_context = storage_context self.storage_context = storage_context
self.service_context = service_context self.embed_model = embed_model
self.transformations = transformations
@abc.abstractmethod @abc.abstractmethod
def ingest(self, file_name: str, file_data: Path) -> list[Document]: def ingest(self, file_name: str, file_data: Path) -> list[Document]:
@ -55,11 +54,12 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
def __init__( def __init__(
self, self,
storage_context: StorageContext, storage_context: StorageContext,
service_context: ServiceContext, embed_model: EmbedType,
transformations: list[TransformComponent],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(storage_context, service_context, *args, **kwargs) super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
self.show_progress = True self.show_progress = True
self._index_thread_lock = ( self._index_thread_lock = (
@ -73,9 +73,10 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
# Load the index with store_nodes_override=True to be able to delete them # Load the index with store_nodes_override=True to be able to delete them
index = load_index_from_storage( index = load_index_from_storage(
storage_context=self.storage_context, storage_context=self.storage_context,
service_context=self.service_context,
store_nodes_override=True, # Force store nodes in index and document stores store_nodes_override=True, # Force store nodes in index and document stores
show_progress=self.show_progress, show_progress=self.show_progress,
embed_model=self.embed_model,
transformations=self.transformations,
) )
except ValueError: except ValueError:
# There are no index in the storage context, creating a new one # There are no index in the storage context, creating a new one
@ -83,9 +84,10 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
[], [],
storage_context=self.storage_context, storage_context=self.storage_context,
service_context=self.service_context,
store_nodes_override=True, # Force store nodes in index and document stores store_nodes_override=True, # Force store nodes in index and document stores
show_progress=self.show_progress, show_progress=self.show_progress,
embed_model=self.embed_model,
transformations=self.transformations,
) )
index.storage_context.persist(persist_dir=local_data_path) index.storage_context.persist(persist_dir=local_data_path)
return index return index
@ -106,11 +108,12 @@ class SimpleIngestComponent(BaseIngestComponentWithIndex):
def __init__( def __init__(
self, self,
storage_context: StorageContext, storage_context: StorageContext,
service_context: ServiceContext, embed_model: EmbedType,
transformations: list[TransformComponent],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(storage_context, service_context, *args, **kwargs) super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
def ingest(self, file_name: str, file_data: Path) -> list[Document]: def ingest(self, file_name: str, file_data: Path) -> list[Document]:
logger.info("Ingesting file_name=%s", file_name) logger.info("Ingesting file_name=%s", file_name)
@ -151,16 +154,17 @@ class BatchIngestComponent(BaseIngestComponentWithIndex):
def __init__( def __init__(
self, self,
storage_context: StorageContext, storage_context: StorageContext,
service_context: ServiceContext, embed_model: EmbedType,
transformations: list[TransformComponent],
count_workers: int, count_workers: int,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(storage_context, service_context, *args, **kwargs) super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
# Make an efficient use of the CPU and GPU, the embedding # Make an efficient use of the CPU and GPU, the embedding
# must be in the transformations # must be in the transformations
assert ( assert (
len(self.service_context.transformations) >= 2 len(self.transformations) >= 2
), "Embeddings must be in the transformations" ), "Embeddings must be in the transformations"
assert count_workers > 0, "count_workers must be > 0" assert count_workers > 0, "count_workers must be > 0"
self.count_workers = count_workers self.count_workers = count_workers
@ -197,7 +201,7 @@ class BatchIngestComponent(BaseIngestComponentWithIndex):
logger.debug("Transforming count=%s documents into nodes", len(documents)) logger.debug("Transforming count=%s documents into nodes", len(documents))
nodes = run_transformations( nodes = run_transformations(
documents, # type: ignore[arg-type] documents, # type: ignore[arg-type]
self.service_context.transformations, self.transformations,
show_progress=self.show_progress, show_progress=self.show_progress,
) )
# Locking the index to avoid concurrent writes # Locking the index to avoid concurrent writes
@ -225,16 +229,17 @@ class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
def __init__( def __init__(
self, self,
storage_context: StorageContext, storage_context: StorageContext,
service_context: ServiceContext, embed_model: EmbedType,
transformations: list[TransformComponent],
count_workers: int, count_workers: int,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(storage_context, service_context, *args, **kwargs) super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
# To make an efficient use of the CPU and GPU, the embeddings # To make an efficient use of the CPU and GPU, the embeddings
# must be in the transformations (to be computed in batches) # must be in the transformations (to be computed in batches)
assert ( assert (
len(self.service_context.transformations) >= 2 len(self.transformations) >= 2
), "Embeddings must be in the transformations" ), "Embeddings must be in the transformations"
assert count_workers > 0, "count_workers must be > 0" assert count_workers > 0, "count_workers must be > 0"
self.count_workers = count_workers self.count_workers = count_workers
@ -278,7 +283,7 @@ class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
logger.debug("Transforming count=%s documents into nodes", len(documents)) logger.debug("Transforming count=%s documents into nodes", len(documents))
nodes = run_transformations( nodes = run_transformations(
documents, # type: ignore[arg-type] documents, # type: ignore[arg-type]
self.service_context.transformations, self.transformations,
show_progress=self.show_progress, show_progress=self.show_progress,
) )
# Locking the index to avoid concurrent writes # Locking the index to avoid concurrent writes
@ -311,18 +316,29 @@ class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
def get_ingestion_component( def get_ingestion_component(
storage_context: StorageContext, storage_context: StorageContext,
service_context: ServiceContext, embed_model: EmbedType,
transformations: list[TransformComponent],
settings: Settings, settings: Settings,
) -> BaseIngestComponent: ) -> BaseIngestComponent:
"""Get the ingestion component for the given configuration.""" """Get the ingestion component for the given configuration."""
ingest_mode = settings.embedding.ingest_mode ingest_mode = settings.embedding.ingest_mode
if ingest_mode == "batch": if ingest_mode == "batch":
return BatchIngestComponent( return BatchIngestComponent(
storage_context, service_context, settings.embedding.count_workers storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
) )
elif ingest_mode == "parallel": elif ingest_mode == "parallel":
return ParallelizedIngestComponent( return ParallelizedIngestComponent(
storage_context, service_context, settings.embedding.count_workers storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
) )
else: else:
return SimpleIngestComponent(storage_context, service_context) return SimpleIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
)

View File

@ -1,14 +1,53 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Type
from llama_index import Document from llama_index.core.readers import StringIterableReader
from llama_index.readers import JSONReader, StringIterableReader from llama_index.core.readers.base import BaseReader
from llama_index.readers.file.base import DEFAULT_FILE_READER_CLS from llama_index.core.readers.json import JSONReader
from llama_index.core.schema import Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Inspired by the `llama_index.core.readers.file.base` module
def _try_loading_included_file_formats() -> Dict[str, Type[BaseReader]]:
try:
from llama_index.readers.file.docs import DocxReader, HWPReader, PDFReader
from llama_index.readers.file.epub import EpubReader
from llama_index.readers.file.image import ImageReader
from llama_index.readers.file.ipynb import IPYNBReader
from llama_index.readers.file.markdown import MarkdownReader
from llama_index.readers.file.mbox import MboxReader
from llama_index.readers.file.tabular import PandasCSVReader
from llama_index.readers.file.slides import PptxReader
from llama_index.readers.file.video_audio import VideoAudioReader
except ImportError:
raise ImportError("`llama-index-readers-file` package not found")
default_file_reader_cls: Dict[str, Type[BaseReader]] = {
".hwp": HWPReader,
".pdf": PDFReader,
".docx": DocxReader,
".pptx": PptxReader,
".ppt": PptxReader,
".pptm": PptxReader,
".jpg": ImageReader,
".png": ImageReader,
".jpeg": ImageReader,
".mp3": VideoAudioReader,
".mp4": VideoAudioReader,
".csv": PandasCSVReader,
".epub": EpubReader,
".md": MarkdownReader,
".mbox": MboxReader,
".ipynb": IPYNBReader,
}
return default_file_reader_cls
# Patching the default file reader to support other file types # Patching the default file reader to support other file types
FILE_READER_CLS = DEFAULT_FILE_READER_CLS.copy() FILE_READER_CLS = _try_loading_included_file_formats()
FILE_READER_CLS.update( FILE_READER_CLS.update(
{ {
".json": JSONReader, ".json": JSONReader,

View File

@ -1,15 +1,16 @@
import logging import logging
from injector import inject, singleton from injector import inject, singleton
from llama_index import set_global_tokenizer from llama_index.core.llms import LLM, MockLLM
from llama_index.llms import MockLLM from llama_index.core.utils import set_global_tokenizer
from llama_index.llms.base import LLM from llama_index.core.settings import Settings as LlamaIndexSettings
from transformers import AutoTokenizer # type: ignore from transformers import AutoTokenizer # type: ignore
from private_gpt.components.llm.prompt_helper import get_prompt_style from private_gpt.components.llm.prompt_helper import get_prompt_style
from private_gpt.paths import models_cache_path, models_path from private_gpt.paths import models_cache_path, models_path
from private_gpt.settings.settings import Settings from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ class LLMComponent:
logger.info("Initializing the LLM in mode=%s", llm_mode) logger.info("Initializing the LLM in mode=%s", llm_mode)
match settings.llm.mode: match settings.llm.mode:
case "local": case "local":
from llama_index.llms import LlamaCPP from llama_index.llms.llama_cpp import LlamaCPP
prompt_style = get_prompt_style(settings.local.prompt_style) prompt_style = get_prompt_style(settings.local.prompt_style)
@ -41,6 +42,7 @@ class LLMComponent:
max_new_tokens=settings.llm.max_new_tokens, max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window, context_window=settings.llm.context_window,
generate_kwargs={}, generate_kwargs={},
callback_manager=LlamaIndexSettings.callback_manager,
# All to GPU # All to GPU
model_kwargs={"n_gpu_layers": -1, "offload_kqv": True}, model_kwargs={"n_gpu_layers": -1, "offload_kqv": True},
# transform inputs into Llama2 format # transform inputs into Llama2 format
@ -58,7 +60,7 @@ class LLMComponent:
context_window=settings.llm.context_window, context_window=settings.llm.context_window,
) )
case "openai": case "openai":
from llama_index.llms import OpenAI from llama_index.llms.openai import OpenAI
openai_settings = settings.openai openai_settings = settings.openai
self.llm = OpenAI( self.llm = OpenAI(
@ -67,7 +69,7 @@ class LLMComponent:
model=openai_settings.model, model=openai_settings.model,
) )
case "openailike": case "openailike":
from llama_index.llms import OpenAILike from llama_index.llms.openai_like import OpenAILike
openai_settings = settings.openai openai_settings = settings.openai
self.llm = OpenAILike( self.llm = OpenAILike(
@ -81,7 +83,7 @@ class LLMComponent:
case "mock": case "mock":
self.llm = MockLLM() self.llm = MockLLM()
case "ollama": case "ollama":
from llama_index.llms import Ollama from llama_index.llms.ollama import Ollama
ollama_settings = settings.ollama ollama_settings = settings.ollama
self.llm = Ollama( self.llm = Ollama(

View File

@ -3,11 +3,7 @@ import logging
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Literal from typing import Any, Literal
from llama_index.llms import ChatMessage, MessageRole from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.llms.llama_utils import (
completion_to_prompt,
messages_to_prompt,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,7 +69,9 @@ class DefaultPromptStyle(AbstractPromptStyle):
class Llama2PromptStyle(AbstractPromptStyle): class Llama2PromptStyle(AbstractPromptStyle):
"""Simple prompt style that just uses the default llama_utils functions. """Simple prompt style that uses llama 2 prompt style.
Inspired by llama_index/legacy/llms/llama_utils.py
It transforms the sequence of messages into a prompt that should look like: It transforms the sequence of messages into a prompt that should look like:
```text ```text
@ -83,11 +81,61 @@ class Llama2PromptStyle(AbstractPromptStyle):
``` ```
""" """
BOS, EOS = "<s>", "</s>"
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. \
Always answer as helpfully as possible and follow ALL given instructions. \
Do not speculate or make up information. \
Do not reference any given instructions or context. \
"""
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return messages_to_prompt(messages) string_messages: list[str] = []
if messages[0].role == MessageRole.SYSTEM:
# pull out the system message (if it exists in messages)
system_message_str = messages[0].content or ""
messages = messages[1:]
else:
system_message_str = self.DEFAULT_SYSTEM_PROMPT
system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}"
for i in range(0, len(messages), 2):
# first message should always be a user
user_message = messages[i]
assert user_message.role == MessageRole.USER
if i == 0:
# make sure system prompt is included at the start
str_message = f"{self.BOS} {self.B_INST} {system_message_str} "
else:
# end previous user-assistant interaction
string_messages[-1] += f" {self.EOS}"
# no need to include system prompt
str_message = f"{self.BOS} {self.B_INST} "
# include user message content
str_message += f"{user_message.content} {self.E_INST}"
if len(messages) > (i + 1):
# if assistant message exists, add to str_message
assistant_message = messages[i + 1]
assert assistant_message.role == MessageRole.ASSISTANT
str_message += f" {assistant_message.content}"
string_messages.append(str_message)
return "".join(string_messages)
def _completion_to_prompt(self, completion: str) -> str: def _completion_to_prompt(self, completion: str) -> str:
return completion_to_prompt(completion) system_prompt_str = self.DEFAULT_SYSTEM_PROMPT
return (
f"{self.BOS} {self.B_INST} {self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} "
f"{completion.strip()} {self.E_INST}"
)
class TagPromptStyle(AbstractPromptStyle): class TagPromptStyle(AbstractPromptStyle):

View File

@ -1,9 +1,9 @@
import logging import logging
from injector import inject, singleton from injector import inject, singleton
from llama_index.storage.docstore import BaseDocumentStore, SimpleDocumentStore from llama_index.core.storage.docstore import BaseDocumentStore, SimpleDocumentStore
from llama_index.storage.index_store import SimpleIndexStore from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.storage.index_store.types import BaseIndexStore from llama_index.core.storage.index_store.types import BaseIndexStore
from private_gpt.paths import local_data_path from private_gpt.paths import local_data_path

View File

@ -1,9 +1,25 @@
from collections.abc import Generator
from typing import Any from typing import Any
from llama_index.schema import BaseNode, MetadataMode from llama_index.core.schema import BaseNode, MetadataMode
from llama_index.vector_stores import ChromaVectorStore from llama_index.core.vector_stores.utils import node_to_metadata_dict
from llama_index.vector_stores.chroma import chunk_list from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.utils import node_to_metadata_dict
def chunk_list(
lst: list[BaseNode], max_chunk_size: int
) -> Generator[list[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]
class BatchedChromaVectorStore(ChromaVectorStore): class BatchedChromaVectorStore(ChromaVectorStore):

View File

@ -2,9 +2,8 @@ import logging
import typing import typing
from injector import inject, singleton from injector import inject, singleton
from llama_index import VectorStoreIndex from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
from llama_index.indices.vector_store import VectorIndexRetriever from llama_index.core.vector_stores.types import VectorStore
from llama_index.vector_stores.types import VectorStore
from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.open_ai.extensions.context_filter import ContextFilter
@ -41,7 +40,7 @@ class VectorStoreComponent:
def __init__(self, settings: Settings) -> None: def __init__(self, settings: Settings) -> None:
match settings.vectorstore.database: match settings.vectorstore.database:
case "pgvector": case "pgvector":
from llama_index.vector_stores import PGVectorStore from llama_index.vector_stores.postgres import PGVectorStore
if settings.pgvector is None: if settings.pgvector is None:
raise ValueError( raise ValueError(

View File

@ -4,6 +4,8 @@ import logging
from fastapi import Depends, FastAPI, Request from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from injector import Injector from injector import Injector
from llama_index.core.callbacks import CallbackManager
from llama_index.core.callbacks.global_handlers import create_global_handler
from private_gpt.server.chat.chat_router import chat_router from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router from private_gpt.server.chunks.chunks_router import chunks_router
@ -12,6 +14,7 @@ from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import Settings from private_gpt.settings.settings import Settings
from llama_index.core.settings import Settings as LlamaIndexSettings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,6 +34,10 @@ def create_app(root_injector: Injector) -> FastAPI:
app.include_router(embeddings_router) app.include_router(embeddings_router)
app.include_router(health_router) app.include_router(health_router)
# Add LlamaIndex simple observability
global_handler = create_global_handler("simple")
LlamaIndexSettings.callback_manager = CallbackManager([global_handler])
settings = root_injector.get(Settings) settings = root_injector.get(Settings)
if settings.server.cors.enabled: if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware") logger.debug("Setting up CORS middleware")

View File

@ -1,11 +1,6 @@
"""FastAPI app creation, logger configuration and main API routes.""" """FastAPI app creation, logger configuration and main API routes."""
import llama_index
from private_gpt.di import global_injector from private_gpt.di import global_injector
from private_gpt.launcher import create_app from private_gpt.launcher import create_app
# Add LlamaIndex simple observability
llama_index.set_global_handler("simple")
app = create_app(global_injector) app = create_app(global_injector)

View File

@ -3,7 +3,7 @@ import uuid
from collections.abc import Iterator from collections.abc import Iterator
from typing import Literal from typing import Literal
from llama_index.llms import ChatResponse, CompletionResponse from llama_index.core.llms import ChatResponse, CompletionResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from private_gpt.server.chunks.chunks_service import Chunk from private_gpt.server.chunks.chunks_service import Chunk

View File

@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from llama_index.llms import ChatMessage, MessageRole from llama_index.core.llms import ChatMessage, MessageRole
from pydantic import BaseModel from pydantic import BaseModel
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse

View File

@ -1,14 +1,15 @@
from dataclasses import dataclass from dataclasses import dataclass
from injector import inject, singleton from injector import inject, singleton
from llama_index import ServiceContext, StorageContext, VectorStoreIndex from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
from llama_index.chat_engine import ContextChatEngine, SimpleChatEngine from llama_index.core.chat_engine.types import (
from llama_index.chat_engine.types import (
BaseChatEngine, BaseChatEngine,
) )
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor from llama_index.core.indices import VectorStoreIndex
from llama_index.llms import ChatMessage, MessageRole from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.types import TokenGen from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.storage import StorageContext
from llama_index.core.types import TokenGen
from pydantic import BaseModel from pydantic import BaseModel
from private_gpt.components.embedding.embedding_component import EmbeddingComponent from private_gpt.components.embedding.embedding_component import EmbeddingComponent
@ -75,20 +76,19 @@ class ChatService:
embedding_component: EmbeddingComponent, embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent, node_store_component: NodeStoreComponent,
) -> None: ) -> None:
self.llm_service = llm_component self.llm_component = llm_component
self.embedding_component = embedding_component
self.vector_store_component = vector_store_component self.vector_store_component = vector_store_component
self.storage_context = StorageContext.from_defaults( self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store, vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store, docstore=node_store_component.doc_store,
index_store=node_store_component.index_store, index_store=node_store_component.index_store,
) )
self.service_context = ServiceContext.from_defaults(
llm=llm_component.llm, embed_model=embedding_component.embedding_model
)
self.index = VectorStoreIndex.from_vector_store( self.index = VectorStoreIndex.from_vector_store(
vector_store_component.vector_store, vector_store_component.vector_store,
storage_context=self.storage_context, storage_context=self.storage_context,
service_context=self.service_context, llm=llm_component.llm,
embed_model=embedding_component.embedding_model,
show_progress=True, show_progress=True,
) )
@ -102,10 +102,17 @@ class ChatService:
vector_index_retriever = self.vector_store_component.get_retriever( vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter index=self.index, context_filter=context_filter
) )
# TODO ContextChatEngine is still not migrated by LlamaIndex to accept
# llm directly, so we are passing legacy ServiceContext until it is fixed.
from llama_index.core import ServiceContext
return ContextChatEngine.from_defaults( return ContextChatEngine.from_defaults(
system_prompt=system_prompt, system_prompt=system_prompt,
retriever=vector_index_retriever, retriever=vector_index_retriever,
service_context=self.service_context, llm=self.llm_component.llm, # Takes no effect at the moment
service_context=ServiceContext.from_defaults(
llm=self.llm_component.llm,
embed_model=self.embedding_component.embedding_model,
),
node_postprocessors=[ node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"), MetadataReplacementPostProcessor(target_metadata_key="window"),
], ],
@ -113,7 +120,7 @@ class ChatService:
else: else:
return SimpleChatEngine.from_defaults( return SimpleChatEngine.from_defaults(
system_prompt=system_prompt, system_prompt=system_prompt,
service_context=self.service_context, llm=self.llm_component.llm,
) )
def stream_chat( def stream_chat(

View File

@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
from injector import inject, singleton from injector import inject, singleton
from llama_index import ServiceContext, StorageContext, VectorStoreIndex from llama_index.core.indices import VectorStoreIndex
from llama_index.schema import NodeWithScore from llama_index.core.schema import NodeWithScore
from llama_index.core.storage import StorageContext
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from private_gpt.components.embedding.embedding_component import EmbeddingComponent from private_gpt.components.embedding.embedding_component import EmbeddingComponent
@ -15,7 +16,7 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.ingest.model import IngestedDoc from private_gpt.server.ingest.model import IngestedDoc
if TYPE_CHECKING: if TYPE_CHECKING:
from llama_index.schema import RelatedNodeInfo from llama_index.core.schema import RelatedNodeInfo
class Chunk(BaseModel): class Chunk(BaseModel):
@ -63,14 +64,13 @@ class ChunksService:
node_store_component: NodeStoreComponent, node_store_component: NodeStoreComponent,
) -> None: ) -> None:
self.vector_store_component = vector_store_component self.vector_store_component = vector_store_component
self.llm_component = llm_component
self.embedding_component = embedding_component
self.storage_context = StorageContext.from_defaults( self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store, vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store, docstore=node_store_component.doc_store,
index_store=node_store_component.index_store, index_store=node_store_component.index_store,
) )
self.query_service_context = ServiceContext.from_defaults(
llm=llm_component.llm, embed_model=embedding_component.embedding_model
)
def _get_sibling_nodes_text( def _get_sibling_nodes_text(
self, node_with_score: NodeWithScore, related_number: int, forward: bool = True self, node_with_score: NodeWithScore, related_number: int, forward: bool = True
@ -103,7 +103,8 @@ class ChunksService:
index = VectorStoreIndex.from_vector_store( index = VectorStoreIndex.from_vector_store(
self.vector_store_component.vector_store, self.vector_store_component.vector_store,
storage_context=self.storage_context, storage_context=self.storage_context,
service_context=self.query_service_context, llm=self.llm_component.llm,
embed_model=self.embedding_component.embedding_model,
show_progress=True, show_progress=True,
) )
vector_index_retriever = self.vector_store_component.get_retriever( vector_index_retriever = self.vector_store_component.get_retriever(

View File

@ -4,11 +4,8 @@ from pathlib import Path
from typing import AnyStr, BinaryIO from typing import AnyStr, BinaryIO
from injector import inject, singleton from injector import inject, singleton
from llama_index import ( from llama_index.core.node_parser import SentenceWindowNodeParser
ServiceContext, from llama_index.core.storage import StorageContext
StorageContext,
)
from llama_index.node_parser import SentenceWindowNodeParser
from private_gpt.components.embedding.embedding_component import EmbeddingComponent from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.ingest.ingest_component import get_ingestion_component from private_gpt.components.ingest.ingest_component import get_ingestion_component
@ -40,17 +37,12 @@ class IngestService:
index_store=node_store_component.index_store, index_store=node_store_component.index_store,
) )
node_parser = SentenceWindowNodeParser.from_defaults() node_parser = SentenceWindowNodeParser.from_defaults()
self.ingest_service_context = ServiceContext.from_defaults(
llm=self.llm_service.llm,
embed_model=embedding_component.embedding_model,
node_parser=node_parser,
# Embeddings done early in the pipeline of node transformations, right
# after the node parsing
transformations=[node_parser, embedding_component.embedding_model],
)
self.ingest_component = get_ingestion_component( self.ingest_component = get_ingestion_component(
self.storage_context, self.ingest_service_context, settings=settings() self.storage_context,
embed_model=embedding_component.embedding_model,
transformations=[node_parser, embedding_component.embedding_model],
settings=settings(),
) )
def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]: def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]:

View File

@ -1,6 +1,6 @@
from typing import Any, Literal from typing import Any, Literal
from llama_index import Document from llama_index.core.schema import Document
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@ -10,7 +10,7 @@ import gradio as gr # type: ignore
from fastapi import FastAPI from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton from injector import inject, singleton
from llama_index.llms import ChatMessage, ChatResponse, MessageRole from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole
from pydantic import BaseModel from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH from private_gpt.constants import PROJECT_ROOT_PATH

View File

@ -6,15 +6,24 @@ authors = ["Zylon <hi@zylon.ai>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.11,<3.12" python = ">=3.11,<3.12"
fastapi = { extras = ["all"], version = "^0.103.1" } fastapi = { extras = ["all"], version = "^0.110.0" }
boto3 = "^1.28.56" boto3 = "^1.34.51"
injector = "^0.21.0" injector = "^0.21.0"
pyyaml = "^6.0.1" pyyaml = "^6.0.1"
python-multipart = "^0.0.6" python-multipart = "^0.0.9"
pypdf = "^3.16.2" llama-index-core = "^0.10.13"
llama-index = { extras = ["local_models"], version = "0.9.3" } llama-index-readers-file = "^0.1.6"
watchdog = "^3.0.0" llama-index-embeddings-huggingface = "^0.1.4"
qdrant-client = "^1.6.9" llama-index-embeddings-openai = "^0.1.6"
llama-index-vector-stores-qdrant = "^0.1.3"
llama-index-vector-stores-chroma = "^0.1.4"
llama-index-llms-llama-cpp = "^0.1.3"
llama-index-llms-openai = "^0.1.6"
llama-index-llms-openai-like = "^0.1.3"
llama-index-llms-ollama = "^0.1.2"
llama-index-vector-stores-postgres = "^0.1.2"
watchdog = "^4.0.0"
qdrant-client = "^1.7.3"
chromadb = {version = "^0.4.13", optional = true} chromadb = {version = "^0.4.13", optional = true}
asyncpg = {version = "^0.29.0", optional = true} asyncpg = {version = "^0.29.0", optional = true}
pgvector = {version = "^0.2.5", optional = true} pgvector = {version = "^0.2.5", optional = true}

View File

@ -1,5 +1,5 @@
import pytest import pytest
from llama_index.llms import ChatMessage, MessageRole from llama_index.core.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt_helper import ( from private_gpt.components.llm.prompt_helper import (
ChatMLPromptStyle, ChatMLPromptStyle,