add default async (#11141)

This commit is contained in:
Bagatur
2023-10-04 11:40:35 -04:00
committed by GitHub
parent 88c5349196
commit 106608bc89
9 changed files with 38 additions and 48 deletions

View File

@@ -5,7 +5,6 @@ import json
import logging import logging
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
@@ -97,12 +96,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
config = config or {} config = config or {}
return await self.acall( return await self.acall(
input, input,
@@ -246,7 +239,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
A dict of named outputs. Should contain all outputs specified in A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`. `Chain.output_keys`.
""" """
raise NotImplementedError("Async call not supported for this chain type.") return await asyncio.get_running_loop().run_in_executor(
None, self._call, inputs, run_manager
)
def __call__( def __call__(
self, self,

View File

@@ -577,10 +577,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, None, partial(self._generate, **kwargs), messages, stop, run_manager
partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
),
) )
def _stream( def _stream(

View File

@@ -248,12 +248,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async invoke, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
config = config or {} config = config or {}
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], [self._convert_input(input)],
@@ -319,13 +313,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) -> List[str]: ) -> List[str]:
if not inputs: if not inputs:
return [] return []
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.batch, **kwargs), inputs, config
)
config = get_config_list(config, len(inputs)) config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency") max_concurrency = config[0].get("max_concurrency")
@@ -478,7 +465,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
raise NotImplementedError() return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
)
def _stream( def _stream(
self, self,
@@ -1035,7 +1024,9 @@ class LLM(BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
raise NotImplementedError() return await asyncio.get_running_loop().run_in_executor(
None, partial(self._call, **kwargs), prompt, stop, run_manager
)
def _generate( def _generate(
self, self,
@@ -1064,12 +1055,6 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
if type(self)._acall == LLM._acall:
# model doesn't implement async call, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
)
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
generations = [] generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")

View File

@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Sequence from typing import Any, Sequence
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
@@ -72,7 +74,6 @@ class BaseDocumentTransformer(ABC):
A list of transformed Documents. A list of transformed Documents.
""" """
@abstractmethod
async def atransform_documents( async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]: ) -> Sequence[Document]:
@@ -84,3 +85,6 @@ class BaseDocumentTransformer(ABC):
Returns: Returns:
A list of transformed Documents. A list of transformed Documents.
""" """
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)

View File

@@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import List
@@ -15,8 +16,12 @@ class Embeddings(ABC):
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.""" """Asynchronous Embed search docs."""
raise NotImplementedError return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text.""" """Asynchronous Embed query text."""
raise NotImplementedError return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)

View File

@@ -1,7 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
@@ -121,10 +123,6 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Document]: ) -> List[Document]:
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)
config = config or {} config = config or {}
return await self.aget_relevant_documents( return await self.aget_relevant_documents(
input, input,
@@ -156,7 +154,9 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
Returns: Returns:
List of relevant documents List of relevant documents
""" """
raise NotImplementedError() return await asyncio.get_running_loop().run_in_executor(
None, partial(self._get_relevant_documents, run_manager=run_manager), query
)
def get_relevant_documents( def get_relevant_documents(
self, self,

View File

@@ -87,7 +87,9 @@ class VectorStore(ABC):
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.""" """Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore. """Run more documents through the embeddings and add to the vectorstore.
@@ -451,7 +453,9 @@ class VectorStore(ABC):
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""
raise NotImplementedError return await asyncio.get_running_loop().run_in_executor(
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
)
def _get_retriever_tags(self) -> List[str]: def _get_retriever_tags(self) -> List[str]:
"""Get tags for retriever.""" """Get tags for retriever."""

View File

@@ -21,6 +21,7 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
from __future__ import annotations from __future__ import annotations
import asyncio
import copy import copy
import logging import logging
import pathlib import pathlib
@@ -28,6 +29,7 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import partial
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import ( from typing import (
AbstractSet, AbstractSet,
@@ -284,7 +286,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self, documents: Sequence[Document], **kwargs: Any self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]: ) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them.""" """Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents
)
class CharacterTextSplitter(TextSplitter): class CharacterTextSplitter(TextSplitter):

View File

@@ -217,10 +217,6 @@ class ChildTool(BaseTool):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
if type(self)._arun == BaseTool._arun:
# If the tool does not implement async, fall back to default implementation
return await super().ainvoke(input, config, **kwargs)
config = config or {} config = config or {}
return await self.arun( return await self.arun(
input, input,