move hyde into chains (#728)

Co-authored-by: scadEfUr <>
This commit is contained in:
scadEfUr 2023-01-24 22:23:32 -08:00 committed by GitHub
parent 0ffeabd14f
commit e3df8ab6dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 34 additions and 14 deletions

View File

@ -21,8 +21,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.llms import OpenAI\n", "from langchain.llms import OpenAI\n",
"from langchain.embeddings import OpenAIEmbeddings, HypotheticalDocumentEmbedder\n", "from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.chains import LLMChain\n", "from langchain.chains import LLMChain, HypotheticalDocumentEmbedder\n",
"from langchain.prompts import PromptTemplate" "from langchain.prompts import PromptTemplate"
] ]
}, },
@ -220,7 +220,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "llm-env",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -234,7 +234,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.9" "version": "3.9.0 (default, Nov 15 2020, 06:25:35) \n[Clang 10.0.0 ]"
},
"vscode": {
"interpreter": {
"hash": "9dd01537e9ab68cf47cb0398488d182358f774f73101197b3bd1b5502c6ec7f9"
}
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,6 +1,7 @@
"""Chains are easily reusable components which can be linked together.""" """Chains are easily reusable components which can be linked together."""
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.conversation.base import ConversationChain from langchain.chains.conversation.base import ConversationChain
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.base import LLMBashChain from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_checker.base import LLMCheckerChain
@ -41,4 +42,5 @@ __all__ = [
"OpenAIModerationChain", "OpenAIModerationChain",
"SQLDatabaseSequentialChain", "SQLDatabaseSequentialChain",
"load_chain", "load_chain",
"HypotheticalDocumentEmbedder",
] ]

View File

@ -4,18 +4,19 @@ https://arxiv.org/abs/2212.10496
""" """
from __future__ import annotations from __future__ import annotations
from typing import List from typing import Dict, List
import numpy as np import numpy as np
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
class HypotheticalDocumentEmbedder(Embeddings, BaseModel): class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel):
"""Generate hypothetical document for query, and then embed that. """Generate hypothetical document for query, and then embed that.
Based on https://arxiv.org/abs/2212.10496 Based on https://arxiv.org/abs/2212.10496
@ -30,10 +31,24 @@ class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Input keys for Hyde's LLM chain."""
return self.llm_chain.input_keys
@property
def output_keys(self) -> List[str]:
"""Output keys for Hyde's LLM chain."""
return self.llm_chain.output_keys
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call the base embeddings.""" """Call the base embeddings."""
return self.base_embeddings.embed_documents(texts) return self.base_embeddings.embed_documents(texts)
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
"""Combine embeddings into final embeddings."""
return list(np.array(embeddings).mean(axis=0))
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it.""" """Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0] var_name = self.llm_chain.input_keys[0]
@ -42,9 +57,9 @@ class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
embeddings = self.embed_documents(documents) embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings) return self.combine_embeddings(embeddings)
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Combine embeddings into final embeddings.""" """Call the internal llm chain."""
return list(np.array(embeddings).mean(axis=0)) return self.llm_chain._call(inputs)
@classmethod @classmethod
def from_llm( def from_llm(

View File

@ -2,7 +2,6 @@
from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
__all__ = [ __all__ = [
@ -10,5 +9,4 @@ __all__ = [
"HuggingFaceEmbeddings", "HuggingFaceEmbeddings",
"CohereEmbeddings", "CohereEmbeddings",
"HuggingFaceHubEmbeddings", "HuggingFaceHubEmbeddings",
"HypotheticalDocumentEmbedder",
] ]

View File

@ -4,9 +4,9 @@ from typing import List, Optional
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult