new round of PR reviews

This commit is contained in:
Juan
2023-12-20 11:51:08 -03:00
parent 4c51eeb263
commit f06a8b4ba1

View File

@@ -19,7 +19,7 @@ from private_gpt.components.vector_store.vector_store_component import (
) )
from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk from private_gpt.server.chunks.chunks_service import Chunk
from private_gpt.settings.settings import settings from private_gpt.settings.settings import Settings
class Completion(BaseModel): class Completion(BaseModel):
@@ -75,6 +75,7 @@ class ChatService:
vector_store_component: VectorStoreComponent, vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent, embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent, node_store_component: NodeStoreComponent,
settings: Settings,
) -> None: ) -> None:
self.llm_service = llm_component self.llm_service = llm_component
self.vector_store_component = vector_store_component self.vector_store_component = vector_store_component
@@ -92,17 +93,19 @@ class ChatService:
service_context=self.service_context, service_context=self.service_context,
show_progress=True, show_progress=True,
) )
self.default_context_template = settings.rag.default_context_template
def _chat_engine( def _chat_engine(
self, self,
system_prompt: str | None = None, system_prompt: str | None = None,
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
context_template: str | None = None,
) -> BaseChatEngine: ) -> BaseChatEngine:
if use_context: if use_context:
if context_template is None: if self.default_context_template is not None:
context_template = settings().rag.default_context_template context_template = self.default_context_template
else:
context_template = None
vector_index_retriever = self.vector_store_component.get_retriever( vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter index=self.index, context_filter=context_filter
) )
@@ -126,7 +129,6 @@ class ChatService:
messages: list[ChatMessage], messages: list[ChatMessage],
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
context_template: str | None = None,
) -> CompletionGen: ) -> CompletionGen:
chat_engine_input = ChatEngineInput.from_messages(messages) chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = ( last_message = (
@@ -147,7 +149,6 @@ class ChatService:
system_prompt=system_prompt, system_prompt=system_prompt,
use_context=use_context, use_context=use_context,
context_filter=context_filter, context_filter=context_filter,
context_template=context_template,
) )
streaming_response = chat_engine.stream_chat( streaming_response = chat_engine.stream_chat(
message=last_message if last_message is not None else "", message=last_message if last_message is not None else "",
@@ -164,7 +165,6 @@ class ChatService:
messages: list[ChatMessage], messages: list[ChatMessage],
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
context_template: str | None = None,
) -> Completion: ) -> Completion:
chat_engine_input = ChatEngineInput.from_messages(messages) chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = ( last_message = (
@@ -185,7 +185,6 @@ class ChatService:
system_prompt=system_prompt, system_prompt=system_prompt,
use_context=use_context, use_context=use_context,
context_filter=context_filter, context_filter=context_filter,
context_template=context_template,
) )
wrapped_response = chat_engine.chat( wrapped_response = chat_engine.chat(
message=last_message if last_message is not None else "", message=last_message if last_message is not None else "",