mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
9 Commits
sr/fix-cod
...
sr/pydanti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3d2fddb53 | ||
|
|
16ea462b7d | ||
|
|
48284ccbb4 | ||
|
|
0da364fbc8 | ||
|
|
d80d208acd | ||
|
|
2687eb10db | ||
|
|
b226701b58 | ||
|
|
94bd4bf313 | ||
|
|
31ba2844d3 |
@@ -69,7 +69,7 @@ class AgentAction(Serializable):
|
||||
tool_input: The input to pass in to the Tool.
|
||||
log: Additional information to log about the action.
|
||||
"""
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -149,7 +149,7 @@ class AgentFinish(Serializable):
|
||||
|
||||
def __init__(self, return_values: dict, log: str, **kwargs: Any):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(return_values=return_values, log=log, **kwargs)
|
||||
super().__init__(return_values=return_values, log=log, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
||||
@@ -520,6 +520,8 @@ class RunManager(BaseRunManager):
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
@@ -542,6 +544,8 @@ class RunManager(BaseRunManager):
|
||||
retry_state (RetryCallState): The retry state.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_retry",
|
||||
@@ -601,6 +605,8 @@ class AsyncRunManager(BaseRunManager, ABC):
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_text",
|
||||
@@ -623,6 +629,8 @@ class AsyncRunManager(BaseRunManager, ABC):
|
||||
retry_state (RetryCallState): The retry state.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_retry",
|
||||
@@ -675,6 +683,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
The chunk. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
@@ -694,6 +704,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
response (LLMResult): The LLM result.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
@@ -718,6 +730,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
- response (LLMResult): The response which was generated before
|
||||
the error occurred.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
@@ -750,7 +764,6 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
@@ -766,6 +779,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
The chunk. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_new_token",
|
||||
@@ -786,6 +801,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
response (LLMResult): The LLM result.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_end",
|
||||
@@ -814,6 +831,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
|
||||
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_error",
|
||||
@@ -836,6 +855,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
@@ -858,6 +879,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
@@ -879,6 +902,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
@@ -900,6 +925,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
@@ -942,6 +969,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_end",
|
||||
@@ -965,6 +994,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_chain_error",
|
||||
@@ -976,7 +1007,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received.
|
||||
|
||||
@@ -987,6 +1017,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_action",
|
||||
@@ -998,7 +1030,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received.
|
||||
|
||||
@@ -1009,6 +1040,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
Returns:
|
||||
Any: The result of the callback.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_agent_finish",
|
||||
@@ -1035,6 +1068,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
output (Any): The output of the tool.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
@@ -1057,6 +1092,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
@@ -1089,7 +1126,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Async run when the tool ends running.
|
||||
|
||||
@@ -1097,6 +1133,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
output (Any): The output of the tool.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
@@ -1108,7 +1146,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
@@ -1120,6 +1157,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_error",
|
||||
@@ -1146,6 +1185,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
|
||||
documents (Sequence[Document]): The retrieved documents.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
@@ -1168,6 +1209,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
|
||||
error (BaseException): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
@@ -1213,6 +1256,8 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
documents (Sequence[Document]): The retrieved documents.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_end",
|
||||
@@ -1236,6 +1281,8 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
error (BaseException): The error.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_retriever_error",
|
||||
@@ -1521,6 +1568,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
if kwargs:
|
||||
msg = (
|
||||
"The dispatcher API does not accept additional keyword arguments."
|
||||
@@ -1998,6 +2047,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
if run_id is None:
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ from langchain_core.outputs import (
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_core.outputs.chat_generation import merge_chat_generation_chunks
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
@@ -485,34 +486,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
for chunk in self._stream(input_messages, stop=stop, **kwargs):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
chunks.append(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
except BaseException as e:
|
||||
generations_with_error_metadata = _generate_response_from_error(e)
|
||||
if generation:
|
||||
generations = [[generation], generations_with_error_metadata]
|
||||
chat_generation_chunk = merge_chat_generation_chunks(chunks)
|
||||
if chat_generation_chunk:
|
||||
generations = [
|
||||
[chat_generation_chunk],
|
||||
generations_with_error_metadata,
|
||||
]
|
||||
else:
|
||||
generations = [generations_with_error_metadata]
|
||||
run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type]
|
||||
run_manager.on_llm_error(
|
||||
e,
|
||||
response=LLMResult(generations=generations), # type: ignore[arg-type]
|
||||
)
|
||||
raise
|
||||
|
||||
generation = merge_chat_generation_chunks(chunks)
|
||||
if generation is None:
|
||||
err = ValueError("No generation chunks were returned")
|
||||
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
||||
@@ -575,29 +583,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
chunks: list[ChatGenerationChunk] = []
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
|
||||
async for chunk in self._astream(
|
||||
input_messages,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
chunks.append(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
except BaseException as e:
|
||||
generations_with_error_metadata = _generate_response_from_error(e)
|
||||
if generation:
|
||||
generations = [[generation], generations_with_error_metadata]
|
||||
chat_generation_chunk = merge_chat_generation_chunks(chunks)
|
||||
if chat_generation_chunk:
|
||||
generations = [[chat_generation_chunk], generations_with_error_metadata]
|
||||
else:
|
||||
generations = [generations_with_error_metadata]
|
||||
await run_manager.on_llm_error(
|
||||
@@ -606,7 +614,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
raise
|
||||
|
||||
if generation is None:
|
||||
generation = merge_chat_generation_chunks(chunks)
|
||||
if not generation:
|
||||
err = ValueError("No generation chunks were returned")
|
||||
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
|
||||
raise err
|
||||
|
||||
@@ -124,11 +124,6 @@ class Serializable(BaseModel, ABC):
|
||||
as part of the serialized representation.
|
||||
"""
|
||||
|
||||
# Remove default BaseModel init docstring.
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""""" # noqa: D419
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Is this class serializable?
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.utils import get_bolded_text
|
||||
@@ -52,7 +52,7 @@ class BaseMessage(Serializable):
|
||||
model implementation.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
|
||||
"""An optional unique identifier for the message. This should ideally be
|
||||
provided by the provider/model which created the message."""
|
||||
|
||||
@@ -60,13 +60,6 @@ class BaseMessage(Serializable):
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
|
||||
"""Coerce the id field to a string."""
|
||||
if id_value is not None:
|
||||
return str(id_value)
|
||||
return id_value
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
@@ -75,7 +68,7 @@ class BaseMessage(Serializable):
|
||||
Args:
|
||||
content: The string contents of the message.
|
||||
"""
|
||||
super().__init__(content=content, **kwargs)
|
||||
super().__init__(content=content, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
||||
@@ -2,17 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs.generation import Generation
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output.
|
||||
@@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration):
|
||||
)
|
||||
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def merge_chat_generation_chunks(
|
||||
chunks: list[ChatGenerationChunk],
|
||||
) -> Union[ChatGenerationChunk, None]:
|
||||
"""Merge a list of ChatGenerationChunks into a single ChatGenerationChunk."""
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
if len(chunks) == 1:
|
||||
return chunks[0]
|
||||
|
||||
return chunks[0] + chunks[1:]
|
||||
|
||||
@@ -98,7 +98,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
|
||||
raise ImportError(msg) from e
|
||||
env = Environment() # noqa: S701
|
||||
ast = env.parse(template)
|
||||
return meta.find_undeclared_variables(ast)
|
||||
return meta.find_undeclared_variables(ast) # type: ignore[no-untyped-call]
|
||||
|
||||
|
||||
def mustache_formatter(template: str, /, **kwargs: Any) -> str:
|
||||
|
||||
Reference in New Issue
Block a user