From 9c0ad0cebbd8bf0d8fcacdf7fa6afa71973eb6e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yves=20Zumb=C3=BChl?= Date: Thu, 30 Nov 2023 02:40:53 +0900 Subject: [PATCH] langchain[patch]: Improve HyDe with custom prompts and ability to supply the run_manager (#14016) - **Description:** The class allows to only select between a few predefined prompts from the paper. That is not ideal, since other use cases might need a custom prompt. The changes made allow for this. To be able to monitor those, I also added functionality to supply a custom run_manager. - **Issue:** no issue, but a new feature, - **Dependencies:** none, - **Tag maintainer:** @hwchase17, - **Twitter handle:** @yvesloy --------- Co-authored-by: Bagatur --- libs/langchain/langchain/chains/hyde/base.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chains/hyde/base.py b/libs/langchain/langchain/chains/hyde/base.py index 0d633246cd3..a9573f6b068 100644 --- a/libs/langchain/langchain/chains/hyde/base.py +++ b/libs/langchain/langchain/chains/hyde/base.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional import numpy as np from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate from langchain_core.pydantic_v1 import Extra from langchain.callbacks.manager import CallbackManagerForChainRun @@ -72,11 +73,21 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): cls, llm: BaseLanguageModel, base_embeddings: Embeddings, - prompt_key: str, + prompt_key: Optional[str] = None, + custom_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> HypotheticalDocumentEmbedder: - """Load and use LLMChain for a specific prompt key.""" - prompt = PROMPT_MAP[prompt_key] + """Load and use LLMChain with either a specific prompt key or custom prompt.""" + if custom_prompt is not None: + prompt = custom_prompt + elif prompt_key is not None and prompt_key in PROMPT_MAP: + prompt = PROMPT_MAP[prompt_key] + else: + raise ValueError( + f"Must specify prompt_key if custom_prompt not provided. Should be one " + f"of {list(PROMPT_MAP.keys())}." + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)