mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
Merge 52ae5b09ff
into 0e287763cd
This commit is contained in:
commit
7069a0661d
@ -3,8 +3,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncIterator, Iterator
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Callable, Dict, Optional, Union, cast
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -17,6 +17,12 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
from langchain_core.language_models.base import LanguageModelInput
|
||||||
|
|
||||||
|
from langchain_core.runnables import Runnable, RunnableConfig
|
||||||
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
class FakeMessagesListChatModel(BaseChatModel):
|
class FakeMessagesListChatModel(BaseChatModel):
|
||||||
"""Fake ChatModel for testing purposes."""
|
"""Fake ChatModel for testing purposes."""
|
||||||
@ -367,3 +373,191 @@ class ParrotFakeChatModel(BaseChatModel):
|
|||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "parrot-fake-chat-model"
|
return "parrot-fake-chat-model"
|
||||||
|
|
||||||
|
|
||||||
|
class FakeToolCallingListChatModel(BaseChatModel):
|
||||||
|
"""Fake Tool calling ChatModel for testing purposes."""
|
||||||
|
|
||||||
|
responses: list[Union[dict,str,BaseMessage]]
|
||||||
|
"""List of responses to **cycle** through in order."""
|
||||||
|
sleep: Optional[float] = None
|
||||||
|
i: int = 0
|
||||||
|
"""List of responses to **cycle** through in order."""
|
||||||
|
error_on_chunk_number: Optional[int] = None
|
||||||
|
"""Internally incremented after every model invocation."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-list-chat-model"
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
|
|
||||||
|
if isinstance(response,BaseMessage):
|
||||||
|
message = response
|
||||||
|
elif isinstance(response,str):
|
||||||
|
message = AIMessage(content=response)
|
||||||
|
elif isinstance(response,dict):
|
||||||
|
message = AIMessage(content=response["content"], tool_calls = response["tool_calls"])
|
||||||
|
else:
|
||||||
|
raise ValueError("Incorrect response type")
|
||||||
|
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
return await run_in_executor(
|
||||||
|
None,
|
||||||
|
self._generate,
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager.get_sync() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: Union[list[str], None] = None,
|
||||||
|
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if isinstance(response,BaseMessage):
|
||||||
|
output = response.content
|
||||||
|
elif isinstance(response,str):
|
||||||
|
output = response
|
||||||
|
elif isinstance(response,dict):
|
||||||
|
output = response["content"]
|
||||||
|
tool_calls = response["tool_calls"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Incorrect response type")
|
||||||
|
|
||||||
|
for i_c, c in enumerate(output):
|
||||||
|
if self.sleep is not None:
|
||||||
|
time.sleep(self.sleep)
|
||||||
|
if (
|
||||||
|
self.error_on_chunk_number is not None
|
||||||
|
and i_c == self.error_on_chunk_number
|
||||||
|
):
|
||||||
|
raise FakeListChatModelError
|
||||||
|
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls))
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
stop: Union[list[str], None] = None,
|
||||||
|
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
response = self.responses[self.i]
|
||||||
|
if self.i < len(self.responses) - 1:
|
||||||
|
self.i += 1
|
||||||
|
else:
|
||||||
|
self.i = 0
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if isinstance(response,BaseMessage):
|
||||||
|
output = response.content
|
||||||
|
elif isinstance(response,str):
|
||||||
|
output = response
|
||||||
|
elif isinstance(response,dict):
|
||||||
|
output = response["content"]
|
||||||
|
tool_calls = response["tool_calls"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Incorrect response type")
|
||||||
|
|
||||||
|
for i_c, c in enumerate(output):
|
||||||
|
if self.sleep is not None:
|
||||||
|
await asyncio.sleep(self.sleep)
|
||||||
|
if (
|
||||||
|
self.error_on_chunk_number is not None
|
||||||
|
and i_c == self.error_on_chunk_number
|
||||||
|
):
|
||||||
|
raise FakeListChatModelError
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c, tool_calls=tool_calls))
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override
|
||||||
|
def _identifying_params(self) -> dict[str, Any]:
|
||||||
|
return {"responses": self.responses}
|
||||||
|
|
||||||
|
@override
|
||||||
|
# manually override batch to preserve batch ordering with no concurrency
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: list[Any],
|
||||||
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
if isinstance(config, list):
|
||||||
|
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
|
||||||
|
return [self.invoke(m, config, **kwargs) for m in inputs]
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: list[Any],
|
||||||
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
if isinstance(config, list):
|
||||||
|
# do Not use an async iterator here because need explicit ordering
|
||||||
|
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
|
||||||
|
# do Not use an async iterator here because need explicit ordering
|
||||||
|
return [await self.ainvoke(m, config, **kwargs) for m in inputs]
|
||||||
|
|
||||||
|
@override
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[
|
||||||
|
Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||||
|
],
|
||||||
|
*,
|
||||||
|
tool_choice: Optional[Union[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
"""Bind tools to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: Sequence of tools to bind to the model.
|
||||||
|
tool_choice: The tool to use. If "any" then any tool can be used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that returns a message.
|
||||||
|
"""
|
||||||
|
kwargs["tool_choice"] = tool_choice
|
||||||
|
raise super().bind(tools=tools, **kwargs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user