mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
Merge branch 'standard_outputs_copy' into mdrxy/ollama-v1
This commit is contained in:
commit
92c913c212
@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
@ -64,9 +66,11 @@ class LLMManagerMixin:
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: Union[str, AIMessageChunk],
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
@ -75,8 +79,8 @@ class LLMManagerMixin:
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
|
||||
generated chunk, containing content and other information.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
@ -84,7 +88,7 @@ class LLMManagerMixin:
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
response: Union[LLMResult, AIMessage],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -93,7 +97,7 @@ class LLMManagerMixin:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The response which was generated.
|
||||
response (LLMResult | AIMessage): The response which was generated.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
@ -261,7 +265,7 @@ class CallbackManagerMixin:
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -439,6 +443,9 @@ class BaseCallbackHandler(
|
||||
run_inline: bool = False
|
||||
"""Whether to run the callback inline."""
|
||||
|
||||
accepts_new_messages: bool = False
|
||||
"""Whether the callback accepts new message format."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
@ -509,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -538,9 +545,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: Union[str, AIMessageChunk],
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
@ -550,8 +559,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
Args:
|
||||
token (str): The new token.
|
||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
||||
containing content and other information.
|
||||
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
|
||||
generated chunk, containing content and other information.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
tags (Optional[list[str]]): The tags.
|
||||
@ -560,7 +569,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
response: Union[LLMResult, AIMessage],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -570,7 +579,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The response which was generated.
|
||||
response (LLMResult | AIMessage): The response which was generated.
|
||||
run_id (UUID): The run ID. This is the ID of the current run.
|
||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||
tags (Optional[list[str]]): The tags.
|
||||
@ -594,8 +603,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: The parent run ID. This is the ID of the parent run.
|
||||
tags: The tags.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
- response (LLMResult): The response which was generated before
|
||||
the error occurred.
|
||||
- response (LLMResult | AIMessage): The response which was generated
|
||||
before the error occurred.
|
||||
"""
|
||||
|
||||
async def on_chain_start(
|
||||
|
@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import copy_context
|
||||
from dataclasses import is_dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -37,6 +38,8 @@ from langchain_core.callbacks.base import (
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
@ -47,7 +50,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -243,6 +246,22 @@ def shielded(func: Func) -> Func:
|
||||
return cast("Func", wrapped)
|
||||
|
||||
|
||||
def _convert_llm_events(
|
||||
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
if event_name == "on_chat_model_start" and isinstance(args[1], list):
|
||||
for idx, item in enumerate(args[1]):
|
||||
if is_dataclass(item):
|
||||
args[1][idx] = item # convert to old message
|
||||
elif event_name == "on_llm_new_token" and is_dataclass(args[0]):
|
||||
kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0])
|
||||
args[0] = args[0].text
|
||||
elif event_name == "on_llm_end" and is_dataclass(args[0]):
|
||||
args[0] = LLMResult(
|
||||
generations=[[ChatGeneration(text=args[0].text, message=args[0])]]
|
||||
)
|
||||
|
||||
|
||||
def handle_event(
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
@ -271,6 +290,8 @@ def handle_event(
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
if not handler.accepts_new_messages:
|
||||
_convert_llm_events(event_name, args, kwargs)
|
||||
event = getattr(handler, event_name)(*args, **kwargs)
|
||||
if asyncio.iscoroutine(event):
|
||||
coros.append(event)
|
||||
@ -365,6 +386,8 @@ async def _ahandle_event_for_handler(
|
||||
) -> None:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
if not handler.accepts_new_messages:
|
||||
_convert_llm_events(event_name, args, kwargs)
|
||||
event = getattr(handler, event_name)
|
||||
if asyncio.iscoroutinefunction(event):
|
||||
await event(*args, **kwargs)
|
||||
@ -672,9 +695,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: Union[str, AIMessageChunk],
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@ -699,11 +724,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The LLM result.
|
||||
response (LLMResult | AIMessage): The LLM result.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
@ -729,8 +754,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
- response (LLMResult): The response which was generated before
|
||||
the error occurred.
|
||||
- response (LLMResult | AIMessage): The response which was generated
|
||||
before the error occurred.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
@ -768,9 +793,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: Union[str, AIMessageChunk],
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM generates a new token.
|
||||
@ -796,11 +823,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
async def on_llm_end(
|
||||
self, response: Union[LLMResult, AIMessage], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
Args:
|
||||
response (LLMResult): The LLM result.
|
||||
response (LLMResult | AIMessage): The LLM result.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
if not self.handlers:
|
||||
@ -827,11 +856,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
Args:
|
||||
error (Exception or KeyboardInterrupt): The error.
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
- response (LLMResult): The response which was generated before
|
||||
the error occurred.
|
||||
|
||||
|
||||
|
||||
- response (LLMResult | AIMessage): The response which was generated
|
||||
before the error occurred.
|
||||
"""
|
||||
if not self.handlers:
|
||||
return
|
||||
|
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
@ -128,7 +129,10 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||
and _is_openai_data_block(block)
|
||||
):
|
||||
if formatted_message is message:
|
||||
formatted_message = message.model_copy()
|
||||
if isinstance(message, BaseMessage):
|
||||
formatted_message = message.model_copy()
|
||||
else:
|
||||
formatted_message = copy.copy(message)
|
||||
# Also shallow-copy content
|
||||
formatted_message.content = list(formatted_message.content)
|
||||
|
||||
|
@ -28,6 +28,7 @@ from langchain_core.messages import (
|
||||
MessageLikeRepresentation,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]:
|
||||
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
|
||||
LanguageModelOutput = Union[BaseMessage, str]
|
||||
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
||||
LanguageModelOutputVar = TypeVar(
|
||||
"LanguageModelOutputVar", BaseMessage, str, AIMessageV1
|
||||
)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
|
1
libs/core/langchain_core/language_models/v1/__init__.py
Normal file
1
libs/core/langchain_core/language_models/v1/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""LangChain v1.0 chat models."""
|
909
libs/core/langchain_core/language_models/v1/chat_models.py
Normal file
909
libs/core/langchain_core/language_models/v1/chat_models.py
Normal file
@ -0,0 +1,909 @@
|
||||
"""Chat models for conversational AI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
convert_to_openai_image_block,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.utils import convert_to_messages_v1
|
||||
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||
from langchain_core.messages.v1 import MessageV1, add_ai_message_chunks
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
)
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
||||
if hasattr(error, "response"):
|
||||
response = error.response
|
||||
metadata: dict = {}
|
||||
if hasattr(response, "headers"):
|
||||
try:
|
||||
metadata["headers"] = dict(response.headers)
|
||||
except Exception:
|
||||
metadata["headers"] = None
|
||||
if hasattr(response, "status_code"):
|
||||
metadata["status_code"] = response.status_code
|
||||
if hasattr(error, "request_id"):
|
||||
metadata["request_id"] = error.request_id
|
||||
generations = [AIMessageV1(content=[], response_metadata=metadata)]
|
||||
else:
|
||||
generations = []
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
|
||||
"""Format messages for tracing in on_chat_model_start.
|
||||
|
||||
- Update image content blocks to OpenAI Chat Completions format (backward
|
||||
compatibility).
|
||||
- Add "type" key to content blocks that have a single key.
|
||||
|
||||
Args:
|
||||
messages: List of messages to format.
|
||||
|
||||
Returns:
|
||||
List of messages formatted for tracing.
|
||||
"""
|
||||
messages_to_trace = []
|
||||
for message in messages:
|
||||
message_to_trace = message
|
||||
for idx, block in enumerate(message.content):
|
||||
# Update image content blocks to OpenAI # Chat Completions format.
|
||||
if (
|
||||
block["type"] == "image"
|
||||
and is_data_content_block(block)
|
||||
and block.get("source_type") != "id"
|
||||
):
|
||||
if message_to_trace is message:
|
||||
# Shallow copy
|
||||
message_to_trace = copy.copy(message)
|
||||
message_to_trace.content = list(message_to_trace.content)
|
||||
|
||||
message_to_trace.content[idx] = convert_to_openai_image_block(block)
|
||||
else:
|
||||
pass
|
||||
messages_to_trace.append(message_to_trace)
|
||||
|
||||
return messages_to_trace
|
||||
|
||||
|
||||
def generate_from_stream(stream: Iterator[AIMessageChunkV1]) -> AIMessageV1:
|
||||
"""Generate from a stream.
|
||||
|
||||
Args:
|
||||
stream: Iterator of AIMessageChunkV1.
|
||||
|
||||
Returns:
|
||||
AIMessageV1: aggregated message.
|
||||
"""
|
||||
generation = next(stream, None)
|
||||
if generation:
|
||||
generation += list(stream)
|
||||
if generation is None:
|
||||
msg = "No generations found in stream."
|
||||
raise ValueError(msg)
|
||||
return generation.to_message()
|
||||
|
||||
|
||||
async def agenerate_from_stream(
|
||||
stream: AsyncIterator[AIMessageChunkV1],
|
||||
) -> AIMessageV1:
|
||||
"""Async generate from a stream.
|
||||
|
||||
Args:
|
||||
stream: Iterator of AIMessageChunkV1.
|
||||
|
||||
Returns:
|
||||
AIMessageV1: aggregated message.
|
||||
"""
|
||||
chunks = [chunk async for chunk in stream]
|
||||
return await run_in_executor(None, generate_from_stream, iter(chunks))
|
||||
|
||||
|
||||
def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> dict:
|
||||
if ls_structured_output_format:
|
||||
try:
|
||||
ls_structured_output_format_dict = {
|
||||
"ls_structured_output_format": {
|
||||
"kwargs": ls_structured_output_format.get("kwargs", {}),
|
||||
"schema": convert_to_json_schema(
|
||||
ls_structured_output_format["schema"]
|
||||
),
|
||||
}
|
||||
}
|
||||
except ValueError:
|
||||
ls_structured_output_format_dict = {}
|
||||
else:
|
||||
ls_structured_output_format_dict = {}
|
||||
|
||||
return ls_structured_output_format_dict
|
||||
|
||||
|
||||
class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||
"""Base class for chat models.
|
||||
|
||||
Key imperative methods:
|
||||
Methods that actually call the underlying model.
|
||||
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| Method | Input | Output | Description |
|
||||
+===========================+================================================================+=====================================================================+==================================================================================================+
|
||||
| `invoke` | str | list[dict | tuple | BaseMessage] | PromptValue | BaseMessage | A single chat model call. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `ainvoke` | ''' | BaseMessage | Defaults to running invoke in an async executor. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `stream` | ''' | Iterator[BaseMessageChunk] | Defaults to yielding output of invoke. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `astream` | ''' | AsyncIterator[BaseMessageChunk] | Defaults to yielding output of ainvoke. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `astream_events` | ''' | AsyncIterator[StreamEvent] | Event types: 'on_chat_model_start', 'on_chat_model_stream', 'on_chat_model_end'. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `batch` | list['''] | list[BaseMessage] | Defaults to running invoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `abatch` | list['''] | list[BaseMessage] | Defaults to running ainvoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `batch_as_completed` | list['''] | Iterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running invoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `abatch_as_completed` | list['''] | AsyncIterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running ainvoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
|
||||
This table provides a brief overview of the main imperative methods. Please see the base Runnable reference for full documentation.
|
||||
|
||||
Key declarative methods:
|
||||
Methods for creating another Runnable using the ChatModel.
|
||||
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| Method | Description |
|
||||
+==================================+===========================================================================================================+
|
||||
| `bind_tools` | Create ChatModel that can call tools. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `with_structured_output` | Create wrapper that structures model output using schema. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `with_retry` | Create wrapper that retries model calls on failure. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `with_fallbacks` | Create wrapper that falls back to other models on failure. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `configurable_fields` | Specify init args of the model that can be configured at runtime via the RunnableConfig. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the RunnableConfig. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
|
||||
This table provides a brief overview of the main declarative methods. Please see the reference for each method for full documentation.
|
||||
|
||||
Creating custom chat model:
|
||||
Custom chat model implementations should inherit from this class.
|
||||
Please reference the table below for information about which
|
||||
methods and properties are required or optional for implementations.
|
||||
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| Method/Property | Description | Required/Optional |
|
||||
+==================================+====================================================================+===================+
|
||||
| `_generate` | Use to generate a chat result from a prompt | Required |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_llm_type` (property) | Used to uniquely identify the type of the model. Used for logging. | Required |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_identifying_params` (property) | Represent model parameterization for tracing purposes. | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_stream` | Use to implement streaming | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_agenerate` | Use to implement a native async method | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_astream` | Use to implement async version of `_stream` | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
|
||||
Follow the guide for more information on how to implement a custom Chat Model:
|
||||
[Guide](https://python.langchain.com/docs/how_to/custom_chat_model/).
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||
"An optional rate limiter to use for limiting the number of requests."
|
||||
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]] = False
|
||||
"""Whether to disable streaming for this model.
|
||||
|
||||
If streaming is bypassed, then ``stream()``/``astream()``/``astream_events()`` will
|
||||
defer to ``invoke()``/``ainvoke()``.
|
||||
|
||||
- If True, will always bypass streaming case.
|
||||
- If ``'tool_calling'``, will bypass streaming case only when the model is called
|
||||
with a ``tools`` keyword argument. In other words, LangChain will automatically
|
||||
switch to non-streaming behavior (``invoke()``) only when the tools argument is
|
||||
provided. This offers the best of both worlds.
|
||||
- If False (default), will always use streaming case if available.
|
||||
|
||||
The main reason for this flag is that code might be written using ``.stream()`` and
|
||||
a user may want to swap out a given model for another model whose the implementation
|
||||
does not properly support streaming.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the output type for this runnable."""
|
||||
return AIMessageV1
|
||||
|
||||
def _convert_input(self, model_input: LanguageModelInput) -> list[MessageV1]:
|
||||
if isinstance(model_input, PromptValue):
|
||||
return model_input.to_messages(output_version="v1")
|
||||
if isinstance(model_input, str):
|
||||
return [HumanMessageV1(content=model_input)]
|
||||
if isinstance(model_input, Sequence):
|
||||
return convert_to_messages_v1(model_input)
|
||||
msg = (
|
||||
f"Invalid input type {type(model_input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def _should_stream(
|
||||
self,
|
||||
*,
|
||||
async_api: bool,
|
||||
run_manager: Optional[
|
||||
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Determine if a given model call should hit the streaming API."""
|
||||
sync_not_implemented = type(self)._stream == BaseChatModelV1._stream # noqa: SLF001
|
||||
async_not_implemented = type(self)._astream == BaseChatModelV1._astream # noqa: SLF001
|
||||
|
||||
# Check if streaming is implemented.
|
||||
if (not async_api) and sync_not_implemented:
|
||||
return False
|
||||
# Note, since async falls back to sync we check both here.
|
||||
if async_api and async_not_implemented and sync_not_implemented:
|
||||
return False
|
||||
|
||||
# Check if streaming has been disabled on this instance.
|
||||
if self.disable_streaming is True:
|
||||
return False
|
||||
# We assume tools are passed in via "tools" kwarg in all models.
|
||||
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
|
||||
return False
|
||||
|
||||
# Check if a runtime streaming flag has been passed in.
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# Check if any streaming callback handlers have been passed in.
|
||||
handlers = run_manager.handlers if run_manager else []
|
||||
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessageV1:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input)
|
||||
ls_structured_output_format = kwargs.pop(
|
||||
"ls_structured_output_format", None
|
||||
) or kwargs.pop("structured_output_format", None)
|
||||
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||
ls_structured_output_format
|
||||
)
|
||||
|
||||
params = self._get_invocation_params(**kwargs)
|
||||
options = {**kwargs, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
input_messages = _normalize_messages(messages)
|
||||
|
||||
if self._should_stream(async_api=False, **kwargs):
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
try:
|
||||
for msg in self._stream(input_messages, **kwargs):
|
||||
run_manager.on_llm_new_token(msg)
|
||||
chunks.append(msg)
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||
raise
|
||||
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||
else:
|
||||
try:
|
||||
msg = self._invoke(input_messages, **kwargs)
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise
|
||||
|
||||
run_manager.on_llm_end(msg)
|
||||
return msg
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input)
|
||||
ls_structured_output_format = kwargs.pop(
|
||||
"ls_structured_output_format", None
|
||||
) or kwargs.pop("structured_output_format", None)
|
||||
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||
ls_structured_output_format
|
||||
)
|
||||
|
||||
params = self._get_invocation_params(**kwargs)
|
||||
options = {**kwargs, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
# TODO: type openai image, audio, file types and permit in MessageV1
|
||||
input_messages = _normalize_messages(messages) # type: ignore[arg-type]
|
||||
|
||||
if self._should_stream(async_api=True, **kwargs):
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
try:
|
||||
async for msg in self._astream(input_messages, **kwargs):
|
||||
await run_manager.on_llm_new_token(msg)
|
||||
chunks.append(msg)
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(
|
||||
e, response=_generate_response_from_error(e)
|
||||
)
|
||||
raise
|
||||
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||
else:
|
||||
try:
|
||||
msg = await self._ainvoke(input_messages, **kwargs)
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(
|
||||
e, response=_generate_response_from_error(e)
|
||||
)
|
||||
raise
|
||||
|
||||
await run_manager.on_llm_end(msg.to_message())
|
||||
return msg
|
||||
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AIMessageChunkV1]:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
"AIMessageChunkV1",
|
||||
self.invoke(input, config=config, **kwargs),
|
||||
)
|
||||
else:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input)
|
||||
ls_structured_output_format = kwargs.pop(
|
||||
"ls_structured_output_format", None
|
||||
) or kwargs.pop("structured_output_format", None)
|
||||
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||
ls_structured_output_format
|
||||
)
|
||||
|
||||
params = self._get_invocation_params(**kwargs)
|
||||
options = {**kwargs, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
try:
|
||||
# TODO: replace this with something for new messages
|
||||
input_messages = _normalize_messages(messages)
|
||||
for msg in self._stream(input_messages, **kwargs):
|
||||
run_manager.on_llm_new_token(msg)
|
||||
chunks.append(msg)
|
||||
yield msg
|
||||
except BaseException as e:
|
||||
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||
raise
|
||||
|
||||
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||
run_manager.on_llm_end(msg)
|
||||
|
||||
@override
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AIMessageChunkV1]:
|
||||
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
||||
# No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
"AIMessageChunkV1",
|
||||
await self.ainvoke(input, config=config, **kwargs),
|
||||
)
|
||||
return
|
||||
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input)
|
||||
|
||||
ls_structured_output_format = kwargs.pop(
|
||||
"ls_structured_output_format", None
|
||||
) or kwargs.pop("structured_output_format", None)
|
||||
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||
ls_structured_output_format
|
||||
)
|
||||
|
||||
params = self._get_invocation_params(**kwargs)
|
||||
options = {**kwargs, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
{},
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.aacquire(blocking=True)
|
||||
|
||||
chunks: list[AIMessageChunkV1] = []
|
||||
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
async for msg in self._astream(
|
||||
input_messages,
|
||||
**kwargs,
|
||||
):
|
||||
await run_manager.on_llm_new_token(msg)
|
||||
chunks.append(msg)
|
||||
yield msg
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||
raise
|
||||
|
||||
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||
await run_manager.on_llm_end(msg)
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
|
||||
return {}
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
# get default provider from class name
|
||||
default_provider = self.__class__.__name__
|
||||
if default_provider.startswith("Chat"):
|
||||
default_provider = default_provider[4:].lower()
|
||||
elif default_provider.endswith("Chat"):
|
||||
default_provider = default_provider[:-4]
|
||||
default_provider = default_provider.lower()
|
||||
|
||||
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
|
||||
if stop:
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
# temperature
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||
ls_params["ls_temperature"] = self.temperature
|
||||
|
||||
# max_tokens
|
||||
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
|
||||
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
|
||||
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
|
||||
ls_params["ls_max_tokens"] = self.max_tokens
|
||||
|
||||
return ls_params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
params = {**params, **kwargs}
|
||||
return str(sorted(params.items()))
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._invoke,
|
||||
messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AIMessageChunkV1]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[MessageV1],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AIMessageChunkV1]:
|
||||
iterator = await run_in_executor(
|
||||
None,
|
||||
self._stream,
|
||||
messages,
|
||||
**kwargs,
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(
|
||||
None,
|
||||
next,
|
||||
iterator,
|
||||
done,
|
||||
)
|
||||
if item is done:
|
||||
break
|
||||
yield item # type: ignore[misc]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[
|
||||
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||
],
|
||||
*,
|
||||
tool_choice: Optional[Union[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tools to the model.
|
||||
|
||||
Args:
|
||||
tools: Sequence of tools to bind to the model.
|
||||
tool_choice: The tool to use. If "any" then any tool can be used.
|
||||
|
||||
Returns:
|
||||
A Runnable that returns a message.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[typing.Dict, type], # noqa: UP006
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
The output schema. Can be passed in as:
|
||||
- an OpenAI function/tool schema,
|
||||
- a JSON Schema,
|
||||
- a TypedDict class,
|
||||
- or a Pydantic class.
|
||||
If ``schema`` is a Pydantic class then the model output will be a
|
||||
Pydantic instance of that class, and the model-generated fields will be
|
||||
validated by the Pydantic class. Otherwise the model output will be a
|
||||
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
|
||||
for more on how to properly specify types and descriptions of
|
||||
schema fields when specifying a Pydantic or TypedDict class.
|
||||
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
|
||||
an instance of ``schema`` (i.e., a Pydantic object).
|
||||
|
||||
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||
|
||||
If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||
- ``"raw"``: BaseMessage
|
||||
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||
- ``"parsing_error"``: Optional[BaseException]
|
||||
|
||||
Example: Pydantic schema (include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = ChatModel(model="model-name", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
|
||||
# -> AnswerWithJustification(
|
||||
# answer='They weigh the same',
|
||||
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||
# )
|
||||
|
||||
Example: Pydantic schema (include_raw=True):
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = ChatModel(model="model-name", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: Dict schema (include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
dict_schema = convert_to_openai_tool(AnswerWithJustification)
|
||||
llm = ChatModel(model="model-name", temperature=0)
|
||||
structured_llm = llm.with_structured_output(dict_schema)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'answer': 'They weigh the same',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||
# }
|
||||
|
||||
.. versionchanged:: 0.2.26
|
||||
|
||||
Added support for TypedDict class.
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("method", None)
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
msg = f"Received unsupported arguments {kwargs}"
|
||||
raise ValueError(msg)
|
||||
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
|
||||
if type(self).bind_tools is BaseChatModelV1.bind_tools:
|
||||
msg = "with_structured_output is not implemented for this model."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
llm = self.bind_tools(
|
||||
[schema],
|
||||
tool_choice="any",
|
||||
ls_structured_output_format={
|
||||
"kwargs": {"method": "function_calling"},
|
||||
"schema": schema,
|
||||
},
|
||||
)
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
|
||||
)
|
||||
else:
|
||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
)
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
return llm | output_parser
|
||||
|
||||
|
||||
def _gen_info_and_msg_metadata(
|
||||
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||
) -> dict:
|
||||
return {
|
||||
**(generation.generation_info or {}),
|
||||
**generation.message.response_metadata,
|
||||
}
|
@ -300,6 +300,81 @@ def _create_message_from_message_type(
|
||||
return message
|
||||
|
||||
|
||||
def _create_message_from_message_type_v1(
|
||||
message_type: str,
|
||||
content: str,
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||
id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> MessageV1:
|
||||
"""Create a message from a message type and content string.
|
||||
|
||||
Args:
|
||||
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
|
||||
content: (str) the content string.
|
||||
name: (str) the name of the message. Default is None.
|
||||
tool_call_id: (str) the tool call id. Default is None.
|
||||
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
|
||||
id: (str) the id of the message. Default is None.
|
||||
kwargs: (dict[str, Any]) additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
a message of the appropriate type.
|
||||
|
||||
Raises:
|
||||
ValueError: if the message type is not one of "human", "user", "ai",
|
||||
"assistant", "tool", "system", or "developer".
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if tool_call_id is not None:
|
||||
kwargs["tool_call_id"] = tool_call_id
|
||||
if kwargs and (response_metadata := kwargs.pop("response_metadata", None)):
|
||||
kwargs["response_metadata"] = response_metadata
|
||||
if id is not None:
|
||||
kwargs["id"] = id
|
||||
if tool_calls is not None:
|
||||
kwargs["tool_calls"] = []
|
||||
for tool_call in tool_calls:
|
||||
# Convert OpenAI-format tool call to LangChain format.
|
||||
if "function" in tool_call:
|
||||
args = tool_call["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args, strict=False)
|
||||
kwargs["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": args,
|
||||
"id": tool_call["id"],
|
||||
"type": "tool_call",
|
||||
}
|
||||
)
|
||||
else:
|
||||
kwargs["tool_calls"].append(tool_call)
|
||||
if message_type in {"human", "user"}:
|
||||
message = HumanMessageV1(content=content, **kwargs)
|
||||
elif message_type in {"ai", "assistant"}:
|
||||
message = AIMessageV1(content=content, **kwargs)
|
||||
elif message_type in {"system", "developer"}:
|
||||
if message_type == "developer":
|
||||
kwargs["custom_role"] = "developer"
|
||||
message = SystemMessageV1(content=content, **kwargs)
|
||||
elif message_type == "tool":
|
||||
artifact = kwargs.pop("artifact", None)
|
||||
message = ToolMessageV1(content=content, artifact=artifact, **kwargs)
|
||||
else:
|
||||
msg = (
|
||||
f"Unexpected message type: '{message_type}'. Use one of 'human',"
|
||||
f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'."
|
||||
)
|
||||
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||
raise ValueError(msg)
|
||||
return message
|
||||
|
||||
|
||||
def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
||||
"""Compatibility layer to convert v1 messages to current messages.
|
||||
|
||||
@ -404,6 +479,61 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
return message_
|
||||
|
||||
|
||||
def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
The message format can be one of the following:
|
||||
|
||||
- BaseMessagePromptTemplate
|
||||
- BaseMessage
|
||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||
- dict: a message dict with role and content keys
|
||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||
|
||||
Args:
|
||||
message: a representation of a message in one of the supported formats.
|
||||
|
||||
Returns:
|
||||
an instance of a message or a message template.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the message type is not supported.
|
||||
ValueError: if the message dict does not contain the required keys.
|
||||
"""
|
||||
if isinstance(message, MessageV1Types):
|
||||
message_ = message
|
||||
elif isinstance(message, str):
|
||||
message_ = _create_message_from_message_type_v1("human", message)
|
||||
elif isinstance(message, Sequence) and len(message) == 2:
|
||||
# mypy doesn't realise this can't be a string given the previous branch
|
||||
message_type_str, template = message # type: ignore[misc]
|
||||
message_ = _create_message_from_message_type_v1(message_type_str, template)
|
||||
elif isinstance(message, dict):
|
||||
msg_kwargs = message.copy()
|
||||
try:
|
||||
try:
|
||||
msg_type = msg_kwargs.pop("role")
|
||||
except KeyError:
|
||||
msg_type = msg_kwargs.pop("type")
|
||||
# None msg content is not allowed
|
||||
msg_content = msg_kwargs.pop("content") or ""
|
||||
except KeyError as e:
|
||||
msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||
msg = create_message(
|
||||
message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
message_ = _create_message_from_message_type_v1(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
else:
|
||||
msg = f"Unsupported message type: {type(message)}"
|
||||
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return message_
|
||||
|
||||
|
||||
def convert_to_messages(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
) -> list[BaseMessage]:
|
||||
@ -423,6 +553,25 @@ def convert_to_messages(
|
||||
return [_convert_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def convert_to_messages_v1(
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
) -> list[MessageV1]:
|
||||
"""Convert a sequence of messages to a list of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to convert.
|
||||
|
||||
Returns:
|
||||
list of messages (BaseMessages).
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
if isinstance(messages, PromptValue):
|
||||
return messages.to_messages(output_version="v1")
|
||||
return [_convert_to_message_v1(m) for m in messages]
|
||||
|
||||
|
||||
def _runnable_support(func: Callable) -> Callable:
|
||||
@overload
|
||||
def wrapped(
|
||||
|
@ -366,6 +366,8 @@ def add_ai_message_chunks(
|
||||
left: AIMessageChunk, *others: AIMessageChunk
|
||||
) -> AIMessageChunk:
|
||||
"""Add multiple AIMessageChunks together."""
|
||||
if not others:
|
||||
return left
|
||||
content = merge_content(
|
||||
cast("list[str | dict[Any, Any]]", left.content),
|
||||
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
||||
|
@ -15,6 +15,8 @@ from langchain_core.language_models import (
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
|
||||
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||
from langchain_core.messages.v1 import MessageV1
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from tests.unit_tests.stubs import (
|
||||
_any_id_ai_message,
|
||||
@ -157,13 +159,13 @@ async def test_callback_handlers() -> None:
|
||||
"""Verify that model is implemented correctly with handlers working."""
|
||||
|
||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
def __init__(self, store: list[str]) -> None:
|
||||
def __init__(self, store: list[Union[str, AIMessageChunkV1]]) -> None:
|
||||
self.store = store
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@ -178,9 +180,11 @@ async def test_callback_handlers() -> None:
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: Union[str, AIMessageChunkV1],
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
chunk: Optional[
|
||||
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||
] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
@ -194,7 +198,7 @@ async def test_callback_handlers() -> None:
|
||||
]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
tokens: list[str] = []
|
||||
tokens: list[Union[str, AIMessageChunkV1]] = []
|
||||
# New model
|
||||
results = [
|
||||
chunk
|
||||
|
@ -463,7 +463,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||
)
|
||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||
"""Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
|
||||
@ -494,7 +494,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Whether to return logprobs."""
|
||||
top_logprobs: Optional[int] = None
|
||||
"""Number of most likely tokens to return at each token position, each with
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
if this parameter is used."""
|
||||
logit_bias: Optional[dict[int, int]] = None
|
||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||
@ -512,7 +512,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
Reasoning models only, like OpenAI o1, o3, and o4-mini.
|
||||
|
||||
Currently supported values are low, medium, and high. Reducing reasoning effort
|
||||
Currently supported values are low, medium, and high. Reducing reasoning effort
|
||||
can result in faster responses and fewer tokens used on reasoning in a response.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
@ -534,21 +534,21 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
"""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
Tiktoken is used to count the number of tokens in documents to constrain
|
||||
them to be under a certain limit. By default, when set to None, this will
|
||||
be the same as the embedding model name. However, there are some cases
|
||||
where you may want to use this Embedding class with a model name not
|
||||
supported by tiktoken. This can include when using Azure embeddings or
|
||||
when using one of the many model providers that expose an OpenAI-like
|
||||
API but with different models. In those cases, in order to avoid erroring
|
||||
when tiktoken is called, you can specify a model name to use here."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
|
||||
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
|
||||
``http_async_client`` as well if you'd like a custom client for async
|
||||
invocations.
|
||||
"""
|
||||
@ -580,21 +580,21 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
include_response_headers: bool = False
|
||||
"""Whether to include response headers in the output message ``response_metadata``.""" # noqa: E501
|
||||
disabled_params: Optional[dict[str, Any]] = Field(default=None)
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||
disabled for the given model.
|
||||
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
|
||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||
parameter and the value is either None, meaning that parameter should never be
|
||||
used, or it's a list of disabled values for the parameter.
|
||||
|
||||
For example, older models may not support the ``'parallel_tool_calls'`` parameter at
|
||||
all, in which case ``disabled_params={"parallel_tool_calls": None}`` can be passed
|
||||
in.
|
||||
|
||||
|
||||
If a parameter is disabled then it will not be used by default in any methods, e.g.
|
||||
in :meth:`~langchain_openai.chat_models.base.ChatOpenAI.with_structured_output`.
|
||||
However this does not prevent a user from directly passed in the parameter during
|
||||
invocation.
|
||||
invocation.
|
||||
"""
|
||||
|
||||
include: Optional[list[str]] = None
|
||||
|
Loading…
Reference in New Issue
Block a user