diff --git a/docs/modules/utils/combine_docs_examples/hyde.ipynb b/docs/modules/utils/combine_docs_examples/hyde.ipynb index 1efbe5f6537..b105873a820 100644 --- a/docs/modules/utils/combine_docs_examples/hyde.ipynb +++ b/docs/modules/utils/combine_docs_examples/hyde.ipynb @@ -21,8 +21,8 @@ "outputs": [], "source": [ "from langchain.llms import OpenAI\n", - "from langchain.embeddings import OpenAIEmbeddings, HypotheticalDocumentEmbedder\n", - "from langchain.chains import LLMChain\n", + "from langchain.embeddings import OpenAIEmbeddings\n", + "from langchain.chains import LLMChain, HypotheticalDocumentEmbedder\n", "from langchain.prompts import PromptTemplate" ] }, @@ -220,7 +220,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "llm-env", "language": "python", "name": "python3" }, @@ -234,7 +234,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.0 (default, Nov 15 2020, 06:25:35) \n[Clang 10.0.0 ]" + }, + "vscode": { + "interpreter": { + "hash": "9dd01537e9ab68cf47cb0398488d182358f774f73101197b3bd1b5502c6ec7f9" + } } }, "nbformat": 4, diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 0079d1e26d0..21bb09d9a4c 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -1,6 +1,7 @@ """Chains are easily reusable components which can be linked together.""" from langchain.chains.api.base import APIChain from langchain.chains.conversation.base import ConversationChain +from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.base import LLMBashChain from langchain.chains.llm_checker.base import LLMCheckerChain @@ -41,4 +42,5 @@ __all__ = [ "OpenAIModerationChain", "SQLDatabaseSequentialChain", "load_chain", + "HypotheticalDocumentEmbedder", ] diff --git a/langchain/embeddings/hyde/__init__.py b/langchain/chains/hyde/__init__.py similarity index 100% rename from langchain/embeddings/hyde/__init__.py rename to langchain/chains/hyde/__init__.py diff --git a/langchain/embeddings/hyde/base.py b/langchain/chains/hyde/base.py similarity index 73% rename from langchain/embeddings/hyde/base.py rename to langchain/chains/hyde/base.py index dbad3535c2d..fd043bbad62 100644 --- a/langchain/embeddings/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -4,18 +4,19 @@ https://arxiv.org/abs/2212.10496 """ from __future__ import annotations -from typing import List +from typing import Dict, List import numpy as np from pydantic import BaseModel, Extra +from langchain.chains.base import Chain +from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.llm import LLMChain from langchain.embeddings.base import Embeddings -from langchain.embeddings.hyde.prompts import PROMPT_MAP from langchain.llms.base import BaseLLM -class HypotheticalDocumentEmbedder(Embeddings, BaseModel): +class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel): """Generate hypothetical document for query, and then embed that. Based on https://arxiv.org/abs/2212.10496 @@ -30,10 +31,24 @@ class HypotheticalDocumentEmbedder(Embeddings, BaseModel): extra = Extra.forbid arbitrary_types_allowed = True + @property + def input_keys(self) -> List[str]: + """Input keys for Hyde's LLM chain.""" + return self.llm_chain.input_keys + + @property + def output_keys(self) -> List[str]: + """Output keys for Hyde's LLM chain.""" + return self.llm_chain.output_keys + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call the base embeddings.""" return self.base_embeddings.embed_documents(texts) + def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: + """Combine embeddings into final embeddings.""" + return list(np.array(embeddings).mean(axis=0)) + def embed_query(self, text: str) -> List[float]: """Generate a hypothetical document and embedded it.""" var_name = self.llm_chain.input_keys[0] @@ -42,9 +57,9 @@ class HypotheticalDocumentEmbedder(Embeddings, BaseModel): embeddings = self.embed_documents(documents) return self.combine_embeddings(embeddings) - def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: - """Combine embeddings into final embeddings.""" - return list(np.array(embeddings).mean(axis=0)) + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + """Call the internal llm chain.""" + return self.llm_chain._call(inputs) @classmethod def from_llm( diff --git a/langchain/embeddings/hyde/prompts.py b/langchain/chains/hyde/prompts.py similarity index 100% rename from langchain/embeddings/hyde/prompts.py rename to langchain/chains/hyde/prompts.py diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index 5a2019e8bf0..6a57deb1364 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -2,7 +2,6 @@ from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings -from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder from langchain.embeddings.openai import OpenAIEmbeddings __all__ = [ @@ -10,5 +9,4 @@ __all__ = [ "HuggingFaceEmbeddings", "CohereEmbeddings", "HuggingFaceHubEmbeddings", - "HypotheticalDocumentEmbedder", ] diff --git a/tests/unit_tests/test_hyde.py b/tests/unit_tests/chains/test_hyde.py similarity index 92% rename from tests/unit_tests/test_hyde.py rename to tests/unit_tests/chains/test_hyde.py index 91b7bb550dc..b609a6295b8 100644 --- a/tests/unit_tests/test_hyde.py +++ b/tests/unit_tests/chains/test_hyde.py @@ -4,9 +4,9 @@ from typing import List, Optional import numpy as np from pydantic import BaseModel +from langchain.chains.hyde.base import HypotheticalDocumentEmbedder +from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.embeddings.base import Embeddings -from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder -from langchain.embeddings.hyde.prompts import PROMPT_MAP from langchain.llms.base import BaseLLM from langchain.schema import Generation, LLMResult