From dbbe5bcf0793ff7807874d2824d8055c57d281e4 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Tue, 24 Sep 2024 17:40:13 +0200 Subject: [PATCH] fix: mypy --- .../components/ingest/ingest_component.py | 2 +- private_gpt/components/llm/llm_component.py | 9 ++++----- private_gpt/components/llm/prompt_helper.py | 3 ++- .../node_store/node_store_component.py | 5 +++-- .../components/vector_store/batched_chroma.py | 17 ++++++++++------- private_gpt/server/chat/chat_service.py | 15 +++++++++++---- tests/fixtures/fast_api_test_client.py | 2 +- tests/fixtures/ingest_helper.py | 2 +- tests/fixtures/mock_injector.py | 2 +- tests/server/ingest/test_local_ingest.py | 2 +- 10 files changed, 35 insertions(+), 24 deletions(-) diff --git a/private_gpt/components/ingest/ingest_component.py b/private_gpt/components/ingest/ingest_component.py index 5ed03959..77db9702 100644 --- a/private_gpt/components/ingest/ingest_component.py +++ b/private_gpt/components/ingest/ingest_component.py @@ -403,7 +403,7 @@ class PipelineIngestComponent(BaseIngestComponentWithIndex): self.transformations, show_progress=self.show_progress, ) - self.node_q.put(("process", file_name, documents, nodes)) + self.node_q.put(("process", file_name, documents, list(nodes))) finally: self.doc_semaphore.release() self.doc_q.task_done() # unblock Q joins diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index e3a02813..eb752e54 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -120,7 +120,6 @@ class LLMComponent: api_version="", temperature=settings.llm.temperature, context_window=settings.llm.context_window, - max_new_tokens=settings.llm.max_new_tokens, messages_to_prompt=prompt_style.messages_to_prompt, completion_to_prompt=prompt_style.completion_to_prompt, tokenizer=settings.llm.tokenizer, @@ -184,10 +183,10 @@ class LLMComponent: return wrapper - Ollama.chat = add_keep_alive(Ollama.chat) - Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) - Ollama.complete = add_keep_alive(Ollama.complete) - Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) + Ollama.chat = add_keep_alive(Ollama.chat) # type: ignore + Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) # type: ignore + Ollama.complete = add_keep_alive(Ollama.complete) # type: ignore + Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) # type: ignore self.llm = llm diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 0432e496..512b02c2 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -40,7 +40,8 @@ class AbstractPromptStyle(abc.ABC): logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt) return prompt - def completion_to_prompt(self, completion: str) -> str: + def completion_to_prompt(self, prompt: str) -> str: + completion = prompt # Fix: Llama-index parameter has to be named as prompt prompt = self._completion_to_prompt(completion) logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt) return prompt diff --git a/private_gpt/components/node_store/node_store_component.py b/private_gpt/components/node_store/node_store_component.py index f81ce701..8008b0c5 100644 --- a/private_gpt/components/node_store/node_store_component.py +++ b/private_gpt/components/node_store/node_store_component.py @@ -38,10 +38,10 @@ class NodeStoreComponent: case "postgres": try: - from llama_index.core.storage.docstore.postgres_docstore import ( + from llama_index.storage.docstore.postgres import ( PostgresDocumentStore, ) - from llama_index.core.storage.index_store.postgres_index_store import ( + from llama_index.storage.index_store.postgres import ( PostgresIndexStore, ) except ImportError: @@ -55,6 +55,7 @@ class NodeStoreComponent: self.index_store = PostgresIndexStore.from_params( **settings.postgres.model_dump(exclude_none=True) ) + self.doc_store = PostgresDocumentStore.from_params( **settings.postgres.model_dump(exclude_none=True) ) diff --git a/private_gpt/components/vector_store/batched_chroma.py b/private_gpt/components/vector_store/batched_chroma.py index 4f9ea25b..54dd0490 100644 --- a/private_gpt/components/vector_store/batched_chroma.py +++ b/private_gpt/components/vector_store/batched_chroma.py @@ -1,14 +1,17 @@ -from collections.abc import Generator -from typing import Any +from collections.abc import Generator, Sequence +from typing import TYPE_CHECKING, Any from llama_index.core.schema import BaseNode, MetadataMode from llama_index.core.vector_stores.utils import node_to_metadata_dict from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore +if TYPE_CHECKING: + from collections.abc import Mapping + def chunk_list( - lst: list[BaseNode], max_chunk_size: int -) -> Generator[list[BaseNode], None, None]: + lst: Sequence[BaseNode], max_chunk_size: int +) -> Generator[Sequence[BaseNode], None, None]: """Yield successive max_chunk_size-sized chunks from lst. Args: @@ -60,7 +63,7 @@ class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore ) self.chroma_client = chroma_client - def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: """Add nodes to index, batching the insertion to avoid issues. Args: @@ -78,8 +81,8 @@ class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore all_ids = [] for node_chunk in node_chunks: - embeddings = [] - metadatas = [] + embeddings: list[Sequence[float]] = [] + metadatas: list[Mapping[str, Any]] = [] ids = [] documents = [] for node in node_chunk: diff --git a/private_gpt/server/chat/chat_service.py b/private_gpt/server/chat/chat_service.py index ae8cf008..efa4a194 100644 --- a/private_gpt/server/chat/chat_service.py +++ b/private_gpt/server/chat/chat_service.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import TYPE_CHECKING from injector import inject, singleton from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine @@ -26,6 +27,9 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.chunks.chunks_service import Chunk from private_gpt.settings.settings import Settings +if TYPE_CHECKING: + from llama_index.core.postprocessor.types import BaseNodePostprocessor + class Completion(BaseModel): response: str @@ -114,12 +118,15 @@ class ChatService: context_filter=context_filter, similarity_top_k=self.settings.rag.similarity_top_k, ) - node_postprocessors = [ + node_postprocessors: list[BaseNodePostprocessor] = [ MetadataReplacementPostProcessor(target_metadata_key="window"), - SimilarityPostprocessor( - similarity_cutoff=settings.rag.similarity_value - ), ] + if settings.rag.similarity_value: + node_postprocessors.append( + SimilarityPostprocessor( + similarity_cutoff=settings.rag.similarity_value + ) + ) if settings.rag.rerank.enabled: rerank_postprocessor = SentenceTransformerRerank( diff --git a/tests/fixtures/fast_api_test_client.py b/tests/fixtures/fast_api_test_client.py index 77d6037c..17525412 100644 --- a/tests/fixtures/fast_api_test_client.py +++ b/tests/fixtures/fast_api_test_client.py @@ -5,7 +5,7 @@ from private_gpt.launcher import create_app from tests.fixtures.mock_injector import MockInjector -@pytest.fixture() +@pytest.fixture def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient: if request is not None and hasattr(request, "param"): injector.bind_settings(request.param or {}) diff --git a/tests/fixtures/ingest_helper.py b/tests/fixtures/ingest_helper.py index 25515f4e..0d49d41a 100644 --- a/tests/fixtures/ingest_helper.py +++ b/tests/fixtures/ingest_helper.py @@ -19,6 +19,6 @@ class IngestHelper: return ingest_result -@pytest.fixture() +@pytest.fixture def ingest_helper(test_client: TestClient) -> IngestHelper: return IngestHelper(test_client) diff --git a/tests/fixtures/mock_injector.py b/tests/fixtures/mock_injector.py index 5769b33d..6f90fc29 100644 --- a/tests/fixtures/mock_injector.py +++ b/tests/fixtures/mock_injector.py @@ -37,6 +37,6 @@ class MockInjector: return self.test_injector.get(interface) -@pytest.fixture() +@pytest.fixture def injector() -> MockInjector: return MockInjector() diff --git a/tests/server/ingest/test_local_ingest.py b/tests/server/ingest/test_local_ingest.py index 860000ef..9c6cba6d 100644 --- a/tests/server/ingest/test_local_ingest.py +++ b/tests/server/ingest/test_local_ingest.py @@ -6,7 +6,7 @@ import pytest from fastapi.testclient import TestClient -@pytest.fixture() +@pytest.fixture def file_path() -> str: return "test.txt"