Compare commits

...

9 Commits

Author SHA1 Message Date
Sydney Runkle
d3d2fddb53 remove run id from hot path 2025-05-13 19:43:40 -07:00
Sydney Runkle
16ea462b7d more removal of slow decorators 2025-05-13 19:42:47 -07:00
Sydney Runkle
48284ccbb4 removing id -> str costly field validator 2025-05-13 19:26:12 -07:00
Sydney Runkle
0da364fbc8 removing shielded and returning early if handlers is empty 2025-05-13 19:22:33 -07:00
Sydney Runkle
d80d208acd fix linting 2025-05-13 14:38:55 -07:00
Sydney Runkle
2687eb10db add helper function 2025-05-13 14:27:41 -07:00
Sydney Runkle
b226701b58 remove init 2025-05-13 10:50:46 -07:00
Sydney Runkle
94bd4bf313 do generation aggregation at the end 2025-05-13 10:49:24 -07:00
Sydney Runkle
31ba2844d3 remove low level init for serializable 2025-05-13 10:40:59 -07:00
7 changed files with 104 additions and 45 deletions

View File

@@ -69,7 +69,7 @@ class AgentAction(Serializable):
tool_input: The input to pass in to the Tool.
log: Additional information to log about the action.
"""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) # type: ignore[call-arg]
@classmethod
def is_lc_serializable(cls) -> bool:
@@ -149,7 +149,7 @@ class AgentFinish(Serializable):
def __init__(self, return_values: dict, log: str, **kwargs: Any):
"""Override init to support instantiation by position for backward compat."""
super().__init__(return_values=return_values, log=log, **kwargs)
super().__init__(return_values=return_values, log=log, **kwargs) # type: ignore[call-arg]
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@@ -520,6 +520,8 @@ class RunManager(BaseRunManager):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_text",
@@ -542,6 +544,8 @@ class RunManager(BaseRunManager):
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_retry",
@@ -601,6 +605,8 @@ class AsyncRunManager(BaseRunManager, ABC):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_text",
@@ -623,6 +629,8 @@ class AsyncRunManager(BaseRunManager, ABC):
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_retry",
@@ -675,6 +683,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_llm_new_token",
@@ -694,6 +704,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_llm_end",
@@ -718,6 +730,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
- response (LLMResult): The response which was generated before
the error occurred.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_llm_error",
@@ -750,7 +764,6 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
inheritable_metadata=self.inheritable_metadata,
)
@shielded
async def on_llm_new_token(
self,
token: str,
@@ -766,6 +779,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_llm_new_token",
@@ -786,6 +801,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_llm_end",
@@ -814,6 +831,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_llm_error",
@@ -836,6 +855,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_chain_end",
@@ -858,6 +879,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_chain_error",
@@ -879,6 +902,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_agent_action",
@@ -900,6 +925,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_agent_finish",
@@ -942,6 +969,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_chain_end",
@@ -965,6 +994,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_chain_error",
@@ -976,7 +1007,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
**kwargs,
)
@shielded
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received.
@@ -987,6 +1017,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_agent_action",
@@ -998,7 +1030,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
**kwargs,
)
@shielded
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received.
@@ -1009,6 +1040,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_agent_finish",
@@ -1035,6 +1068,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_tool_end",
@@ -1057,6 +1092,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_tool_error",
@@ -1089,7 +1126,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
inheritable_metadata=self.inheritable_metadata,
)
@shielded
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Async run when the tool ends running.
@@ -1097,6 +1133,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_tool_end",
@@ -1108,7 +1146,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
**kwargs,
)
@shielded
async def on_tool_error(
self,
error: BaseException,
@@ -1120,6 +1157,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_tool_error",
@@ -1146,6 +1185,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_retriever_end",
@@ -1168,6 +1209,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
handle_event(
self.handlers,
"on_retriever_error",
@@ -1213,6 +1256,8 @@ class AsyncCallbackManagerForRetrieverRun(
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_retriever_end",
@@ -1236,6 +1281,8 @@ class AsyncCallbackManagerForRetrieverRun(
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_retriever_error",
@@ -1521,6 +1568,8 @@ class CallbackManager(BaseCallbackManager):
.. versionadded:: 0.2.14
"""
if not self.handlers:
return
if kwargs:
msg = (
"The dispatcher API does not accept additional keyword arguments."
@@ -1998,6 +2047,8 @@ class AsyncCallbackManager(BaseCallbackManager):
.. versionadded:: 0.2.14
"""
if not self.handlers:
return
if run_id is None:
run_id = uuid.uuid4()

View File

@@ -66,6 +66,7 @@ from langchain_core.outputs import (
LLMResult,
RunInfo,
)
from langchain_core.outputs.chat_generation import merge_chat_generation_chunks
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
@@ -485,34 +486,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
chunks: list[ChatGenerationChunk] = []
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
try:
input_messages = _normalize_messages(messages)
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
for chunk in self._stream(input_messages, stop=stop, **kwargs):
if chunk.message.id is None:
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
chunk.message.id = run_id
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
chunks.append(chunk)
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e)
if generation:
generations = [[generation], generations_with_error_metadata]
chat_generation_chunk = merge_chat_generation_chunks(chunks)
if chat_generation_chunk:
generations = [
[chat_generation_chunk],
generations_with_error_metadata,
]
else:
generations = [generations_with_error_metadata]
run_manager.on_llm_error(e, response=LLMResult(generations=generations)) # type: ignore[arg-type]
run_manager.on_llm_error(
e,
response=LLMResult(generations=generations), # type: ignore[arg-type]
)
raise
generation = merge_chat_generation_chunks(chunks)
if generation is None:
err = ValueError("No generation chunks were returned")
run_manager.on_llm_error(err, response=LLMResult(generations=[]))
@@ -575,29 +583,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.rate_limiter:
await self.rate_limiter.aacquire(blocking=True)
generation: Optional[ChatGenerationChunk] = None
chunks: list[ChatGenerationChunk] = []
try:
input_messages = _normalize_messages(messages)
run_id = "-".join((_LC_ID_PREFIX, str(run_manager.run_id)))
async for chunk in self._astream(
input_messages,
stop=stop,
**kwargs,
):
if chunk.message.id is None:
chunk.message.id = f"{_LC_ID_PREFIX}-{run_manager.run_id}"
chunk.message.id = run_id
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
chunks.append(chunk)
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
except BaseException as e:
generations_with_error_metadata = _generate_response_from_error(e)
if generation:
generations = [[generation], generations_with_error_metadata]
chat_generation_chunk = merge_chat_generation_chunks(chunks)
if chat_generation_chunk:
generations = [[chat_generation_chunk], generations_with_error_metadata]
else:
generations = [generations_with_error_metadata]
await run_manager.on_llm_error(
@@ -606,7 +614,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise
if generation is None:
generation = merge_chat_generation_chunks(chunks)
if not generation:
err = ValueError("No generation chunks were returned")
await run_manager.on_llm_error(err, response=LLMResult(generations=[]))
raise err

View File

@@ -124,11 +124,6 @@ class Serializable(BaseModel, ABC):
as part of the serialized representation.
"""
# Remove default BaseModel init docstring.
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""""" # noqa: D419
super().__init__(*args, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Is this class serializable?

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator
from pydantic import ConfigDict, Field
from langchain_core.load.serializable import Serializable
from langchain_core.utils import get_bolded_text
@@ -52,7 +52,7 @@ class BaseMessage(Serializable):
model implementation.
"""
id: Optional[str] = None
id: Optional[str] = Field(default=None, coerce_numbers_to_str=True)
"""An optional unique identifier for the message. This should ideally be
provided by the provider/model which created the message."""
@@ -60,13 +60,6 @@ class BaseMessage(Serializable):
extra="allow",
)
@field_validator("id", mode="before")
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
"""Coerce the id field to a string."""
if id_value is not None:
return str(id_value)
return id_value
def __init__(
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
@@ -75,7 +68,7 @@ class BaseMessage(Serializable):
Args:
content: The string contents of the message.
"""
super().__init__(content=content, **kwargs)
super().__init__(content=content, **kwargs) # type: ignore[call-arg]
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@@ -2,17 +2,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Union
from typing import Literal, Union
from pydantic import model_validator
from typing_extensions import Self
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
from langchain_core.utils._merge import merge_dicts
if TYPE_CHECKING:
from typing_extensions import Self
class ChatGeneration(Generation):
"""A single chat generation output.
@@ -115,3 +113,16 @@ class ChatGenerationChunk(ChatGeneration):
)
msg = f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
raise TypeError(msg)
def merge_chat_generation_chunks(
chunks: list[ChatGenerationChunk],
) -> Union[ChatGenerationChunk, None]:
"""Merge a list of ChatGenerationChunks into a single ChatGenerationChunk."""
if not chunks:
return None
if len(chunks) == 1:
return chunks[0]
return chunks[0] + chunks[1:]

View File

@@ -98,7 +98,7 @@ def _get_jinja2_variables_from_template(template: str) -> set[str]:
raise ImportError(msg) from e
env = Environment() # noqa: S701
ast = env.parse(template)
return meta.find_undeclared_variables(ast)
return meta.find_undeclared_variables(ast) # type: ignore[no-untyped-call]
def mustache_formatter(template: str, /, **kwargs: Any) -> str: