diff --git a/private_gpt/server/chat/chat_service.py b/private_gpt/server/chat/chat_service.py index da6f494e..cd1b5041 100644 --- a/private_gpt/server/chat/chat_service.py +++ b/private_gpt/server/chat/chat_service.py @@ -19,7 +19,7 @@ from private_gpt.components.vector_store.vector_store_component import ( ) from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.chunks.chunks_service import Chunk -from private_gpt.settings.settings import settings +from private_gpt.settings.settings import Settings class Completion(BaseModel): @@ -75,6 +75,7 @@ class ChatService: vector_store_component: VectorStoreComponent, embedding_component: EmbeddingComponent, node_store_component: NodeStoreComponent, + settings: Settings, ) -> None: self.llm_service = llm_component self.vector_store_component = vector_store_component @@ -92,17 +93,19 @@ class ChatService: service_context=self.service_context, show_progress=True, ) + self.default_context_template = settings.rag.default_context_template def _chat_engine( self, system_prompt: str | None = None, use_context: bool = False, context_filter: ContextFilter | None = None, - context_template: str | None = None, ) -> BaseChatEngine: if use_context: - if context_template is None: - context_template = settings().rag.default_context_template + if self.default_context_template is not None: + context_template = self.default_context_template + else: + context_template = None vector_index_retriever = self.vector_store_component.get_retriever( index=self.index, context_filter=context_filter ) @@ -126,7 +129,6 @@ class ChatService: messages: list[ChatMessage], use_context: bool = False, context_filter: ContextFilter | None = None, - context_template: str | None = None, ) -> CompletionGen: chat_engine_input = ChatEngineInput.from_messages(messages) last_message = ( @@ -147,7 +149,6 @@ class ChatService: system_prompt=system_prompt, use_context=use_context, context_filter=context_filter, - context_template=context_template, ) streaming_response = chat_engine.stream_chat( message=last_message if last_message is not None else "", @@ -164,7 +165,6 @@ class ChatService: messages: list[ChatMessage], use_context: bool = False, context_filter: ContextFilter | None = None, - context_template: str | None = None, ) -> Completion: chat_engine_input = ChatEngineInput.from_messages(messages) last_message = ( @@ -185,7 +185,6 @@ class ChatService: system_prompt=system_prompt, use_context=use_context, context_filter=context_filter, - context_template=context_template, ) wrapped_response = chat_engine.chat( message=last_message if last_message is not None else "",