mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-24 12:39:07 +00:00
new round of PR reviews
This commit is contained in:
@@ -19,7 +19,7 @@ from private_gpt.components.vector_store.vector_store_component import (
|
|||||||
)
|
)
|
||||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||||
from private_gpt.server.chunks.chunks_service import Chunk
|
from private_gpt.server.chunks.chunks_service import Chunk
|
||||||
from private_gpt.settings.settings import settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
class Completion(BaseModel):
|
class Completion(BaseModel):
|
||||||
@@ -75,6 +75,7 @@ class ChatService:
|
|||||||
vector_store_component: VectorStoreComponent,
|
vector_store_component: VectorStoreComponent,
|
||||||
embedding_component: EmbeddingComponent,
|
embedding_component: EmbeddingComponent,
|
||||||
node_store_component: NodeStoreComponent,
|
node_store_component: NodeStoreComponent,
|
||||||
|
settings: Settings,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm_service = llm_component
|
self.llm_service = llm_component
|
||||||
self.vector_store_component = vector_store_component
|
self.vector_store_component = vector_store_component
|
||||||
@@ -92,17 +93,19 @@ class ChatService:
|
|||||||
service_context=self.service_context,
|
service_context=self.service_context,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
)
|
)
|
||||||
|
self.default_context_template = settings.rag.default_context_template
|
||||||
|
|
||||||
def _chat_engine(
|
def _chat_engine(
|
||||||
self,
|
self,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
use_context: bool = False,
|
use_context: bool = False,
|
||||||
context_filter: ContextFilter | None = None,
|
context_filter: ContextFilter | None = None,
|
||||||
context_template: str | None = None,
|
|
||||||
) -> BaseChatEngine:
|
) -> BaseChatEngine:
|
||||||
if use_context:
|
if use_context:
|
||||||
if context_template is None:
|
if self.default_context_template is not None:
|
||||||
context_template = settings().rag.default_context_template
|
context_template = self.default_context_template
|
||||||
|
else:
|
||||||
|
context_template = None
|
||||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||||
index=self.index, context_filter=context_filter
|
index=self.index, context_filter=context_filter
|
||||||
)
|
)
|
||||||
@@ -126,7 +129,6 @@ class ChatService:
|
|||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
use_context: bool = False,
|
use_context: bool = False,
|
||||||
context_filter: ContextFilter | None = None,
|
context_filter: ContextFilter | None = None,
|
||||||
context_template: str | None = None,
|
|
||||||
) -> CompletionGen:
|
) -> CompletionGen:
|
||||||
chat_engine_input = ChatEngineInput.from_messages(messages)
|
chat_engine_input = ChatEngineInput.from_messages(messages)
|
||||||
last_message = (
|
last_message = (
|
||||||
@@ -147,7 +149,6 @@ class ChatService:
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
use_context=use_context,
|
use_context=use_context,
|
||||||
context_filter=context_filter,
|
context_filter=context_filter,
|
||||||
context_template=context_template,
|
|
||||||
)
|
)
|
||||||
streaming_response = chat_engine.stream_chat(
|
streaming_response = chat_engine.stream_chat(
|
||||||
message=last_message if last_message is not None else "",
|
message=last_message if last_message is not None else "",
|
||||||
@@ -164,7 +165,6 @@ class ChatService:
|
|||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
use_context: bool = False,
|
use_context: bool = False,
|
||||||
context_filter: ContextFilter | None = None,
|
context_filter: ContextFilter | None = None,
|
||||||
context_template: str | None = None,
|
|
||||||
) -> Completion:
|
) -> Completion:
|
||||||
chat_engine_input = ChatEngineInput.from_messages(messages)
|
chat_engine_input = ChatEngineInput.from_messages(messages)
|
||||||
last_message = (
|
last_message = (
|
||||||
@@ -185,7 +185,6 @@ class ChatService:
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
use_context=use_context,
|
use_context=use_context,
|
||||||
context_filter=context_filter,
|
context_filter=context_filter,
|
||||||
context_template=context_template,
|
|
||||||
)
|
)
|
||||||
wrapped_response = chat_engine.chat(
|
wrapped_response = chat_engine.chat(
|
||||||
message=last_message if last_message is not None else "",
|
message=last_message if last_message is not None else "",
|
||||||
|
Reference in New Issue
Block a user