mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-20 01:54:14 +00:00
langchain[patch]: Add async methods to EmbeddingRouterChain (#19603)
This commit is contained in:
committed by
GitHub
parent
b3d7b5a653
commit
7c2578bd55
@@ -2,7 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import Extra
|
from langchain_core.pydantic_v1 import Extra
|
||||||
@@ -40,6 +43,15 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
results = self.vectorstore.similarity_search(_input, k=1)
|
results = self.vectorstore.similarity_search(_input, k=1)
|
||||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
||||||
|
results = await self.vectorstore.asimilarity_search(_input, k=1)
|
||||||
|
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_names_and_descriptions(
|
def from_names_and_descriptions(
|
||||||
cls,
|
cls,
|
||||||
@@ -57,3 +69,21 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
)
|
)
|
||||||
vectorstore = vectorstore_cls.from_documents(documents, embeddings)
|
vectorstore = vectorstore_cls.from_documents(documents, embeddings)
|
||||||
return cls(vectorstore=vectorstore, **kwargs)
|
return cls(vectorstore=vectorstore, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def afrom_names_and_descriptions(
|
||||||
|
cls,
|
||||||
|
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
|
||||||
|
vectorstore_cls: Type[VectorStore],
|
||||||
|
embeddings: Embeddings,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> EmbeddingRouterChain:
|
||||||
|
"""Convenience constructor."""
|
||||||
|
documents = []
|
||||||
|
for name, descriptions in names_and_descriptions:
|
||||||
|
for description in descriptions:
|
||||||
|
documents.append(
|
||||||
|
Document(page_content=description, metadata={"name": name})
|
||||||
|
)
|
||||||
|
vectorstore = await vectorstore_cls.afrom_documents(documents, embeddings)
|
||||||
|
return cls(vectorstore=vectorstore, **kwargs)
|
||||||
|
Reference in New Issue
Block a user