mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Propagate context vars in all classes/methods (#15329)
- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor needs manual handling of context vars <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
99000c612e
@ -1,12 +1,9 @@
|
|||||||
"""ChatModel wrapper which returns user input as the response.."""
|
"""ChatModel wrapper which returns user input as the response.."""
|
||||||
import asyncio
|
|
||||||
from functools import partial
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
@ -111,15 +108,3 @@ class HumanInputChatModel(BaseChatModel):
|
|||||||
self.message_func(messages, **self.message_kwargs)
|
self.message_func(messages, **self.message_kwargs)
|
||||||
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
|
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
|
||||||
return ChatResult(generations=[ChatGeneration(message=user_input)])
|
return ChatResult(generations=[ChatGeneration(message=user_input)])
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
func = partial(
|
|
||||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
@ -125,18 +122,6 @@ class ChatMlflow(BaseChatModel):
|
|||||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||||
return ChatMlflow._create_chat_result(resp)
|
return ChatMlflow._create_chat_result(resp)
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
func = partial(
|
|
||||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
return self._default_params
|
return self._default_params
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
@ -116,18 +113,6 @@ class ChatMLflowAIGateway(BaseChatModel):
|
|||||||
resp = mlflow.gateway.query(self.route, data=data)
|
resp = mlflow.gateway.query(self.route, data=data)
|
||||||
return ChatMLflowAIGateway._create_chat_result(resp)
|
return ChatMLflowAIGateway._create_chat_result(resp)
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
func = partial(
|
|
||||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
return self._default_params
|
return self._default_params
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
|
||||||
from typing import Any, AsyncIterator, Dict, List, Optional, cast
|
from typing import Any, AsyncIterator, Dict, List, Optional, cast
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -300,25 +298,3 @@ class PaiEasChatEndpoint(BaseChatModel):
|
|||||||
# break if stop sequence found
|
# break if stop sequence found
|
||||||
if stop_seq_found:
|
if stop_seq_found:
|
||||||
break
|
break
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
stream: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
if stream if stream is not None else self.streaming:
|
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
|
||||||
async for chunk in self._astream(
|
|
||||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
):
|
|
||||||
generation = chunk
|
|
||||||
assert generation is not None
|
|
||||||
return ChatResult(generations=[generation])
|
|
||||||
|
|
||||||
func = partial(
|
|
||||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
|
|
||||||
class BedrockEmbeddings(BaseModel, Embeddings):
|
class BedrockEmbeddings(BaseModel, Embeddings):
|
||||||
@ -181,9 +181,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self.embed_query, text)
|
||||||
None, partial(self.embed_query, text)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Asynchronous compute doc embeddings using a Bedrock model.
|
"""Asynchronous compute doc embeddings using a Bedrock model.
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from functools import partial
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -134,9 +134,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
List[float]: Embeddings for the text.
|
List[float]: Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self.embed_query, text)
|
||||||
None, partial(self.embed_query, text)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Asynchronous Embed search docs.
|
"""Asynchronous Embed search docs.
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
|
||||||
from typing import TYPE_CHECKING, Optional, Type
|
from typing import TYPE_CHECKING, Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
@ -57,11 +55,3 @@ Note: SessionId must be received from previous Browser window creation."""
|
|||||||
print(f"{e}, retrying...")
|
print(f"{e}, retrying...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred: {e}")
|
raise Exception(f"An error occurred: {e}")
|
||||||
|
|
||||||
async def _arun(
|
|
||||||
self,
|
|
||||||
sessionId: str,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
||||||
) -> None:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
await loop.run_in_executor(None, self._run, sessionId)
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
|
||||||
from typing import TYPE_CHECKING, Optional, Type
|
from typing import TYPE_CHECKING, Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
@ -67,14 +65,3 @@ class MultionCreateSession(BaseTool):
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred: {e}")
|
raise Exception(f"An error occurred: {e}")
|
||||||
|
|
||||||
async def _arun(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
url: Optional[str] = "https://www.google.com/",
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
||||||
) -> dict:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
result = await loop.run_in_executor(None, self._run, query, url)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
|
||||||
from typing import TYPE_CHECKING, Optional, Type
|
from typing import TYPE_CHECKING, Optional, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
@ -74,15 +72,3 @@ Note: sessionId must be received from previous Browser window creation."""
|
|||||||
return {"error": f"{e}", "Response": "retrying..."}
|
return {"error": f"{e}", "Response": "retrying..."}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred: {e}")
|
raise Exception(f"An error occurred: {e}")
|
||||||
|
|
||||||
async def _arun(
|
|
||||||
self,
|
|
||||||
sessionId: str,
|
|
||||||
query: str,
|
|
||||||
url: Optional[str] = "https://www.google.com/",
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
||||||
) -> dict:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
result = await loop.run_in_executor(None, self._run, sessionId, query, url)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import asyncio
|
|
||||||
import platform
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, List, Optional, Type, Union
|
from typing import Any, List, Optional, Type, Union
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
@ -77,13 +75,3 @@ class ShellTool(BaseTool):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Run commands and return final output."""
|
"""Run commands and return final output."""
|
||||||
return self.process.run(commands)
|
return self.process.run(commands)
|
||||||
|
|
||||||
async def _arun(
|
|
||||||
self,
|
|
||||||
commands: Union[str, List[str]],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Run commands asynchronously and return final output."""
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None, self.process.run, commands
|
|
||||||
)
|
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -24,6 +22,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
from langchain_community.docstore.base import AddableMixin, Docstore
|
from langchain_community.docstore.base import AddableMixin, Docstore
|
||||||
@ -359,7 +358,8 @@ class FAISS(VectorStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# This is a temporary workaround to make the similarity search asynchronous.
|
# This is a temporary workaround to make the similarity search asynchronous.
|
||||||
func = partial(
|
return await run_in_executor(
|
||||||
|
None,
|
||||||
self.similarity_search_with_score_by_vector,
|
self.similarity_search_with_score_by_vector,
|
||||||
embedding,
|
embedding,
|
||||||
k=k,
|
k=k,
|
||||||
@ -367,7 +367,6 @@ class FAISS(VectorStore):
|
|||||||
fetch_k=fetch_k,
|
fetch_k=fetch_k,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def similarity_search_with_score(
|
def similarity_search_with_score(
|
||||||
self,
|
self,
|
||||||
@ -640,7 +639,8 @@ class FAISS(VectorStore):
|
|||||||
relevance and score for each.
|
relevance and score for each.
|
||||||
"""
|
"""
|
||||||
# This is a temporary workaround to make the similarity search asynchronous.
|
# This is a temporary workaround to make the similarity search asynchronous.
|
||||||
func = partial(
|
return await run_in_executor(
|
||||||
|
None,
|
||||||
self.max_marginal_relevance_search_with_score_by_vector,
|
self.max_marginal_relevance_search_with_score_by_vector,
|
||||||
embedding,
|
embedding,
|
||||||
k=k,
|
k=k,
|
||||||
@ -648,7 +648,6 @@ class FAISS(VectorStore):
|
|||||||
lambda_mult=lambda_mult,
|
lambda_mult=lambda_mult,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
)
|
)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector(
|
||||||
self,
|
self,
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from functools import partial
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -31,6 +29,7 @@ except ImportError:
|
|||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
@ -941,7 +940,8 @@ class PGVector(VectorStore):
|
|||||||
# This is a temporary workaround to make the similarity search
|
# This is a temporary workaround to make the similarity search
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
# asynchronous. The proper solution is to make the similarity search
|
||||||
# asynchronous in the vector store implementations.
|
# asynchronous in the vector store implementations.
|
||||||
func = partial(
|
return await run_in_executor(
|
||||||
|
None,
|
||||||
self.max_marginal_relevance_search_by_vector,
|
self.max_marginal_relevance_search_by_vector,
|
||||||
embedding,
|
embedding,
|
||||||
k=k,
|
k=k,
|
||||||
@ -950,4 +950,3 @@ class PGVector(VectorStore):
|
|||||||
filter=filter,
|
filter=filter,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import functools
|
import functools
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
@ -25,6 +24,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
@ -58,10 +58,9 @@ def sync_call_fallback(method: Callable) -> Callable:
|
|||||||
# by removing the first letter from the method name. For example,
|
# by removing the first letter from the method name. For example,
|
||||||
# if the async method is called ``aaad_texts``, the synchronous method
|
# if the async method is called ``aaad_texts``, the synchronous method
|
||||||
# will be called ``aad_texts``.
|
# will be called ``aad_texts``.
|
||||||
sync_method = functools.partial(
|
return await run_in_executor(
|
||||||
getattr(self, method.__name__[1:]), *args, **kwargs
|
None, getattr(self, method.__name__[1:]), *args, **kwargs
|
||||||
)
|
)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, sync_method)
|
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ from langchain_core.runnables.base import (
|
|||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
coerce_to_runnable,
|
coerce_to_runnable,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.config import RunnableConfig, patch_config
|
from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config
|
||||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -186,7 +186,7 @@ class ContextGet(RunnableSerializable):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
configurable = config.get("configurable", {})
|
configurable = config.get("configurable", {})
|
||||||
if isinstance(self.key, list):
|
if isinstance(self.key, list):
|
||||||
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)}
|
||||||
@ -196,7 +196,7 @@ class ContextGet(RunnableSerializable):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
configurable = config.get("configurable", {})
|
configurable = config.get("configurable", {})
|
||||||
if isinstance(self.key, list):
|
if isinstance(self.key, list):
|
||||||
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
|
||||||
@ -281,7 +281,7 @@ class ContextSet(RunnableSerializable):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
configurable = config.get("configurable", {})
|
configurable = config.get("configurable", {})
|
||||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||||
if mapper is not None:
|
if mapper is not None:
|
||||||
@ -293,7 +293,7 @@ class ContextSet(RunnableSerializable):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
configurable = config.get("configurable", {})
|
configurable = config.get("configurable", {})
|
||||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||||
if mapper is not None:
|
if mapper is not None:
|
||||||
|
@ -4,13 +4,15 @@ import asyncio
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import Context, copy_context
|
from contextvars import copy_context
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
@ -272,25 +274,14 @@ def handle_event(
|
|||||||
# we end up in a deadlock, as we'd have gotten here from a
|
# we end up in a deadlock, as we'd have gotten here from a
|
||||||
# running coroutine, which we cannot interrupt to run this one.
|
# running coroutine, which we cannot interrupt to run this one.
|
||||||
# The solution is to create a new loop in a new thread.
|
# The solution is to create a new loop in a new thread.
|
||||||
with _executor_w_context(1) as executor:
|
with ThreadPoolExecutor(1) as executor:
|
||||||
executor.submit(_run_coros, coros).result()
|
executor.submit(
|
||||||
|
cast(Callable, copy_context().run), _run_coros, coros
|
||||||
|
).result()
|
||||||
else:
|
else:
|
||||||
_run_coros(coros)
|
_run_coros(coros)
|
||||||
|
|
||||||
|
|
||||||
def _set_context(context: Context) -> None:
|
|
||||||
for var, value in context.items():
|
|
||||||
var.set(value)
|
|
||||||
|
|
||||||
|
|
||||||
def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
|
|
||||||
return ThreadPoolExecutor(
|
|
||||||
max_workers=max_workers,
|
|
||||||
initializer=_set_context,
|
|
||||||
initargs=(copy_context(),),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||||
if hasattr(asyncio, "Runner"):
|
if hasattr(asyncio, "Runner"):
|
||||||
# Python 3.11+
|
# Python 3.11+
|
||||||
@ -315,7 +306,6 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _ahandle_event_for_handler(
|
async def _ahandle_event_for_handler(
|
||||||
executor: ThreadPoolExecutor,
|
|
||||||
handler: BaseCallbackHandler,
|
handler: BaseCallbackHandler,
|
||||||
event_name: str,
|
event_name: str,
|
||||||
ignore_condition_name: Optional[str],
|
ignore_condition_name: Optional[str],
|
||||||
@ -332,13 +322,18 @@ async def _ahandle_event_for_handler(
|
|||||||
event(*args, **kwargs)
|
event(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
executor, functools.partial(event, *args, **kwargs)
|
None,
|
||||||
|
cast(
|
||||||
|
Callable,
|
||||||
|
functools.partial(
|
||||||
|
copy_context().run, event, *args, **kwargs
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
if event_name == "on_chat_model_start":
|
if event_name == "on_chat_model_start":
|
||||||
message_strings = [get_buffer_string(m) for m in args[1]]
|
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||||
await _ahandle_event_for_handler(
|
await _ahandle_event_for_handler(
|
||||||
executor,
|
|
||||||
handler,
|
handler,
|
||||||
"on_llm_start",
|
"on_llm_start",
|
||||||
"ignore_llm",
|
"ignore_llm",
|
||||||
@ -380,25 +375,23 @@ async def ahandle_event(
|
|||||||
*args: The arguments to pass to the event handler
|
*args: The arguments to pass to the event handler
|
||||||
**kwargs: The keyword arguments to pass to the event handler
|
**kwargs: The keyword arguments to pass to the event handler
|
||||||
"""
|
"""
|
||||||
with _executor_w_context() as executor:
|
for handler in [h for h in handlers if h.run_inline]:
|
||||||
for handler in [h for h in handlers if h.run_inline]:
|
await _ahandle_event_for_handler(
|
||||||
await _ahandle_event_for_handler(
|
handler, event_name, ignore_condition_name, *args, **kwargs
|
||||||
executor, handler, event_name, ignore_condition_name, *args, **kwargs
|
|
||||||
)
|
|
||||||
await asyncio.gather(
|
|
||||||
*(
|
|
||||||
_ahandle_event_for_handler(
|
|
||||||
executor,
|
|
||||||
handler,
|
|
||||||
event_name,
|
|
||||||
ignore_condition_name,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
for handler in handlers
|
|
||||||
if not handler.run_inline
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
_ahandle_event_for_handler(
|
||||||
|
handler,
|
||||||
|
event_name,
|
||||||
|
ignore_condition_name,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for handler in handlers
|
||||||
|
if not handler.run_inline
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
BRM = TypeVar("BRM", bound="BaseRunManager")
|
BRM = TypeVar("BRM", bound="BaseRunManager")
|
||||||
@ -526,9 +519,17 @@ class ParentRunManager(RunManager):
|
|||||||
return manager
|
return manager
|
||||||
|
|
||||||
|
|
||||||
class AsyncRunManager(BaseRunManager):
|
class AsyncRunManager(BaseRunManager, ABC):
|
||||||
"""Async Run Manager."""
|
"""Async Run Manager."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_sync(self) -> RunManager:
|
||||||
|
"""Get the equivalent sync RunManager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RunManager: The sync RunManager.
|
||||||
|
"""
|
||||||
|
|
||||||
async def on_text(
|
async def on_text(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@ -664,6 +665,23 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||||
"""Async callback manager for LLM run."""
|
"""Async callback manager for LLM run."""
|
||||||
|
|
||||||
|
def get_sync(self) -> CallbackManagerForLLMRun:
|
||||||
|
"""Get the equivalent sync RunManager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManagerForLLMRun: The sync RunManager.
|
||||||
|
"""
|
||||||
|
return CallbackManagerForLLMRun(
|
||||||
|
run_id=self.run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
@ -818,6 +836,23 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
|||||||
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||||
"""Async callback manager for chain run."""
|
"""Async callback manager for chain run."""
|
||||||
|
|
||||||
|
def get_sync(self) -> CallbackManagerForChainRun:
|
||||||
|
"""Get the equivalent sync RunManager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManagerForChainRun: The sync RunManager.
|
||||||
|
"""
|
||||||
|
return CallbackManagerForChainRun(
|
||||||
|
run_id=self.run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_chain_end(
|
async def on_chain_end(
|
||||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -948,6 +983,23 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
|||||||
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||||
"""Async callback manager for tool run."""
|
"""Async callback manager for tool run."""
|
||||||
|
|
||||||
|
def get_sync(self) -> CallbackManagerForToolRun:
|
||||||
|
"""Get the equivalent sync RunManager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManagerForToolRun: The sync RunManager.
|
||||||
|
"""
|
||||||
|
return CallbackManagerForToolRun(
|
||||||
|
run_id=self.run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running.
|
"""Run when tool ends running.
|
||||||
|
|
||||||
@ -1031,6 +1083,23 @@ class AsyncCallbackManagerForRetrieverRun(
|
|||||||
):
|
):
|
||||||
"""Async callback manager for retriever run."""
|
"""Async callback manager for retriever run."""
|
||||||
|
|
||||||
|
def get_sync(self) -> CallbackManagerForRetrieverRun:
|
||||||
|
"""Get the equivalent sync RunManager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManagerForRetrieverRun: The sync RunManager.
|
||||||
|
"""
|
||||||
|
return CallbackManagerForRetrieverRun(
|
||||||
|
run_id=self.run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_retriever_end(
|
async def on_retriever_end(
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Any, Sequence
|
from typing import TYPE_CHECKING, Any, Sequence
|
||||||
|
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@ -69,6 +69,6 @@ class BaseDocumentTransformer(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of transformed Documents.
|
A list of transformed Documents.
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, partial(self.transform_documents, **kwargs), documents
|
None, self.transform_documents, documents, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
|
|
||||||
class Embeddings(ABC):
|
class Embeddings(ABC):
|
||||||
"""Interface for embedding models."""
|
"""Interface for embedding models."""
|
||||||
@ -16,12 +17,8 @@ class Embeddings(ABC):
|
|||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Asynchronous Embed search docs."""
|
"""Asynchronous Embed search docs."""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self.embed_documents, texts)
|
||||||
None, self.embed_documents, texts
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
"""Asynchronous Embed query text."""
|
"""Asynchronous Embed query text."""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self.embed_query, text)
|
||||||
None, self.embed_query, text
|
|
||||||
)
|
|
||||||
|
@ -4,7 +4,6 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -45,6 +44,7 @@ from langchain_core.outputs import (
|
|||||||
)
|
)
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
@ -158,7 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return cast(
|
return cast(
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
self.generate_prompt(
|
self.generate_prompt(
|
||||||
@ -180,7 +180,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)],
|
[self._convert_input(input)],
|
||||||
stop=stop,
|
stop=stop,
|
||||||
@ -206,7 +206,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop, **kwargs}
|
options = {"stop": stop, **kwargs}
|
||||||
@ -264,7 +264,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop, **kwargs}
|
options = {"stop": stop, **kwargs}
|
||||||
@ -605,8 +605,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
None,
|
||||||
|
self._generate,
|
||||||
|
messages,
|
||||||
|
stop,
|
||||||
|
run_manager.get_sync() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
@ -766,7 +771,11 @@ class SimpleChatModel(BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
func = partial(
|
return await run_in_executor(
|
||||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
None,
|
||||||
|
self._generate,
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager.get_sync() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
@ -8,7 +8,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -52,7 +51,8 @@ from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
|
|||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||||
from langchain_core.runnables import RunnableConfig, get_config_list
|
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return (
|
return (
|
||||||
self.generate_prompt(
|
self.generate_prompt(
|
||||||
[self._convert_input(input)],
|
[self._convert_input(input)],
|
||||||
@ -244,7 +244,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)],
|
[self._convert_input(input)],
|
||||||
stop=stop,
|
stop=stop,
|
||||||
@ -362,7 +362,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = self._convert_input(input).to_string()
|
prompt = self._convert_input(input).to_string()
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
@ -419,7 +419,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = self._convert_input(input).to_string()
|
prompt = self._convert_input(input).to_string()
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
params = {**params, **kwargs}
|
params = {**params, **kwargs}
|
||||||
@ -483,8 +483,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompts."""
|
"""Run the LLM on the given prompts."""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, partial(self._generate, **kwargs), prompts, stop, run_manager
|
None,
|
||||||
|
self._generate,
|
||||||
|
prompts,
|
||||||
|
stop,
|
||||||
|
run_manager.get_sync() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
@ -1049,8 +1054,13 @@ class LLM(BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, partial(self._call, **kwargs), prompt, stop, run_manager
|
None,
|
||||||
|
self._call,
|
||||||
|
prompt,
|
||||||
|
stop,
|
||||||
|
run_manager.get_sync() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import functools
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -20,6 +18,7 @@ from typing_extensions import get_args
|
|||||||
from langchain_core.messages import AnyMessage, BaseMessage
|
from langchain_core.messages import AnyMessage, BaseMessage
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
@ -54,9 +53,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Structured output.
|
Structured output.
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self.parse_result, result)
|
||||||
None, self.parse_result, result
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseGenerationOutputParser(
|
class BaseGenerationOutputParser(
|
||||||
@ -247,9 +244,7 @@ class BaseOutputParser(
|
|||||||
Returns:
|
Returns:
|
||||||
Structured output.
|
Structured output.
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self.parse_result, result, partial=partial)
|
||||||
None, functools.partial(self.parse_result, partial=partial), result
|
|
||||||
)
|
|
||||||
|
|
||||||
async def aparse(self, text: str) -> T:
|
async def aparse(self, text: str) -> T:
|
||||||
"""Parse a single string model output into some structure.
|
"""Parse a single string model output into some structure.
|
||||||
@ -260,7 +255,7 @@ class BaseOutputParser(
|
|||||||
Returns:
|
Returns:
|
||||||
Structured output.
|
Structured output.
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
|
return await run_in_executor(None, self.parse, text)
|
||||||
|
|
||||||
# TODO: rename 'completion' -> 'text'.
|
# TODO: rename 'completion' -> 'text'.
|
||||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||||
|
@ -1,15 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.load.dump import dumpd
|
from langchain_core.load.dump import dumpd
|
||||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import (
|
||||||
|
Runnable,
|
||||||
|
RunnableConfig,
|
||||||
|
RunnableSerializable,
|
||||||
|
ensure_config,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
@ -113,7 +117,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None
|
self, input: str, config: Optional[RunnableConfig] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return self.get_relevant_documents(
|
return self.get_relevant_documents(
|
||||||
input,
|
input,
|
||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
@ -128,7 +132,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return await self.aget_relevant_documents(
|
return await self.aget_relevant_documents(
|
||||||
input,
|
input,
|
||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
@ -159,8 +163,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, partial(self._get_relevant_documents, run_manager=run_manager), query
|
None,
|
||||||
|
self._get_relevant_documents,
|
||||||
|
query,
|
||||||
|
run_manager=run_manager.get_sync(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_relevant_documents(
|
def get_relevant_documents(
|
||||||
|
@ -27,8 +27,10 @@ from langchain_core.runnables.base import (
|
|||||||
from langchain_core.runnables.branch import RunnableBranch
|
from langchain_core.runnables.branch import RunnableBranch
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
get_config_list,
|
get_config_list,
|
||||||
patch_config,
|
patch_config,
|
||||||
|
run_in_executor,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
from langchain_core.runnables.fallbacks import RunnableWithFallbacks
|
||||||
from langchain_core.runnables.passthrough import (
|
from langchain_core.runnables.passthrough import (
|
||||||
@ -42,6 +44,7 @@ from langchain_core.runnables.utils import (
|
|||||||
ConfigurableField,
|
ConfigurableField,
|
||||||
ConfigurableFieldMultiOption,
|
ConfigurableFieldMultiOption,
|
||||||
ConfigurableFieldSingleOption,
|
ConfigurableFieldSingleOption,
|
||||||
|
ConfigurableFieldSpec,
|
||||||
aadd,
|
aadd,
|
||||||
add,
|
add,
|
||||||
)
|
)
|
||||||
@ -51,6 +54,9 @@ __all__ = [
|
|||||||
"ConfigurableField",
|
"ConfigurableField",
|
||||||
"ConfigurableFieldSingleOption",
|
"ConfigurableFieldSingleOption",
|
||||||
"ConfigurableFieldMultiOption",
|
"ConfigurableFieldMultiOption",
|
||||||
|
"ConfigurableFieldSpec",
|
||||||
|
"ensure_config",
|
||||||
|
"run_in_executor",
|
||||||
"patch_config",
|
"patch_config",
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
|
@ -6,7 +6,7 @@ import threading
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import FIRST_COMPLETED, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial, wraps
|
from functools import wraps
|
||||||
from itertools import groupby, tee
|
from itertools import groupby, tee
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -47,6 +47,7 @@ from langchain_core.runnables.config import (
|
|||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
merge_configs,
|
merge_configs,
|
||||||
patch_config,
|
patch_config,
|
||||||
|
run_in_executor,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.graph import Graph
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
@ -472,10 +473,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
Subclasses should override this method if they can run asynchronously.
|
Subclasses should override this method if they can run asynchronously.
|
||||||
"""
|
"""
|
||||||
with get_executor_for_config(config) as executor:
|
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
executor, partial(self.invoke, **kwargs), input, config
|
|
||||||
)
|
|
||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
@ -665,7 +663,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assign the stream handler to the config
|
# Assign the stream handler to the config
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callbacks = config.get("callbacks")
|
callbacks = config.get("callbacks")
|
||||||
if callbacks is None:
|
if callbacks is None:
|
||||||
config["callbacks"] = [stream]
|
config["callbacks"] = [stream]
|
||||||
@ -2883,10 +2881,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
@wraps(self.func)
|
@wraps(self.func)
|
||||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
with get_executor_for_config(config) as executor:
|
return await run_in_executor(config, self.func, *args, **kwargs)
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
executor, partial(self.func, **kwargs), *args
|
|
||||||
)
|
|
||||||
|
|
||||||
afunc = f
|
afunc = f
|
||||||
|
|
||||||
@ -2913,7 +2908,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
def _config(
|
def _config(
|
||||||
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
|
|
||||||
if config.get("run_name") is None:
|
if config.get("run_name") is None:
|
||||||
try:
|
try:
|
||||||
@ -3052,9 +3047,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
@wraps(self.func)
|
@wraps(self.func)
|
||||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(config, self.func, *args, **kwargs)
|
||||||
None, partial(self.func, **kwargs), *args
|
|
||||||
)
|
|
||||||
|
|
||||||
afunc = f
|
afunc = f
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
import asyncio
|
||||||
|
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import Context, copy_context
|
from contextvars import ContextVar, copy_context
|
||||||
|
from functools import partial
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -10,13 +12,16 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import ParamSpec, TypedDict
|
||||||
|
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
Input,
|
Input,
|
||||||
@ -91,6 +96,11 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
var_child_runnable_config = ContextVar(
|
||||||
|
"child_runnable_config", default=RunnableConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||||
"""Ensure that a config is a dict with all keys present.
|
"""Ensure that a config is a dict with all keys present.
|
||||||
|
|
||||||
@ -107,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
|||||||
callbacks=None,
|
callbacks=None,
|
||||||
recursion_limit=25,
|
recursion_limit=25,
|
||||||
)
|
)
|
||||||
|
if var_config := var_child_runnable_config.get():
|
||||||
|
empty.update(
|
||||||
|
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
|
||||||
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
empty.update(
|
empty.update(
|
||||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
||||||
@ -388,9 +402,51 @@ def get_async_callback_manager_for_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _set_context(context: Context) -> None:
|
P = ParamSpec("P")
|
||||||
for var, value in context.items():
|
T = TypeVar("T")
|
||||||
var.set(value)
|
|
||||||
|
|
||||||
|
class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||||||
|
"""ThreadPoolExecutor that copies the context to the child thread."""
|
||||||
|
|
||||||
|
def submit( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
func: Callable[P, T],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Future[T]:
|
||||||
|
"""Submit a function to the executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable[..., T]): The function to submit.
|
||||||
|
*args (Any): The positional arguments to the function.
|
||||||
|
**kwargs (Any): The keyword arguments to the function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Future[T]: The future for the function.
|
||||||
|
"""
|
||||||
|
return super().submit(
|
||||||
|
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
|
||||||
|
)
|
||||||
|
|
||||||
|
def map(
|
||||||
|
self,
|
||||||
|
fn: Callable[..., T],
|
||||||
|
*iterables: Iterable[Any],
|
||||||
|
timeout: float | None = None,
|
||||||
|
chunksize: int = 1,
|
||||||
|
) -> Iterator[T]:
|
||||||
|
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def _wrapped_fn(*args: Any) -> T:
|
||||||
|
return contexts.pop().run(fn, *args)
|
||||||
|
|
||||||
|
return super().map(
|
||||||
|
_wrapped_fn,
|
||||||
|
*iterables,
|
||||||
|
timeout=timeout,
|
||||||
|
chunksize=chunksize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -406,9 +462,36 @@ def get_executor_for_config(
|
|||||||
Generator[Executor, None, None]: The executor.
|
Generator[Executor, None, None]: The executor.
|
||||||
"""
|
"""
|
||||||
config = config or {}
|
config = config or {}
|
||||||
with ThreadPoolExecutor(
|
with ContextThreadPoolExecutor(
|
||||||
max_workers=config.get("max_concurrency"),
|
max_workers=config.get("max_concurrency")
|
||||||
initializer=_set_context,
|
|
||||||
initargs=(copy_context(),),
|
|
||||||
) as executor:
|
) as executor:
|
||||||
yield executor
|
yield executor
|
||||||
|
|
||||||
|
|
||||||
|
async def run_in_executor(
|
||||||
|
executor_or_config: Optional[Union[Executor, RunnableConfig]],
|
||||||
|
func: Callable[P, T],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> T:
|
||||||
|
"""Run a function in an executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
executor (Executor): The executor.
|
||||||
|
func (Callable[P, Output]): The function.
|
||||||
|
*args (Any): The positional arguments to the function.
|
||||||
|
**kwargs (Any): The keyword arguments to the function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output: The output of the function.
|
||||||
|
"""
|
||||||
|
if executor_or_config is None or isinstance(executor_or_config, dict):
|
||||||
|
# Use default executor with context copied from current context
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
executor_or_config, partial(func, **kwargs), *args
|
||||||
|
)
|
||||||
|
@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
|
|||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
get_config_list,
|
get_config_list,
|
||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
)
|
)
|
||||||
@ -259,7 +260,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
def _prepare(
|
def _prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()}
|
||||||
configurable_fields = {
|
configurable_fields = {
|
||||||
specs_by_id[k][0]: v
|
specs_by_id[k][0]: v
|
||||||
@ -392,7 +393,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
def _prepare(
|
def _prepare(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
which = config.get("configurable", {}).get(self.which.id, self.default_key)
|
||||||
# remap configurable keys for the chosen alternative
|
# remap configurable keys for the chosen alternative
|
||||||
if self.prefix_keys:
|
if self.prefix_keys:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -18,6 +17,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
|
|||||||
from langchain_core.load import load
|
from langchain_core.load import load
|
||||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||||
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||||
from langchain_core.runnables.utils import (
|
from langchain_core.runnables.utils import (
|
||||||
ConfigurableFieldSpec,
|
ConfigurableFieldSpec,
|
||||||
@ -331,9 +331,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
async def _aenter_history(
|
async def _aenter_history(
|
||||||
self, input: Dict[str, Any], config: RunnableConfig
|
self, input: Dict[str, Any], config: RunnableConfig
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(config, self._enter_history, input, config)
|
||||||
None, self._enter_history, input, config
|
|
||||||
)
|
|
||||||
|
|
||||||
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
|
def _exit_history(self, run: Run, config: RunnableConfig) -> None:
|
||||||
hist = config["configurable"]["message_history"]
|
hist = config["configurable"]["message_history"]
|
||||||
|
@ -31,6 +31,7 @@ from langchain_core.runnables.config import (
|
|||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
acall_func_with_variable_args,
|
acall_func_with_variable_args,
|
||||||
call_func_with_variable_args,
|
call_func_with_variable_args,
|
||||||
|
ensure_config,
|
||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
@ -206,7 +207,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Other:
|
) -> Other:
|
||||||
if self.func is not None:
|
if self.func is not None:
|
||||||
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
|
call_func_with_variable_args(
|
||||||
|
self.func, input, ensure_config(config), **kwargs
|
||||||
|
)
|
||||||
return self._call_with_config(identity, input, config)
|
return self._call_with_config(identity, input, config)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
@ -217,10 +220,12 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
) -> Other:
|
) -> Other:
|
||||||
if self.afunc is not None:
|
if self.afunc is not None:
|
||||||
await acall_func_with_variable_args(
|
await acall_func_with_variable_args(
|
||||||
self.afunc, input, config or {}, **kwargs
|
self.afunc, input, ensure_config(config), **kwargs
|
||||||
)
|
)
|
||||||
elif self.func is not None:
|
elif self.func is not None:
|
||||||
call_func_with_variable_args(self.func, input, config or {}, **kwargs)
|
call_func_with_variable_args(
|
||||||
|
self.func, input, ensure_config(config), **kwargs
|
||||||
|
)
|
||||||
return await self._acall_with_config(aidentity, input, config)
|
return await self._acall_with_config(aidentity, input, config)
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
@ -243,7 +248,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
final = final + chunk
|
final = final + chunk
|
||||||
|
|
||||||
if final is not None:
|
if final is not None:
|
||||||
call_func_with_variable_args(self.func, final, config or {}, **kwargs)
|
call_func_with_variable_args(
|
||||||
|
self.func, final, ensure_config(config), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
@ -269,7 +276,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
final = final + chunk
|
final = final + chunk
|
||||||
|
|
||||||
if final is not None:
|
if final is not None:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
if self.afunc is not None:
|
if self.afunc is not None:
|
||||||
await acall_func_with_variable_args(
|
await acall_func_with_variable_args(
|
||||||
self.afunc, final, config, **kwargs
|
self.afunc, final, config, **kwargs
|
||||||
@ -458,7 +465,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get executor to start map output stream in background
|
# get executor to start map output stream in background
|
||||||
with get_executor_for_config(config or {}) as executor:
|
with get_executor_for_config(config) as executor:
|
||||||
# start map output stream
|
# start map output stream
|
||||||
first_map_chunk_future = executor.submit(
|
first_map_chunk_future = executor.submit(
|
||||||
next,
|
next,
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
"""Base implementation for tools or skills."""
|
"""Base implementation for tools or skills."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import partial
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
@ -26,7 +24,13 @@ from langchain_core.pydantic_v1 import (
|
|||||||
root_validator,
|
root_validator,
|
||||||
validate_arguments,
|
validate_arguments,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import (
|
||||||
|
Runnable,
|
||||||
|
RunnableConfig,
|
||||||
|
RunnableSerializable,
|
||||||
|
ensure_config,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
|
|
||||||
class SchemaAnnotationError(TypeError):
|
class SchemaAnnotationError(TypeError):
|
||||||
@ -202,7 +206,7 @@ class ChildTool(BaseTool):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return self.run(
|
return self.run(
|
||||||
input,
|
input,
|
||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
@ -218,7 +222,7 @@ class ChildTool(BaseTool):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return await self.arun(
|
return await self.arun(
|
||||||
input,
|
input,
|
||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
@ -280,11 +284,7 @@ class ChildTool(BaseTool):
|
|||||||
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||||
to child implementations to enable tracing,
|
to child implementations to enable tracing,
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||||
None,
|
|
||||||
partial(self._run, **kwargs),
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||||
# For backwards compatibility, if run_input is a string,
|
# For backwards compatibility, if run_input is a string,
|
||||||
@ -468,9 +468,7 @@ class Tool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
if not self.coroutine:
|
if not self.coroutine:
|
||||||
# If the tool does not implement async, fall back to default implementation
|
# If the tool does not implement async, fall back to default implementation
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
||||||
None, partial(self.invoke, input, config, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
return await super().ainvoke(input, config, **kwargs)
|
return await super().ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
@ -538,8 +536,12 @@ class Tool(BaseTool):
|
|||||||
else await self.coroutine(*args, **kwargs)
|
else await self.coroutine(*args, **kwargs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, partial(self._run, run_manager=run_manager, **kwargs), *args
|
None,
|
||||||
|
self._run,
|
||||||
|
run_manager=run_manager.get_sync() if run_manager else None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: this is for backwards compatibility, remove in future
|
# TODO: this is for backwards compatibility, remove in future
|
||||||
@ -599,9 +601,7 @@ class StructuredTool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
if not self.coroutine:
|
if not self.coroutine:
|
||||||
# If the tool does not implement async, fall back to default implementation
|
# If the tool does not implement async, fall back to default implementation
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(config, self.invoke, input, config, **kwargs)
|
||||||
None, partial(self.invoke, input, config, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
return await super().ainvoke(input, config, **kwargs)
|
return await super().ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
@ -652,10 +652,12 @@ class StructuredTool(BaseTool):
|
|||||||
if new_argument_supported
|
if new_argument_supported
|
||||||
else await self.coroutine(*args, **kwargs)
|
else await self.coroutine(*args, **kwargs)
|
||||||
)
|
)
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None,
|
None,
|
||||||
partial(self._run, run_manager=run_manager, **kwargs),
|
self._run,
|
||||||
|
run_manager=run_manager.get_sync() if run_manager else None,
|
||||||
*args,
|
*args,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -24,6 +22,7 @@ from typing import (
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
@ -103,9 +102,7 @@ class VectorStore(ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
|
||||||
None, partial(self.add_texts, **kwargs), texts, metadatas
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Run more documents through the embeddings and add to the vectorstore.
|
"""Run more documents through the embeddings and add to the vectorstore.
|
||||||
@ -224,8 +221,9 @@ class VectorStore(ABC):
|
|||||||
# This is a temporary workaround to make the similarity search
|
# This is a temporary workaround to make the similarity search
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
# asynchronous. The proper solution is to make the similarity search
|
||||||
# asynchronous in the vector store implementations.
|
# asynchronous in the vector store implementations.
|
||||||
func = partial(self.similarity_search_with_score, *args, **kwargs)
|
return await run_in_executor(
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
None, self.similarity_search_with_score, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _similarity_search_with_relevance_scores(
|
def _similarity_search_with_relevance_scores(
|
||||||
self,
|
self,
|
||||||
@ -383,8 +381,7 @@ class VectorStore(ABC):
|
|||||||
# This is a temporary workaround to make the similarity search
|
# This is a temporary workaround to make the similarity search
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
# asynchronous. The proper solution is to make the similarity search
|
||||||
# asynchronous in the vector store implementations.
|
# asynchronous in the vector store implementations.
|
||||||
func = partial(self.similarity_search, query, k=k, **kwargs)
|
return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||||
@ -408,8 +405,9 @@ class VectorStore(ABC):
|
|||||||
# This is a temporary workaround to make the similarity search
|
# This is a temporary workaround to make the similarity search
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
# asynchronous. The proper solution is to make the similarity search
|
||||||
# asynchronous in the vector store implementations.
|
# asynchronous in the vector store implementations.
|
||||||
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
|
return await run_in_executor(
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
None, self.similarity_search_by_vector, embedding, k=k, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
self,
|
self,
|
||||||
@ -450,7 +448,8 @@ class VectorStore(ABC):
|
|||||||
# This is a temporary workaround to make the similarity search
|
# This is a temporary workaround to make the similarity search
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
# asynchronous. The proper solution is to make the similarity search
|
||||||
# asynchronous in the vector store implementations.
|
# asynchronous in the vector store implementations.
|
||||||
func = partial(
|
return await run_in_executor(
|
||||||
|
None,
|
||||||
self.max_marginal_relevance_search,
|
self.max_marginal_relevance_search,
|
||||||
query,
|
query,
|
||||||
k=k,
|
k=k,
|
||||||
@ -458,7 +457,6 @@ class VectorStore(ABC):
|
|||||||
lambda_mult=lambda_mult,
|
lambda_mult=lambda_mult,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector(
|
||||||
self,
|
self,
|
||||||
@ -541,8 +539,8 @@ class VectorStore(ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> VST:
|
) -> VST:
|
||||||
"""Return VectorStore initialized from texts and embeddings."""
|
"""Return VectorStore initialized from texts and embeddings."""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
|
None, cls.from_texts, texts, embedding, metadatas, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_retriever_tags(self) -> List[str]:
|
def _get_retriever_tags(self) -> List[str]:
|
||||||
|
@ -5,6 +5,9 @@ EXPECTED_ALL = [
|
|||||||
"ConfigurableField",
|
"ConfigurableField",
|
||||||
"ConfigurableFieldSingleOption",
|
"ConfigurableFieldSingleOption",
|
||||||
"ConfigurableFieldMultiOption",
|
"ConfigurableFieldMultiOption",
|
||||||
|
"ConfigurableFieldSpec",
|
||||||
|
"ensure_config",
|
||||||
|
"run_in_executor",
|
||||||
"patch_config",
|
"patch_config",
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""A tool for running python code in a REPL."""
|
"""A tool for running python code in a REPL."""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
@ -14,6 +13,7 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
from langchain_experimental.utilities.python import PythonREPL
|
from langchain_experimental.utilities.python import PythonREPL
|
||||||
|
|
||||||
@ -72,10 +72,7 @@ class PythonREPLTool(BaseTool):
|
|||||||
if self.sanitize_input:
|
if self.sanitize_input:
|
||||||
query = sanitize_input(query)
|
query = sanitize_input(query)
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
return await run_in_executor(None, self.run, query)
|
||||||
result = await loop.run_in_executor(None, self.run, query)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class PythonInputs(BaseModel):
|
class PythonInputs(BaseModel):
|
||||||
@ -144,7 +141,4 @@ class PythonAstREPLTool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
return await run_in_executor(None, self._run, query)
|
||||||
result = await loop.run_in_executor(None, self._run, query)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
@ -30,7 +30,7 @@ from langchain_core.prompts import BasePromptTemplate
|
|||||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||||
from langchain_core.runnables import Runnable, RunnableConfig
|
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||||
from langchain_core.runnables.utils import AddableDict
|
from langchain_core.runnables.utils import AddableDict
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils.input import get_color_mapping
|
from langchain_core.utils.input import get_color_mapping
|
||||||
@ -1437,7 +1437,7 @@ class AgentExecutor(Chain):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[AddableDict]:
|
) -> Iterator[AddableDict]:
|
||||||
"""Enables streaming over steps taken to reach final output."""
|
"""Enables streaming over steps taken to reach final output."""
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
iterator = AgentExecutorIterator(
|
iterator = AgentExecutorIterator(
|
||||||
self,
|
self,
|
||||||
input,
|
input,
|
||||||
@ -1458,7 +1458,7 @@ class AgentExecutor(Chain):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[AddableDict]:
|
) -> AsyncIterator[AddableDict]:
|
||||||
"""Enables streaming over steps taken to reach final output."""
|
"""Enables streaming over steps taken to reach final output."""
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
iterator = AgentExecutorIterator(
|
iterator = AgentExecutorIterator(
|
||||||
self,
|
self,
|
||||||
input,
|
input,
|
||||||
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Un
|
|||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
@ -222,7 +222,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|||||||
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
|
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
inheritable_tags=config.get("tags"),
|
inheritable_tags=config.get("tags"),
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@ -85,12 +84,5 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
|||||||
message = result[0].message
|
message = result[0].message
|
||||||
return self._parse_ai_message(message)
|
return self._parse_ai_message(message)
|
||||||
|
|
||||||
async def aparse_result(
|
|
||||||
self, result: List[Generation], *, partial: bool = False
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, self.parse_result, result
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
raise ValueError("Can only parse messages")
|
raise ValueError("Can only parse messages")
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@ -92,12 +91,5 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
|||||||
message = result[0].message
|
message = result[0].message
|
||||||
return parse_ai_message_to_openai_tool_action(message)
|
return parse_ai_message_to_openai_tool_action(message)
|
||||||
|
|
||||||
async def aparse_result(
|
|
||||||
self, result: List[Generation], *, partial: bool = False
|
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, self.parse_result, result
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||||
raise ValueError("Can only parse messages")
|
raise ValueError("Can only parse messages")
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -19,7 +18,12 @@ from langchain_core.pydantic_v1 import (
|
|||||||
root_validator,
|
root_validator,
|
||||||
validator,
|
validator,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
from langchain_core.runnables import (
|
||||||
|
RunnableConfig,
|
||||||
|
RunnableSerializable,
|
||||||
|
ensure_config,
|
||||||
|
run_in_executor,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -85,7 +89,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return self(
|
return self(
|
||||||
input,
|
input,
|
||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
@ -101,7 +105,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
return await self.acall(
|
return await self.acall(
|
||||||
input,
|
input,
|
||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
@ -245,8 +249,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
A dict of named outputs. Should contain all outputs specified in
|
A dict of named outputs. Should contain all outputs specified in
|
||||||
`Chain.output_keys`.
|
`Chain.output_keys`.
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, self._call, inputs, run_manager
|
None, self._call, inputs, run_manager.get_sync() if run_manager else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
"""Interfaces to be implemented by general evaluators."""
|
"""Interfaces to be implemented by general evaluators."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Optional, Sequence, Tuple, Union
|
from typing import Any, Optional, Sequence, Tuple, Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
@ -189,15 +188,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
|||||||
- value: the string value of the evaluation, if applicable.
|
- value: the string value of the evaluation, if applicable.
|
||||||
- reasoning: the reasoning for the evaluation, if applicable.
|
- reasoning: the reasoning for the evaluation, if applicable.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None,
|
None,
|
||||||
partial(
|
self._evaluate_strings,
|
||||||
self._evaluate_strings,
|
prediction=prediction,
|
||||||
prediction=prediction,
|
reference=reference,
|
||||||
reference=reference,
|
input=input,
|
||||||
input=input,
|
**kwargs,
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def evaluate_strings(
|
def evaluate_strings(
|
||||||
@ -292,16 +289,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the preference, scores, and/or other information.
|
dict: A dictionary containing the preference, scores, and/or other information.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None,
|
None,
|
||||||
partial(
|
self._evaluate_string_pairs,
|
||||||
self._evaluate_string_pairs,
|
prediction=prediction,
|
||||||
prediction=prediction,
|
prediction_b=prediction_b,
|
||||||
prediction_b=prediction_b,
|
reference=reference,
|
||||||
reference=reference,
|
input=input,
|
||||||
input=input,
|
**kwargs,
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def evaluate_string_pairs(
|
def evaluate_string_pairs(
|
||||||
@ -415,16 +410,14 @@ class AgentTrajectoryEvaluator(_EvalArgsMixin, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: The evaluation result.
|
dict: The evaluation result.
|
||||||
"""
|
"""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None,
|
None,
|
||||||
partial(
|
self._evaluate_agent_trajectory,
|
||||||
self._evaluate_agent_trajectory,
|
prediction=prediction,
|
||||||
prediction=prediction,
|
agent_trajectory=agent_trajectory,
|
||||||
agent_trajectory=agent_trajectory,
|
reference=reference,
|
||||||
reference=reference,
|
input=input,
|
||||||
input=input,
|
**kwargs,
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def evaluate_agent_trajectory(
|
def evaluate_agent_trajectory(
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import asyncio
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import List, Optional, Sequence, Union
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
from langchain_core.documents import BaseDocumentTransformer, Document
|
from langchain_core.documents import BaseDocumentTransformer, Document
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class BaseDocumentCompressor(BaseModel, ABC):
|
|||||||
callbacks: Optional[Callbacks] = None,
|
callbacks: Optional[Callbacks] = None,
|
||||||
) -> Sequence[Document]:
|
) -> Sequence[Document]:
|
||||||
"""Compress retrieved documents given the query context."""
|
"""Compress retrieved documents given the query context."""
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await run_in_executor(
|
||||||
None, self.compress_documents, documents, query, callbacks
|
None, self.compress_documents, documents, query, callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -29,7 +28,6 @@ import re
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
|
||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
@ -283,14 +281,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
"""Transform sequence of documents by splitting them."""
|
"""Transform sequence of documents by splitting them."""
|
||||||
return self.split_documents(list(documents))
|
return self.split_documents(list(documents))
|
||||||
|
|
||||||
async def atransform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, partial(self.transform_documents, **kwargs), documents
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CharacterTextSplitter(TextSplitter):
|
class CharacterTextSplitter(TextSplitter):
|
||||||
"""Splitting text that looks at characters."""
|
"""Splitting text that looks at characters."""
|
||||||
|
Loading…
Reference in New Issue
Block a user