mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
fix(core): lint standard outputs branch (#32311)
This commit is contained in:
parent
9507d0f21c
commit
8cf97e838c
@ -66,7 +66,7 @@ class LLMManagerMixin:
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: Union[str, AIMessageChunk],
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
@ -265,7 +265,7 @@ class CallbackManagerMixin:
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -516,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -545,7 +545,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: Union[str, AIMessageChunk],
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
|
@ -11,7 +11,6 @@ from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import copy_context
|
||||
from dataclasses import is_dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -38,7 +37,13 @@ from langchain_core.callbacks.base import (
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
MessageV1,
|
||||
MessageV1Types,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
@ -249,17 +254,31 @@ def shielded(func: Func) -> Func:
|
||||
def _convert_llm_events(
|
||||
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
if event_name == "on_chat_model_start" and isinstance(args[1], list):
|
||||
for idx, item in enumerate(args[1]):
|
||||
if is_dataclass(item):
|
||||
args[1][idx] = item # convert to old message
|
||||
elif event_name == "on_llm_new_token" and is_dataclass(args[0]):
|
||||
kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0])
|
||||
args[0] = args[0].text
|
||||
elif event_name == "on_llm_end" and is_dataclass(args[0]):
|
||||
args[0] = LLMResult(
|
||||
if (
|
||||
event_name == "on_chat_model_start"
|
||||
and isinstance(args[1], list)
|
||||
and args[1]
|
||||
and isinstance(args[1][0], MessageV1Types)
|
||||
):
|
||||
batch = [
|
||||
convert_from_v1_message(item)
|
||||
for item in args[1]
|
||||
if isinstance(item, MessageV1Types)
|
||||
]
|
||||
args[1] = [batch] # type: ignore[index]
|
||||
elif (
|
||||
event_name == "on_llm_new_token"
|
||||
and "chunk" in kwargs
|
||||
and isinstance(kwargs["chunk"], MessageV1Types)
|
||||
):
|
||||
chunk = kwargs["chunk"]
|
||||
kwargs["chunk"] = ChatGenerationChunk(text=chunk.text, message=chunk)
|
||||
elif event_name == "on_llm_end" and isinstance(args[0], MessageV1Types):
|
||||
args[0] = LLMResult( # type: ignore[index]
|
||||
generations=[[ChatGeneration(text=args[0].text, message=args[0])]]
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def handle_event(
|
||||
@ -695,7 +714,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: Union[str, AIMessageChunk],
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
@ -793,7 +812,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: Union[str, AIMessageChunk],
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
@ -1380,7 +1399,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[CallbackManagerForLLMRun]:
|
||||
@ -1388,7 +1407,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
|
||||
Args:
|
||||
serialized (dict[str, Any]): The serialized LLM.
|
||||
messages (list[list[BaseMessage]]): The list of messages.
|
||||
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
|
||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
@ -1890,7 +1909,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[AsyncCallbackManagerForLLMRun]:
|
||||
@ -1898,7 +1917,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
|
||||
Args:
|
||||
serialized (dict[str, Any]): The serialized LLM.
|
||||
messages (list[list[BaseMessage]]): The list of messages.
|
||||
messages (list[list[BaseMessage | MessageV1]]): The list of messages.
|
||||
run_id (UUID, optional): The ID of the run. Defaults to None.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -12,6 +12,7 @@ from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.v1 import AIMessage, MessageV1
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
@ -32,7 +33,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running.
|
||||
@ -54,7 +55,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
|
@ -4,13 +4,15 @@ import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages.ai import UsageMetadata, add_usage
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
|
||||
@ -58,9 +60,17 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
||||
return str(self.usage_metadata)
|
||||
|
||||
@override
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessageV1], **kwargs: Any
|
||||
) -> None:
|
||||
"""Collect token usage."""
|
||||
# Check for usage_metadata (langchain-core >= 0.2.2)
|
||||
if isinstance(response, AIMessageV1):
|
||||
response = LLMResult(
|
||||
generations=[
|
||||
[ChatGeneration(message=convert_from_v1_message(response))]
|
||||
]
|
||||
)
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except IndexError:
|
||||
|
@ -4,6 +4,7 @@ from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
|
||||
|
||||
def _is_openai_data_block(block: dict) -> bool:
|
||||
@ -129,10 +130,7 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||
and _is_openai_data_block(block)
|
||||
):
|
||||
if formatted_message is message:
|
||||
if isinstance(message, BaseMessage):
|
||||
formatted_message = message.model_copy()
|
||||
else:
|
||||
formatted_message = copy.copy(message)
|
||||
formatted_message = message.model_copy()
|
||||
# Also shallow-copy content
|
||||
formatted_message.content = list(formatted_message.content)
|
||||
|
||||
@ -142,3 +140,37 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
def _normalize_messages_v1(messages: Sequence[MessageV1]) -> list[MessageV1]:
|
||||
"""Extend support for message formats.
|
||||
|
||||
Chat models implement support for images in OpenAI Chat Completions format, as well
|
||||
as other multimodal data as standard data blocks. This function extends support to
|
||||
audio and file data in OpenAI Chat Completions format by converting them to standard
|
||||
data blocks.
|
||||
"""
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
formatted_message = message
|
||||
if isinstance(message.content, list):
|
||||
for idx, block in enumerate(message.content):
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
# Subset to (PDF) files and audio, as most relevant chat models
|
||||
# support images in OAI format (and some may not yet support the
|
||||
# standard data block format)
|
||||
and block.get("type") in {"file", "input_audio"}
|
||||
and _is_openai_data_block(block) # type: ignore[arg-type]
|
||||
):
|
||||
if formatted_message is message:
|
||||
formatted_message = copy.copy(message)
|
||||
# Also shallow-copy content
|
||||
formatted_message.content = list(formatted_message.content)
|
||||
|
||||
formatted_message.content[idx] = ( # type: ignore[call-overload]
|
||||
_convert_openai_format_to_data_block(block) # type: ignore[arg-type]
|
||||
)
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
return formatted_messages
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cache
|
||||
from typing import (
|
||||
@ -25,6 +26,7 @@ from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
MessageLikeRepresentation,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
@ -164,6 +166,7 @@ class BaseLanguageModel(
|
||||
list[AnyMessage],
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
prompts: list[PromptValue],
|
||||
@ -198,6 +201,7 @@ class BaseLanguageModel(
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: list[PromptValue],
|
||||
@ -241,6 +245,7 @@ class BaseLanguageModel(
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@abstractmethod
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -261,6 +266,7 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -285,6 +291,7 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@abstractmethod
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
@ -305,6 +312,7 @@ class BaseLanguageModel(
|
||||
"""
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -360,6 +368,33 @@ class BaseLanguageModel(
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||
|
||||
@classmethod
|
||||
def _all_required_field_names(cls) -> set:
|
||||
"""DEPRECATED: Kept for backwards compatibility.
|
||||
|
@ -55,7 +55,6 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
convert_to_openai_image_block,
|
||||
get_buffer_string,
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
@ -1352,33 +1351,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[Sequence] = None,
|
||||
) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||
tool schemas.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[
|
||||
|
@ -6,7 +6,7 @@ import copy
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -22,29 +22,32 @@ from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from typing_extensions import TypeAlias, override
|
||||
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models._utils import _normalize_messages_v1
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
_get_token_ids_default_method,
|
||||
_get_verbosity,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
convert_to_openai_image_block,
|
||||
get_buffer_string,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.utils import (
|
||||
_convert_from_v1_message,
|
||||
convert_from_v1_message,
|
||||
convert_to_messages_v1,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
@ -58,6 +61,7 @@ from langchain_core.outputs import (
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.base import RunnableSerializable
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.utils.function_calling import (
|
||||
@ -85,14 +89,15 @@ def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
||||
metadata["status_code"] = response.status_code
|
||||
if hasattr(error, "request_id"):
|
||||
metadata["request_id"] = error.request_id
|
||||
generations = [AIMessageV1(content=[], response_metadata=metadata)]
|
||||
# Permit response_metadata without model_name, model_provider fields
|
||||
generations = [AIMessageV1(content=[], response_metadata=metadata)] # type: ignore[arg-type]
|
||||
else:
|
||||
generations = []
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
|
||||
def _format_for_tracing(messages: Sequence[MessageV1]) -> list[MessageV1]:
|
||||
"""Format messages for tracing in on_chat_model_start.
|
||||
|
||||
- Update image content blocks to OpenAI Chat Completions format (backward
|
||||
@ -112,7 +117,7 @@ def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
|
||||
# Update image content blocks to OpenAI # Chat Completions format.
|
||||
if (
|
||||
block["type"] == "image"
|
||||
and is_data_content_block(block)
|
||||
and is_data_content_block(block) # type: ignore[arg-type] # permit unnecessary runtime check
|
||||
and block.get("source_type") != "id"
|
||||
):
|
||||
if message_to_trace is message:
|
||||
@ -120,7 +125,9 @@ def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
|
||||
message_to_trace = copy.copy(message)
|
||||
message_to_trace.content = list(message_to_trace.content)
|
||||
|
||||
message_to_trace.content[idx] = convert_to_openai_image_block(block)
|
||||
# TODO: for tracing purposes we store non-standard types (OpenAI format)
|
||||
# in message content. Consider typing these block formats.
|
||||
message_to_trace.content[idx] = convert_to_openai_image_block(block) # type: ignore[arg-type, call-overload]
|
||||
else:
|
||||
pass
|
||||
messages_to_trace.append(message_to_trace)
|
||||
@ -180,7 +187,7 @@ def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) ->
|
||||
return ls_structured_output_format_dict
|
||||
|
||||
|
||||
class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC):
|
||||
"""Base class for chat models.
|
||||
|
||||
Key imperative methods:
|
||||
@ -278,12 +285,69 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
does not properly support streaming.
|
||||
"""
|
||||
|
||||
cache: Union[BaseCache, bool, None] = Field(default=None, exclude=True)
|
||||
"""Whether to cache the response.
|
||||
|
||||
* If true, will use the global cache.
|
||||
* If false, will not use a cache
|
||||
* If None, will use the global cache if it's set, otherwise no cache.
|
||||
* If instance of BaseCache, will use the provided cache.
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
"""Callbacks to add to the run trace."""
|
||||
tags: Optional[list[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@field_validator("verbose", mode="before")
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool: # noqa: FBT001
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
|
||||
Args:
|
||||
verbose: The verbosity setting to use.
|
||||
|
||||
Returns:
|
||||
The verbosity setting to use.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
StringPromptValue,
|
||||
)
|
||||
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
# for a much better schema.
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
list[MessageV1],
|
||||
]
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
@ -299,7 +363,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
return convert_to_messages_v1(model_input)
|
||||
msg = (
|
||||
f"Invalid input type {type(model_input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
"Must be a PromptValue, str, or list of Messages."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -371,7 +435,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
_format_for_tracing(messages),
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@ -382,27 +446,27 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
input_messages = _normalize_messages(messages)
|
||||
input_messages = _normalize_messages_v1(messages)
|
||||
|
||||
if self._should_stream(async_api=False, **kwargs):
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
try:
|
||||
for msg in self._stream(input_messages, **kwargs):
|
||||
run_manager.on_llm_new_token(msg)
|
||||
run_manager.on_llm_new_token(msg.text or "")
|
||||
chunks.append(msg)
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||
raise
|
||||
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||
full_message = add_ai_message_chunks(chunks[0], *chunks[1:]).to_message()
|
||||
else:
|
||||
try:
|
||||
msg = self._invoke(input_messages, **kwargs)
|
||||
full_message = self._invoke(input_messages, **kwargs)
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise
|
||||
|
||||
run_manager.on_llm_end(msg)
|
||||
return msg
|
||||
run_manager.on_llm_end(full_message)
|
||||
return full_message
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
@ -410,7 +474,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
) -> AIMessageV1:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input)
|
||||
ls_structured_output_format = kwargs.pop(
|
||||
@ -437,7 +501,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
_format_for_tracing(messages),
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@ -449,31 +513,31 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
# TODO: type openai image, audio, file types and permit in MessageV1
|
||||
input_messages = _normalize_messages(messages) # type: ignore[arg-type]
|
||||
input_messages = _normalize_messages_v1(messages)
|
||||
|
||||
if self._should_stream(async_api=True, **kwargs):
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
try:
|
||||
async for msg in self._astream(input_messages, **kwargs):
|
||||
await run_manager.on_llm_new_token(msg)
|
||||
await run_manager.on_llm_new_token(msg.text or "")
|
||||
chunks.append(msg)
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(
|
||||
e, response=_generate_response_from_error(e)
|
||||
)
|
||||
raise
|
||||
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||
full_message = add_ai_message_chunks(chunks[0], *chunks[1:]).to_message()
|
||||
else:
|
||||
try:
|
||||
msg = await self._ainvoke(input_messages, **kwargs)
|
||||
full_message = await self._ainvoke(input_messages, **kwargs)
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(
|
||||
e, response=_generate_response_from_error(e)
|
||||
)
|
||||
raise
|
||||
|
||||
await run_manager.on_llm_end(msg.to_message())
|
||||
return msg
|
||||
await run_manager.on_llm_end(full_message)
|
||||
return full_message
|
||||
|
||||
@override
|
||||
def stream(
|
||||
@ -515,7 +579,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
_format_for_tracing(messages),
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@ -530,9 +594,9 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
|
||||
try:
|
||||
# TODO: replace this with something for new messages
|
||||
input_messages = _normalize_messages(messages)
|
||||
input_messages = _normalize_messages_v1(messages)
|
||||
for msg in self._stream(input_messages, **kwargs):
|
||||
run_manager.on_llm_new_token(msg)
|
||||
run_manager.on_llm_new_token(msg.text or "")
|
||||
chunks.append(msg)
|
||||
yield msg
|
||||
except BaseException as e:
|
||||
@ -584,7 +648,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
_format_for_tracing(messages),
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@ -598,12 +662,12 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
input_messages = _normalize_messages_v1(messages)
|
||||
async for msg in self._astream(
|
||||
input_messages,
|
||||
**kwargs,
|
||||
):
|
||||
await run_manager.on_llm_new_token(msg)
|
||||
await run_manager.on_llm_new_token(msg.text or "")
|
||||
chunks.append(msg)
|
||||
yield msg
|
||||
except BaseException as e:
|
||||
@ -623,7 +687,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params = self.dump()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
@ -674,14 +738,14 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
) -> AIMessageV1:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
) -> AIMessageV1:
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._invoke,
|
||||
@ -724,8 +788,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
def dump(self, **kwargs: Any) -> dict: # noqa: ARG002
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@ -903,6 +966,38 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self.lc_attributes
|
||||
|
||||
def get_token_ids(self, text: str) -> list[int]:
|
||||
"""Return the ordered ids of the tokens in a text.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
A list of ids corresponding to the tokens in the text, in order they occur
|
||||
in the text.
|
||||
"""
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
return _get_token_ids_default_method(text)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input fits in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
return len(self.get_token_ids(text))
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
@ -923,7 +1018,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
messages_v0 = [_convert_from_v1_message(message) for message in messages]
|
||||
messages_v0 = [convert_from_v1_message(message) for message in messages]
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||
|
@ -327,7 +327,6 @@ def _create_message_from_message_type_v1(
|
||||
ValueError: if the message type is not one of "human", "user", "ai",
|
||||
"assistant", "tool", "system", or "developer".
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if tool_call_id is not None:
|
||||
@ -355,7 +354,7 @@ def _create_message_from_message_type_v1(
|
||||
else:
|
||||
kwargs["tool_calls"].append(tool_call)
|
||||
if message_type in {"human", "user"}:
|
||||
message = HumanMessageV1(content=content, **kwargs)
|
||||
message: MessageV1 = HumanMessageV1(content=content, **kwargs)
|
||||
elif message_type in {"ai", "assistant"}:
|
||||
message = AIMessageV1(content=content, **kwargs)
|
||||
elif message_type in {"system", "developer"}:
|
||||
@ -375,7 +374,7 @@ def _create_message_from_message_type_v1(
|
||||
return message
|
||||
|
||||
|
||||
def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
||||
def convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
||||
"""Compatibility layer to convert v1 messages to current messages.
|
||||
|
||||
Args:
|
||||
@ -469,7 +468,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
elif isinstance(message, MessageV1Types):
|
||||
message_ = _convert_from_v1_message(message)
|
||||
message_ = convert_from_v1_message(message)
|
||||
else:
|
||||
msg = f"Unsupported message type: {type(message)}"
|
||||
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||
@ -501,7 +500,7 @@ def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
|
||||
"""
|
||||
if isinstance(message, MessageV1Types):
|
||||
if isinstance(message, AIMessageChunkV1):
|
||||
message_ = message.to_message()
|
||||
message_: MessageV1 = message.to_message()
|
||||
else:
|
||||
message_ = message
|
||||
elif isinstance(message, str):
|
||||
|
@ -188,7 +188,7 @@ class AIMessage:
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIMessageChunk:
|
||||
class AIMessageChunk(AIMessage):
|
||||
"""A partial chunk of an AI message during streaming.
|
||||
|
||||
Represents a portion of an AI response that is delivered incrementally
|
||||
@ -203,43 +203,7 @@ class AIMessageChunk:
|
||||
usage_metadata: Optional metadata about token usage and costs.
|
||||
"""
|
||||
|
||||
type: Literal["ai_chunk"] = "ai_chunk"
|
||||
|
||||
name: Optional[str] = None
|
||||
"""An optional name for the message.
|
||||
|
||||
This can be used to provide a human-readable name for the message.
|
||||
|
||||
Usage of this field is optional, and whether it's used or not is up to the
|
||||
model implementation.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""Unique identifier for the message chunk.
|
||||
|
||||
If the provider assigns a meaningful ID, it should be used here.
|
||||
"""
|
||||
|
||||
lc_version: str = "v1"
|
||||
"""Encoding version for the message."""
|
||||
|
||||
content: list[types.ContentBlock] = field(init=False)
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
"""If provided, usage metadata for a message chunk, such as token counts.
|
||||
|
||||
These data represent incremental usage statistics, as opposed to a running total.
|
||||
"""
|
||||
|
||||
response_metadata: ResponseMetadata = field(init=False)
|
||||
"""Metadata about the response chunk.
|
||||
|
||||
This field should include non-standard data returned by the provider, such as
|
||||
response headers, service tiers, or log probabilities.
|
||||
"""
|
||||
|
||||
parsed: Optional[Union[dict[str, Any], BaseModel]] = None
|
||||
"""Auto-parsed message contents, if applicable."""
|
||||
type: Literal["ai_chunk"] = "ai_chunk" # type: ignore[assignment]
|
||||
|
||||
tool_call_chunks: list[types.ToolCallChunk] = field(init=False)
|
||||
|
||||
@ -342,7 +306,7 @@ class AIMessageChunk:
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
create_tool_call(
|
||||
name=chunk.get("name", ""),
|
||||
name=chunk.get("name") or "",
|
||||
args=args_,
|
||||
id=chunk.get("id", ""),
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages.tool import invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
@ -26,7 +26,7 @@ def parse_tool_call(
|
||||
partial: bool = False,
|
||||
strict: bool = False,
|
||||
return_id: bool = True,
|
||||
) -> Optional[ToolCall]:
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Parse a single tool call.
|
||||
|
||||
Args:
|
||||
|
@ -16,6 +16,7 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.exceptions import TracerException # noqa: F401
|
||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1
|
||||
from langchain_core.tracers.core import _TracerCore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -54,7 +55,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
@ -138,7 +139,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@ -190,7 +193,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
)
|
||||
|
||||
@override
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessage], *, run_id: UUID, **kwargs: Any
|
||||
) -> Run:
|
||||
"""End a trace for an LLM run.
|
||||
|
||||
Args:
|
||||
@ -562,7 +567,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -617,7 +622,9 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@ -646,7 +653,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
@override
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
response: Union[LLMResult, AIMessage],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -882,7 +889,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
|
||||
|
@ -18,6 +18,13 @@ from typing import (
|
||||
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
MessageV1,
|
||||
MessageV1Types,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@ -156,7 +163,7 @@ class _TracerCore(ABC):
|
||||
def _create_chat_model_run(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -181,6 +188,12 @@ class _TracerCore(ABC):
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
if isinstance(messages[0], MessageV1Types):
|
||||
# Convert from v1 messages to BaseMessage
|
||||
messages = [
|
||||
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
|
||||
]
|
||||
messages = cast("list[list[BaseMessage]]", messages)
|
||||
return Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
@ -230,7 +243,9 @@ class _TracerCore(ABC):
|
||||
self,
|
||||
token: str,
|
||||
run_id: UUID,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
parent_run_id: Optional[UUID] = None, # noqa: ARG002
|
||||
) -> Run:
|
||||
"""Append token event to LLM run and return the run."""
|
||||
@ -276,7 +291,15 @@ class _TracerCore(ABC):
|
||||
)
|
||||
return llm_run
|
||||
|
||||
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
|
||||
def _complete_llm_run(
|
||||
self, response: Union[LLMResult, AIMessage], run_id: UUID
|
||||
) -> Run:
|
||||
if isinstance(response, AIMessage):
|
||||
response = LLMResult(
|
||||
generations=[
|
||||
[ChatGeneration(message=convert_from_v1_message(response))]
|
||||
]
|
||||
)
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
if getattr(llm_run, "outputs", None) is None:
|
||||
llm_run.outputs = {}
|
||||
@ -558,7 +581,7 @@ class _TracerCore(ABC):
|
||||
self,
|
||||
run: Run, # noqa: ARG002
|
||||
token: str, # noqa: ARG002
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]], # noqa: ARG002
|
||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process new LLM token."""
|
||||
return None
|
||||
|
@ -19,6 +19,7 @@ from typing_extensions import NotRequired, TypedDict, override
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.outputs import (
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
@ -43,6 +44,8 @@ if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tracers.log_stream import LogEntry
|
||||
|
||||
@ -297,7 +300,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
@ -307,6 +310,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
# below cast is because type is converted in handle_event
|
||||
messages = cast("list[list[BaseMessage]]", messages)
|
||||
name_ = _assign_name(name, serialized)
|
||||
run_type = "chat_model"
|
||||
|
||||
@ -407,13 +412,18 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
run_info = self.run_map.get(run_id)
|
||||
chunk = cast(
|
||||
"Optional[Union[GenerationChunk, ChatGenerationChunk]]", chunk
|
||||
) # converted in handle_event
|
||||
chunk_: Union[GenerationChunk, BaseMessageChunk]
|
||||
|
||||
if run_info is None:
|
||||
@ -456,9 +466,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
@override
|
||||
async def on_llm_end(
|
||||
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
|
||||
self, response: Union[LLMResult, AIMessageV1], *, run_id: UUID, **kwargs: Any
|
||||
) -> None:
|
||||
"""End a trace for an LLM run."""
|
||||
response = cast("LLMResult", response) # converted in handle_event
|
||||
run_info = self.run_map.pop(run_id)
|
||||
inputs_ = run_info["inputs"]
|
||||
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client
|
||||
@ -21,11 +21,14 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages.utils import convert_from_v1_message
|
||||
from langchain_core.messages.v1 import MessageV1Types
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.v1 import AIMessageChunk, MessageV1
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -113,7 +116,7 @@ class LangChainTracer(BaseTracer):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[list[str]] = None,
|
||||
@ -140,6 +143,12 @@ class LangChainTracer(BaseTracer):
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
if isinstance(messages[0], MessageV1Types):
|
||||
# Convert from v1 messages to BaseMessage
|
||||
messages = [
|
||||
[convert_from_v1_message(msg) for msg in messages] # type: ignore[arg-type]
|
||||
]
|
||||
messages = cast("list[list[BaseMessage]]", messages)
|
||||
chat_model_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
@ -232,7 +241,9 @@ class LangChainTracer(BaseTracer):
|
||||
self,
|
||||
token: str,
|
||||
run_id: UUID,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
) -> Run:
|
||||
"""Append token event to LLM run and return the run."""
|
||||
|
@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages.v1 import AIMessageChunk
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
@ -485,7 +486,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
index = self._key_map_by_run_id.get(run.id)
|
||||
|
@ -10,6 +10,8 @@ from typing_extensions import override
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
|
||||
@ -18,7 +20,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -35,7 +37,9 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
|
@ -9,6 +9,7 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@ -285,7 +286,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
@ -165,7 +165,7 @@ async def test_callback_handlers() -> None:
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -16,6 +16,9 @@ from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.base import AsyncBaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
@ -84,8 +87,9 @@ async def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
tracer = FakeAsyncTracer()
|
||||
manager = AsyncCallbackManager(handlers=[tracer])
|
||||
messages: list[list[BaseMessage]] = [[HumanMessage(content="")]]
|
||||
run_managers = await manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
|
||||
serialized=SERIALIZED_CHAT, messages=messages
|
||||
)
|
||||
compare_run = Run(
|
||||
id=str(run_managers[0].run_id), # type: ignore[arg-type]
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
@ -20,6 +20,9 @@ from langchain_core.runnables import chain as as_runnable
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
@ -89,8 +92,10 @@ def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
# TODO: why is this annotation needed
|
||||
messages: list[list[BaseMessage]] = [[HumanMessage(content="")]]
|
||||
run_managers = manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
|
||||
serialized=SERIALIZED_CHAT, messages=messages
|
||||
)
|
||||
compare_run = Run(
|
||||
id=str(run_managers[0].run_id), # type: ignore[arg-type]
|
||||
|
@ -5,6 +5,7 @@ from collections.abc import AsyncIterator
|
||||
from typing import Any, Literal, Union, cast
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from typing_extensions import override
|
||||
|
||||
@ -44,7 +45,9 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
self.queue.put_nowait(token)
|
||||
|
||||
@override
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
async def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessage], **kwargs: Any
|
||||
) -> None:
|
||||
self.done.set()
|
||||
|
||||
@override
|
||||
|
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from typing_extensions import override
|
||||
|
||||
@ -75,7 +76,9 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.answer_reached = False
|
||||
|
||||
@override
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
async def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessage], **kwargs: Any
|
||||
) -> None:
|
||||
if self.answer_reached:
|
||||
self.done.set()
|
||||
|
||||
|
@ -2,11 +2,12 @@
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import base as base_callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages.v1 import AIMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from typing_extensions import override
|
||||
|
||||
@ -111,7 +112,7 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
|
||||
@override
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
response: Union[LLMResult, AIMessage],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
@ -6,6 +6,7 @@ from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
@ -281,7 +282,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
@ -6,6 +6,8 @@ from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from typing_extensions import override
|
||||
|
||||
@ -155,7 +157,7 @@ async def test_callback_handlers() -> None:
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[MessageV1]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -172,7 +174,9 @@ async def test_callback_handlers() -> None:
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
|
Loading…
Reference in New Issue
Block a user