mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-25 15:01:52 +00:00
fix: mypy
This commit is contained in:
parent
7c99878576
commit
dbbe5bcf07
@ -403,7 +403,7 @@ class PipelineIngestComponent(BaseIngestComponentWithIndex):
|
|||||||
self.transformations,
|
self.transformations,
|
||||||
show_progress=self.show_progress,
|
show_progress=self.show_progress,
|
||||||
)
|
)
|
||||||
self.node_q.put(("process", file_name, documents, nodes))
|
self.node_q.put(("process", file_name, documents, list(nodes)))
|
||||||
finally:
|
finally:
|
||||||
self.doc_semaphore.release()
|
self.doc_semaphore.release()
|
||||||
self.doc_q.task_done() # unblock Q joins
|
self.doc_q.task_done() # unblock Q joins
|
||||||
|
@ -120,7 +120,6 @@ class LLMComponent:
|
|||||||
api_version="",
|
api_version="",
|
||||||
temperature=settings.llm.temperature,
|
temperature=settings.llm.temperature,
|
||||||
context_window=settings.llm.context_window,
|
context_window=settings.llm.context_window,
|
||||||
max_new_tokens=settings.llm.max_new_tokens,
|
|
||||||
messages_to_prompt=prompt_style.messages_to_prompt,
|
messages_to_prompt=prompt_style.messages_to_prompt,
|
||||||
completion_to_prompt=prompt_style.completion_to_prompt,
|
completion_to_prompt=prompt_style.completion_to_prompt,
|
||||||
tokenizer=settings.llm.tokenizer,
|
tokenizer=settings.llm.tokenizer,
|
||||||
@ -184,10 +183,10 @@ class LLMComponent:
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
Ollama.chat = add_keep_alive(Ollama.chat)
|
Ollama.chat = add_keep_alive(Ollama.chat) # type: ignore
|
||||||
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
|
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) # type: ignore
|
||||||
Ollama.complete = add_keep_alive(Ollama.complete)
|
Ollama.complete = add_keep_alive(Ollama.complete) # type: ignore
|
||||||
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
|
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) # type: ignore
|
||||||
|
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
|
||||||
|
@ -40,7 +40,8 @@ class AbstractPromptStyle(abc.ABC):
|
|||||||
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
|
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def completion_to_prompt(self, completion: str) -> str:
|
def completion_to_prompt(self, prompt: str) -> str:
|
||||||
|
completion = prompt # Fix: Llama-index parameter has to be named as prompt
|
||||||
prompt = self._completion_to_prompt(completion)
|
prompt = self._completion_to_prompt(completion)
|
||||||
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
||||||
return prompt
|
return prompt
|
||||||
|
@ -38,10 +38,10 @@ class NodeStoreComponent:
|
|||||||
|
|
||||||
case "postgres":
|
case "postgres":
|
||||||
try:
|
try:
|
||||||
from llama_index.core.storage.docstore.postgres_docstore import (
|
from llama_index.storage.docstore.postgres import (
|
||||||
PostgresDocumentStore,
|
PostgresDocumentStore,
|
||||||
)
|
)
|
||||||
from llama_index.core.storage.index_store.postgres_index_store import (
|
from llama_index.storage.index_store.postgres import (
|
||||||
PostgresIndexStore,
|
PostgresIndexStore,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -55,6 +55,7 @@ class NodeStoreComponent:
|
|||||||
self.index_store = PostgresIndexStore.from_params(
|
self.index_store = PostgresIndexStore.from_params(
|
||||||
**settings.postgres.model_dump(exclude_none=True)
|
**settings.postgres.model_dump(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.doc_store = PostgresDocumentStore.from_params(
|
self.doc_store = PostgresDocumentStore.from_params(
|
||||||
**settings.postgres.model_dump(exclude_none=True)
|
**settings.postgres.model_dump(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from llama_index.core.schema import BaseNode, MetadataMode
|
from llama_index.core.schema import BaseNode, MetadataMode
|
||||||
from llama_index.core.vector_stores.utils import node_to_metadata_dict
|
from llama_index.core.vector_stores.utils import node_to_metadata_dict
|
||||||
from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore
|
from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
|
||||||
def chunk_list(
|
def chunk_list(
|
||||||
lst: list[BaseNode], max_chunk_size: int
|
lst: Sequence[BaseNode], max_chunk_size: int
|
||||||
) -> Generator[list[BaseNode], None, None]:
|
) -> Generator[Sequence[BaseNode], None, None]:
|
||||||
"""Yield successive max_chunk_size-sized chunks from lst.
|
"""Yield successive max_chunk_size-sized chunks from lst.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -60,7 +63,7 @@ class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
|
|||||||
)
|
)
|
||||||
self.chroma_client = chroma_client
|
self.chroma_client = chroma_client
|
||||||
|
|
||||||
def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
|
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
|
||||||
"""Add nodes to index, batching the insertion to avoid issues.
|
"""Add nodes to index, batching the insertion to avoid issues.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -78,8 +81,8 @@ class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
|
|||||||
|
|
||||||
all_ids = []
|
all_ids = []
|
||||||
for node_chunk in node_chunks:
|
for node_chunk in node_chunks:
|
||||||
embeddings = []
|
embeddings: list[Sequence[float]] = []
|
||||||
metadatas = []
|
metadatas: list[Mapping[str, Any]] = []
|
||||||
ids = []
|
ids = []
|
||||||
documents = []
|
documents = []
|
||||||
for node in node_chunk:
|
for node in node_chunk:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
|
from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
|
||||||
@ -26,6 +27,9 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
|||||||
from private_gpt.server.chunks.chunks_service import Chunk
|
from private_gpt.server.chunks.chunks_service import Chunk
|
||||||
from private_gpt.settings.settings import Settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||||
|
|
||||||
|
|
||||||
class Completion(BaseModel):
|
class Completion(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
@ -114,12 +118,15 @@ class ChatService:
|
|||||||
context_filter=context_filter,
|
context_filter=context_filter,
|
||||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||||
)
|
)
|
||||||
node_postprocessors = [
|
node_postprocessors: list[BaseNodePostprocessor] = [
|
||||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||||
|
]
|
||||||
|
if settings.rag.similarity_value:
|
||||||
|
node_postprocessors.append(
|
||||||
SimilarityPostprocessor(
|
SimilarityPostprocessor(
|
||||||
similarity_cutoff=settings.rag.similarity_value
|
similarity_cutoff=settings.rag.similarity_value
|
||||||
),
|
)
|
||||||
]
|
)
|
||||||
|
|
||||||
if settings.rag.rerank.enabled:
|
if settings.rag.rerank.enabled:
|
||||||
rerank_postprocessor = SentenceTransformerRerank(
|
rerank_postprocessor = SentenceTransformerRerank(
|
||||||
|
2
tests/fixtures/fast_api_test_client.py
vendored
2
tests/fixtures/fast_api_test_client.py
vendored
@ -5,7 +5,7 @@ from private_gpt.launcher import create_app
|
|||||||
from tests.fixtures.mock_injector import MockInjector
|
from tests.fixtures.mock_injector import MockInjector
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient:
|
def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient:
|
||||||
if request is not None and hasattr(request, "param"):
|
if request is not None and hasattr(request, "param"):
|
||||||
injector.bind_settings(request.param or {})
|
injector.bind_settings(request.param or {})
|
||||||
|
2
tests/fixtures/ingest_helper.py
vendored
2
tests/fixtures/ingest_helper.py
vendored
@ -19,6 +19,6 @@ class IngestHelper:
|
|||||||
return ingest_result
|
return ingest_result
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def ingest_helper(test_client: TestClient) -> IngestHelper:
|
def ingest_helper(test_client: TestClient) -> IngestHelper:
|
||||||
return IngestHelper(test_client)
|
return IngestHelper(test_client)
|
||||||
|
2
tests/fixtures/mock_injector.py
vendored
2
tests/fixtures/mock_injector.py
vendored
@ -37,6 +37,6 @@ class MockInjector:
|
|||||||
return self.test_injector.get(interface)
|
return self.test_injector.get(interface)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def injector() -> MockInjector:
|
def injector() -> MockInjector:
|
||||||
return MockInjector()
|
return MockInjector()
|
||||||
|
@ -6,7 +6,7 @@ import pytest
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture
|
||||||
def file_path() -> str:
|
def file_path() -> str:
|
||||||
return "test.txt"
|
return "test.txt"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user