mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
add default async (#11141)
This commit is contained in:
@@ -5,7 +5,6 @@ import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
@@ -97,12 +96,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
if type(self)._acall == Chain._acall:
|
||||
# If the chain does not implement async, fall back to default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
return await self.acall(
|
||||
input,
|
||||
@@ -246,7 +239,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
A dict of named outputs. Should contain all outputs specified in
|
||||
`Chain.output_keys`.
|
||||
"""
|
||||
raise NotImplementedError("Async call not supported for this chain type.")
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self._call, inputs, run_manager
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -577,10 +577,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
),
|
||||
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
||||
)
|
||||
|
||||
def _stream(
|
||||
|
@@ -248,12 +248,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async invoke, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
@@ -319,13 +313,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
) -> List[str]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async batch, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.batch, **kwargs), inputs, config
|
||||
)
|
||||
|
||||
config = get_config_list(config, len(inputs))
|
||||
max_concurrency = config[0].get("max_concurrency")
|
||||
|
||||
@@ -478,7 +465,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
raise NotImplementedError()
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), prompts, stop, run_manager
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@@ -1035,7 +1024,9 @@ class LLM(BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
raise NotImplementedError()
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._call, **kwargs), prompt, stop, run_manager
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -1064,12 +1055,6 @@ class LLM(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
if type(self)._acall == LLM._acall:
|
||||
# model doesn't implement async call, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
|
||||
)
|
||||
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
generations = []
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
|
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Sequence
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
@@ -72,7 +74,6 @@ class BaseDocumentTransformer(ABC):
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
@@ -84,3 +85,6 @@ class BaseDocumentTransformer(ABC):
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.transform_documents, **kwargs), documents
|
||||
)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
@@ -15,8 +16,12 @@ class Embeddings(ABC):
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
raise NotImplementedError
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
raise NotImplementedError
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
||||
|
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
@@ -121,10 +123,6 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Document]:
|
||||
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
|
||||
# If the retriever doesn't implement async, use default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
config = config or {}
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
@@ -156,7 +154,9 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._get_relevant_documents, run_manager=run_manager), query
|
||||
)
|
||||
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
|
@@ -87,7 +87,9 @@ class VectorStore(ABC):
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
raise NotImplementedError
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.add_texts, **kwargs), texts, metadatas
|
||||
)
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
@@ -451,7 +453,9 @@ class VectorStore(ABC):
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
raise NotImplementedError
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
|
||||
)
|
||||
|
||||
def _get_retriever_tags(self) -> List[str]:
|
||||
"""Get tags for retriever."""
|
||||
|
@@ -21,6 +21,7 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import pathlib
|
||||
@@ -28,6 +29,7 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from io import BytesIO, StringIO
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
@@ -284,7 +286,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.transform_documents, **kwargs), documents
|
||||
)
|
||||
|
||||
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
|
@@ -217,10 +217,6 @@ class ChildTool(BaseTool):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if type(self)._arun == BaseTool._arun:
|
||||
# If the tool does not implement async, fall back to default implementation
|
||||
return await super().ainvoke(input, config, **kwargs)
|
||||
|
||||
config = config or {}
|
||||
return await self.arun(
|
||||
input,
|
||||
|
Reference in New Issue
Block a user