mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
parent
0ffeabd14f
commit
e3df8ab6dc
@ -21,8 +21,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.llms import OpenAI\n",
|
"from langchain.llms import OpenAI\n",
|
||||||
"from langchain.embeddings import OpenAIEmbeddings, HypotheticalDocumentEmbedder\n",
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
"from langchain.chains import LLMChain\n",
|
"from langchain.chains import LLMChain, HypotheticalDocumentEmbedder\n",
|
||||||
"from langchain.prompts import PromptTemplate"
|
"from langchain.prompts import PromptTemplate"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -220,7 +220,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "llm-env",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -234,7 +234,12 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.9"
|
"version": "3.9.0 (default, Nov 15 2020, 06:25:35) \n[Clang 10.0.0 ]"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "9dd01537e9ab68cf47cb0398488d182358f774f73101197b3bd1b5502c6ec7f9"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Chains are easily reusable components which can be linked together."""
|
"""Chains are easily reusable components which can be linked together."""
|
||||||
from langchain.chains.api.base import APIChain
|
from langchain.chains.api.base import APIChain
|
||||||
from langchain.chains.conversation.base import ConversationChain
|
from langchain.chains.conversation.base import ConversationChain
|
||||||
|
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_bash.base import LLMBashChain
|
from langchain.chains.llm_bash.base import LLMBashChain
|
||||||
from langchain.chains.llm_checker.base import LLMCheckerChain
|
from langchain.chains.llm_checker.base import LLMCheckerChain
|
||||||
@ -41,4 +42,5 @@ __all__ = [
|
|||||||
"OpenAIModerationChain",
|
"OpenAIModerationChain",
|
||||||
"SQLDatabaseSequentialChain",
|
"SQLDatabaseSequentialChain",
|
||||||
"load_chain",
|
"load_chain",
|
||||||
|
"HypotheticalDocumentEmbedder",
|
||||||
]
|
]
|
||||||
|
@ -4,18 +4,19 @@ https://arxiv.org/abs/2212.10496
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
|
class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel):
|
||||||
"""Generate hypothetical document for query, and then embed that.
|
"""Generate hypothetical document for query, and then embed that.
|
||||||
|
|
||||||
Based on https://arxiv.org/abs/2212.10496
|
Based on https://arxiv.org/abs/2212.10496
|
||||||
@ -30,10 +31,24 @@ class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys for Hyde's LLM chain."""
|
||||||
|
return self.llm_chain.input_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Output keys for Hyde's LLM chain."""
|
||||||
|
return self.llm_chain.output_keys
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Call the base embeddings."""
|
"""Call the base embeddings."""
|
||||||
return self.base_embeddings.embed_documents(texts)
|
return self.base_embeddings.embed_documents(texts)
|
||||||
|
|
||||||
|
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
|
||||||
|
"""Combine embeddings into final embeddings."""
|
||||||
|
return list(np.array(embeddings).mean(axis=0))
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Generate a hypothetical document and embedded it."""
|
"""Generate a hypothetical document and embedded it."""
|
||||||
var_name = self.llm_chain.input_keys[0]
|
var_name = self.llm_chain.input_keys[0]
|
||||||
@ -42,9 +57,9 @@ class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
|
|||||||
embeddings = self.embed_documents(documents)
|
embeddings = self.embed_documents(documents)
|
||||||
return self.combine_embeddings(embeddings)
|
return self.combine_embeddings(embeddings)
|
||||||
|
|
||||||
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
"""Combine embeddings into final embeddings."""
|
"""Call the internal llm chain."""
|
||||||
return list(np.array(embeddings).mean(axis=0))
|
return self.llm_chain._call(inputs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
@ -2,7 +2,6 @@
|
|||||||
from langchain.embeddings.cohere import CohereEmbeddings
|
from langchain.embeddings.cohere import CohereEmbeddings
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||||
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -10,5 +9,4 @@ __all__ = [
|
|||||||
"HuggingFaceEmbeddings",
|
"HuggingFaceEmbeddings",
|
||||||
"CohereEmbeddings",
|
"CohereEmbeddings",
|
||||||
"HuggingFaceHubEmbeddings",
|
"HuggingFaceHubEmbeddings",
|
||||||
"HypotheticalDocumentEmbedder",
|
|
||||||
]
|
]
|
||||||
|
@ -4,9 +4,9 @@ from typing import List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||||
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
|
||||||
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
|
|
Loading…
Reference in New Issue
Block a user