fix: mypy

This commit is contained in:
Javier Martinez 2024-09-24 17:40:13 +02:00
parent 7c99878576
commit dbbe5bcf07
No known key found for this signature in database
10 changed files with 35 additions and 24 deletions

View File

@ -403,7 +403,7 @@ class PipelineIngestComponent(BaseIngestComponentWithIndex):
self.transformations,
show_progress=self.show_progress,
)
self.node_q.put(("process", file_name, documents, nodes))
self.node_q.put(("process", file_name, documents, list(nodes)))
finally:
self.doc_semaphore.release()
self.doc_q.task_done() # unblock Q joins

View File

@ -120,7 +120,6 @@ class LLMComponent:
api_version="",
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
max_new_tokens=settings.llm.max_new_tokens,
messages_to_prompt=prompt_style.messages_to_prompt,
completion_to_prompt=prompt_style.completion_to_prompt,
tokenizer=settings.llm.tokenizer,
@ -184,10 +183,10 @@ class LLMComponent:
return wrapper
Ollama.chat = add_keep_alive(Ollama.chat)
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
Ollama.chat = add_keep_alive(Ollama.chat) # type: ignore
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) # type: ignore
Ollama.complete = add_keep_alive(Ollama.complete) # type: ignore
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) # type: ignore
self.llm = llm

View File

@ -40,7 +40,8 @@ class AbstractPromptStyle(abc.ABC):
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
return prompt
def completion_to_prompt(self, completion: str) -> str:
def completion_to_prompt(self, prompt: str) -> str:
completion = prompt # Fix: Llama-index parameter has to be named as prompt
prompt = self._completion_to_prompt(completion)
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
return prompt

View File

@ -38,10 +38,10 @@ class NodeStoreComponent:
case "postgres":
try:
from llama_index.core.storage.docstore.postgres_docstore import (
from llama_index.storage.docstore.postgres import (
PostgresDocumentStore,
)
from llama_index.core.storage.index_store.postgres_index_store import (
from llama_index.storage.index_store.postgres import (
PostgresIndexStore,
)
except ImportError:
@ -55,6 +55,7 @@ class NodeStoreComponent:
self.index_store = PostgresIndexStore.from_params(
**settings.postgres.model_dump(exclude_none=True)
)
self.doc_store = PostgresDocumentStore.from_params(
**settings.postgres.model_dump(exclude_none=True)
)

View File

@ -1,14 +1,17 @@
from collections.abc import Generator
from typing import Any
from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING, Any
from llama_index.core.schema import BaseNode, MetadataMode
from llama_index.core.vector_stores.utils import node_to_metadata_dict
from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore
if TYPE_CHECKING:
from collections.abc import Mapping
def chunk_list(
lst: list[BaseNode], max_chunk_size: int
) -> Generator[list[BaseNode], None, None]:
lst: Sequence[BaseNode], max_chunk_size: int
) -> Generator[Sequence[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
@ -60,7 +63,7 @@ class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
)
self.chroma_client = chroma_client
def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
"""Add nodes to index, batching the insertion to avoid issues.
Args:
@ -78,8 +81,8 @@ class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
embeddings: list[Sequence[float]] = []
metadatas: list[Mapping[str, Any]] = []
ids = []
documents = []
for node in node_chunk:

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from injector import inject, singleton
from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
@ -26,6 +27,9 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk
from private_gpt.settings.settings import Settings
if TYPE_CHECKING:
from llama_index.core.postprocessor.types import BaseNodePostprocessor
class Completion(BaseModel):
response: str
@ -114,12 +118,15 @@ class ChatService:
context_filter=context_filter,
similarity_top_k=self.settings.rag.similarity_top_k,
)
node_postprocessors = [
node_postprocessors: list[BaseNodePostprocessor] = [
MetadataReplacementPostProcessor(target_metadata_key="window"),
]
if settings.rag.similarity_value:
node_postprocessors.append(
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value
),
]
)
)
if settings.rag.rerank.enabled:
rerank_postprocessor = SentenceTransformerRerank(

View File

@ -5,7 +5,7 @@ from private_gpt.launcher import create_app
from tests.fixtures.mock_injector import MockInjector
@pytest.fixture()
@pytest.fixture
def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient:
if request is not None and hasattr(request, "param"):
injector.bind_settings(request.param or {})

View File

@ -19,6 +19,6 @@ class IngestHelper:
return ingest_result
@pytest.fixture()
@pytest.fixture
def ingest_helper(test_client: TestClient) -> IngestHelper:
return IngestHelper(test_client)

View File

@ -37,6 +37,6 @@ class MockInjector:
return self.test_injector.get(interface)
@pytest.fixture()
@pytest.fixture
def injector() -> MockInjector:
return MockInjector()

View File

@ -6,7 +6,7 @@ import pytest
from fastapi.testclient import TestClient
@pytest.fixture()
@pytest.fixture
def file_path() -> str:
return "test.txt"