mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
add default async (#11141)
This commit is contained in:
@@ -5,7 +5,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
@@ -97,12 +96,6 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if type(self)._acall == Chain._acall:
|
|
||||||
# If the chain does not implement async, fall back to default implementation
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, partial(self.invoke, input, config, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
return await self.acall(
|
return await self.acall(
|
||||||
input,
|
input,
|
||||||
@@ -246,7 +239,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
A dict of named outputs. Should contain all outputs specified in
|
A dict of named outputs. Should contain all outputs specified in
|
||||||
`Chain.output_keys`.
|
`Chain.output_keys`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Async call not supported for this chain type.")
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self._call, inputs, run_manager
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@@ -577,10 +577,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
None,
|
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
||||||
partial(
|
|
||||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
|
@@ -248,12 +248,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if type(self)._agenerate == BaseLLM._agenerate:
|
|
||||||
# model doesn't implement async invoke, so use default implementation
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)],
|
[self._convert_input(input)],
|
||||||
@@ -319,13 +313,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if type(self)._agenerate == BaseLLM._agenerate:
|
|
||||||
# model doesn't implement async batch, so use default implementation
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, partial(self.batch, **kwargs), inputs, config
|
|
||||||
)
|
|
||||||
|
|
||||||
config = get_config_list(config, len(inputs))
|
config = get_config_list(config, len(inputs))
|
||||||
max_concurrency = config[0].get("max_concurrency")
|
max_concurrency = config[0].get("max_concurrency")
|
||||||
|
|
||||||
@@ -478,7 +465,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompts."""
|
"""Run the LLM on the given prompts."""
|
||||||
raise NotImplementedError()
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self._generate, **kwargs), prompts, stop, run_manager
|
||||||
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@@ -1035,7 +1024,9 @@ class LLM(BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
raise NotImplementedError()
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self._call, **kwargs), prompt, stop, run_manager
|
||||||
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@@ -1064,12 +1055,6 @@ class LLM(BaseLLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
if type(self)._acall == LLM._acall:
|
|
||||||
# model doesn't implement async call, so use default implementation
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
generations = []
|
generations = []
|
||||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
@@ -72,7 +74,6 @@ class BaseDocumentTransformer(ABC):
|
|||||||
A list of transformed Documents.
|
A list of transformed Documents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def atransform_documents(
|
async def atransform_documents(
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
) -> Sequence[Document]:
|
) -> Sequence[Document]:
|
||||||
@@ -84,3 +85,6 @@ class BaseDocumentTransformer(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of transformed Documents.
|
A list of transformed Documents.
|
||||||
"""
|
"""
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.transform_documents, **kwargs), documents
|
||||||
|
)
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -15,8 +16,12 @@ class Embeddings(ABC):
|
|||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Asynchronous Embed search docs."""
|
"""Asynchronous Embed search docs."""
|
||||||
raise NotImplementedError
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self.embed_documents, texts
|
||||||
|
)
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
"""Asynchronous Embed query text."""
|
"""Asynchronous Embed query text."""
|
||||||
raise NotImplementedError
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, self.embed_query, text
|
||||||
|
)
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
@@ -121,10 +123,6 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
|
|
||||||
# If the retriever doesn't implement async, use default implementation
|
|
||||||
return await super().ainvoke(input, config)
|
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
return await self.aget_relevant_documents(
|
return await self.aget_relevant_documents(
|
||||||
input,
|
input,
|
||||||
@@ -156,7 +154,9 @@ class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self._get_relevant_documents, run_manager=run_manager), query
|
||||||
|
)
|
||||||
|
|
||||||
def get_relevant_documents(
|
def get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
|
@@ -87,7 +87,9 @@ class VectorStore(ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||||
raise NotImplementedError
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.add_texts, **kwargs), texts, metadatas
|
||||||
|
)
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Run more documents through the embeddings and add to the vectorstore.
|
"""Run more documents through the embeddings and add to the vectorstore.
|
||||||
@@ -451,7 +453,9 @@ class VectorStore(ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VST:
|
) -> VST:
|
||||||
"""Return VectorStore initialized from texts and embeddings."""
|
"""Return VectorStore initialized from texts and embeddings."""
|
||||||
raise NotImplementedError
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
|
||||||
|
)
|
||||||
|
|
||||||
def _get_retriever_tags(self) -> List[str]:
|
def _get_retriever_tags(self) -> List[str]:
|
||||||
"""Get tags for retriever."""
|
"""Get tags for retriever."""
|
||||||
|
@@ -21,6 +21,7 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -28,6 +29,7 @@ import re
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
@@ -284,7 +286,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
self, documents: Sequence[Document], **kwargs: Any
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
) -> Sequence[Document]:
|
) -> Sequence[Document]:
|
||||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||||
raise NotImplementedError
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.transform_documents, **kwargs), documents
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CharacterTextSplitter(TextSplitter):
|
class CharacterTextSplitter(TextSplitter):
|
||||||
|
@@ -217,10 +217,6 @@ class ChildTool(BaseTool):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if type(self)._arun == BaseTool._arun:
|
|
||||||
# If the tool does not implement async, fall back to default implementation
|
|
||||||
return await super().ainvoke(input, config, **kwargs)
|
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
return await self.arun(
|
return await self.arun(
|
||||||
input,
|
input,
|
||||||
|
Reference in New Issue
Block a user