mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
Merge branch 'standard_outputs_copy' into mdrxy/ollama-v1
This commit is contained in:
commit
92c913c212
@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
|||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, MessageV1
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -64,9 +66,11 @@ class LLMManagerMixin:
|
|||||||
|
|
||||||
def on_llm_new_token(
|
def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: Union[str, AIMessageChunk],
|
||||||
*,
|
*,
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
chunk: Optional[
|
||||||
|
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||||
|
] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -75,8 +79,8 @@ class LLMManagerMixin:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): The new token.
|
token (str): The new token.
|
||||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
|
||||||
containing content and other information.
|
generated chunk, containing content and other information.
|
||||||
run_id (UUID): The run ID. This is the ID of the current run.
|
run_id (UUID): The run ID. This is the ID of the current run.
|
||||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||||
kwargs (Any): Additional keyword arguments.
|
kwargs (Any): Additional keyword arguments.
|
||||||
@ -84,7 +88,7 @@ class LLMManagerMixin:
|
|||||||
|
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
self,
|
self,
|
||||||
response: LLMResult,
|
response: Union[LLMResult, AIMessage],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -93,7 +97,7 @@ class LLMManagerMixin:
|
|||||||
"""Run when LLM ends running.
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response (LLMResult): The response which was generated.
|
response (LLMResult | AIMessage): The response which was generated.
|
||||||
run_id (UUID): The run ID. This is the ID of the current run.
|
run_id (UUID): The run ID. This is the ID of the current run.
|
||||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||||
kwargs (Any): Additional keyword arguments.
|
kwargs (Any): Additional keyword arguments.
|
||||||
@ -261,7 +265,7 @@ class CallbackManagerMixin:
|
|||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
messages: list[list[BaseMessage]],
|
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -439,6 +443,9 @@ class BaseCallbackHandler(
|
|||||||
run_inline: bool = False
|
run_inline: bool = False
|
||||||
"""Whether to run the callback inline."""
|
"""Whether to run the callback inline."""
|
||||||
|
|
||||||
|
accepts_new_messages: bool = False
|
||||||
|
"""Whether the callback accepts new message format."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_llm(self) -> bool:
|
def ignore_llm(self) -> bool:
|
||||||
"""Whether to ignore LLM callbacks."""
|
"""Whether to ignore LLM callbacks."""
|
||||||
@ -509,7 +516,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
async def on_chat_model_start(
|
async def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
messages: list[list[BaseMessage]],
|
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -538,9 +545,11 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: Union[str, AIMessageChunk],
|
||||||
*,
|
*,
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
chunk: Optional[
|
||||||
|
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||||
|
] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
@ -550,8 +559,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): The new token.
|
token (str): The new token.
|
||||||
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
|
chunk (GenerationChunk | ChatGenerationChunk | AIMessageChunk): The new
|
||||||
containing content and other information.
|
generated chunk, containing content and other information.
|
||||||
run_id (UUID): The run ID. This is the ID of the current run.
|
run_id (UUID): The run ID. This is the ID of the current run.
|
||||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||||
tags (Optional[list[str]]): The tags.
|
tags (Optional[list[str]]): The tags.
|
||||||
@ -560,7 +569,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_llm_end(
|
async def on_llm_end(
|
||||||
self,
|
self,
|
||||||
response: LLMResult,
|
response: Union[LLMResult, AIMessage],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -570,7 +579,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run when LLM ends running.
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response (LLMResult): The response which was generated.
|
response (LLMResult | AIMessage): The response which was generated.
|
||||||
run_id (UUID): The run ID. This is the ID of the current run.
|
run_id (UUID): The run ID. This is the ID of the current run.
|
||||||
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
|
||||||
tags (Optional[list[str]]): The tags.
|
tags (Optional[list[str]]): The tags.
|
||||||
@ -594,8 +603,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
parent_run_id: The parent run ID. This is the ID of the parent run.
|
parent_run_id: The parent run ID. This is the ID of the parent run.
|
||||||
tags: The tags.
|
tags: The tags.
|
||||||
kwargs (Any): Additional keyword arguments.
|
kwargs (Any): Additional keyword arguments.
|
||||||
- response (LLMResult): The response which was generated before
|
- response (LLMResult | AIMessage): The response which was generated
|
||||||
the error occurred.
|
before the error occurred.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
|
@ -11,6 +11,7 @@ from abc import ABC, abstractmethod
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
|
from dataclasses import is_dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -37,6 +38,8 @@ from langchain_core.callbacks.base import (
|
|||||||
)
|
)
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_core.messages.v1 import AIMessage, AIMessageChunk
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
from langchain_core.utils.env import env_var_is_set
|
from langchain_core.utils.env import env_var_is_set
|
||||||
|
|
||||||
@ -47,7 +50,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
from langchain_core.outputs import GenerationChunk
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -243,6 +246,22 @@ def shielded(func: Func) -> Func:
|
|||||||
return cast("Func", wrapped)
|
return cast("Func", wrapped)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_llm_events(
|
||||||
|
event_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
if event_name == "on_chat_model_start" and isinstance(args[1], list):
|
||||||
|
for idx, item in enumerate(args[1]):
|
||||||
|
if is_dataclass(item):
|
||||||
|
args[1][idx] = item # convert to old message
|
||||||
|
elif event_name == "on_llm_new_token" and is_dataclass(args[0]):
|
||||||
|
kwargs["chunk"] = ChatGenerationChunk(text=args[0].text, message=args[0])
|
||||||
|
args[0] = args[0].text
|
||||||
|
elif event_name == "on_llm_end" and is_dataclass(args[0]):
|
||||||
|
args[0] = LLMResult(
|
||||||
|
generations=[[ChatGeneration(text=args[0].text, message=args[0])]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_event(
|
def handle_event(
|
||||||
handlers: list[BaseCallbackHandler],
|
handlers: list[BaseCallbackHandler],
|
||||||
event_name: str,
|
event_name: str,
|
||||||
@ -271,6 +290,8 @@ def handle_event(
|
|||||||
if ignore_condition_name is None or not getattr(
|
if ignore_condition_name is None or not getattr(
|
||||||
handler, ignore_condition_name
|
handler, ignore_condition_name
|
||||||
):
|
):
|
||||||
|
if not handler.accepts_new_messages:
|
||||||
|
_convert_llm_events(event_name, args, kwargs)
|
||||||
event = getattr(handler, event_name)(*args, **kwargs)
|
event = getattr(handler, event_name)(*args, **kwargs)
|
||||||
if asyncio.iscoroutine(event):
|
if asyncio.iscoroutine(event):
|
||||||
coros.append(event)
|
coros.append(event)
|
||||||
@ -365,6 +386,8 @@ async def _ahandle_event_for_handler(
|
|||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||||
|
if not handler.accepts_new_messages:
|
||||||
|
_convert_llm_events(event_name, args, kwargs)
|
||||||
event = getattr(handler, event_name)
|
event = getattr(handler, event_name)
|
||||||
if asyncio.iscoroutinefunction(event):
|
if asyncio.iscoroutinefunction(event):
|
||||||
await event(*args, **kwargs)
|
await event(*args, **kwargs)
|
||||||
@ -672,9 +695,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
|
|
||||||
def on_llm_new_token(
|
def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: Union[str, AIMessageChunk],
|
||||||
*,
|
*,
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
chunk: Optional[
|
||||||
|
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||||
|
] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM generates a new token.
|
"""Run when LLM generates a new token.
|
||||||
@ -699,11 +724,11 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: Union[LLMResult, AIMessage], **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running.
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response (LLMResult): The LLM result.
|
response (LLMResult | AIMessage): The LLM result.
|
||||||
**kwargs (Any): Additional keyword arguments.
|
**kwargs (Any): Additional keyword arguments.
|
||||||
"""
|
"""
|
||||||
if not self.handlers:
|
if not self.handlers:
|
||||||
@ -729,8 +754,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
Args:
|
Args:
|
||||||
error (Exception or KeyboardInterrupt): The error.
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
kwargs (Any): Additional keyword arguments.
|
kwargs (Any): Additional keyword arguments.
|
||||||
- response (LLMResult): The response which was generated before
|
- response (LLMResult | AIMessage): The response which was generated
|
||||||
the error occurred.
|
before the error occurred.
|
||||||
"""
|
"""
|
||||||
if not self.handlers:
|
if not self.handlers:
|
||||||
return
|
return
|
||||||
@ -768,9 +793,11 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
|
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: Union[str, AIMessageChunk],
|
||||||
*,
|
*,
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
chunk: Optional[
|
||||||
|
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunk]
|
||||||
|
] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM generates a new token.
|
"""Run when LLM generates a new token.
|
||||||
@ -796,11 +823,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@shielded
|
@shielded
|
||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_llm_end(
|
||||||
|
self, response: Union[LLMResult, AIMessage], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
"""Run when LLM ends running.
|
"""Run when LLM ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response (LLMResult): The LLM result.
|
response (LLMResult | AIMessage): The LLM result.
|
||||||
**kwargs (Any): Additional keyword arguments.
|
**kwargs (Any): Additional keyword arguments.
|
||||||
"""
|
"""
|
||||||
if not self.handlers:
|
if not self.handlers:
|
||||||
@ -827,11 +856,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
Args:
|
Args:
|
||||||
error (Exception or KeyboardInterrupt): The error.
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
kwargs (Any): Additional keyword arguments.
|
kwargs (Any): Additional keyword arguments.
|
||||||
- response (LLMResult): The response which was generated before
|
- response (LLMResult | AIMessage): The response which was generated
|
||||||
the error occurred.
|
before the error occurred.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self.handlers:
|
if not self.handlers:
|
||||||
return
|
return
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import re
|
import re
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -128,7 +129,10 @@ def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
|||||||
and _is_openai_data_block(block)
|
and _is_openai_data_block(block)
|
||||||
):
|
):
|
||||||
if formatted_message is message:
|
if formatted_message is message:
|
||||||
formatted_message = message.model_copy()
|
if isinstance(message, BaseMessage):
|
||||||
|
formatted_message = message.model_copy()
|
||||||
|
else:
|
||||||
|
formatted_message = copy.copy(message)
|
||||||
# Also shallow-copy content
|
# Also shallow-copy content
|
||||||
formatted_message.content = list(formatted_message.content)
|
formatted_message.content = list(formatted_message.content)
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ from langchain_core.messages import (
|
|||||||
MessageLikeRepresentation,
|
MessageLikeRepresentation,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||||
from langchain_core.utils import get_pydantic_field_names
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
@ -85,7 +86,9 @@ def _get_token_ids_default_method(text: str) -> list[int]:
|
|||||||
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
|
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
|
||||||
LanguageModelOutput = Union[BaseMessage, str]
|
LanguageModelOutput = Union[BaseMessage, str]
|
||||||
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
||||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
|
LanguageModelOutputVar = TypeVar(
|
||||||
|
"LanguageModelOutputVar", BaseMessage, str, AIMessageV1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
def _get_verbosity() -> bool:
|
||||||
|
1
libs/core/langchain_core/language_models/v1/__init__.py
Normal file
1
libs/core/langchain_core/language_models/v1/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""LangChain v1.0 chat models."""
|
909
libs/core/langchain_core/language_models/v1/chat_models.py
Normal file
909
libs/core/langchain_core/language_models/v1/chat_models.py
Normal file
@ -0,0 +1,909 @@
|
|||||||
|
"""Chat models for conversational AI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import typing
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
|
from operator import itemgetter
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManager,
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManager,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models._utils import _normalize_messages
|
||||||
|
from langchain_core.language_models.base import (
|
||||||
|
BaseLanguageModel,
|
||||||
|
LangSmithParams,
|
||||||
|
LanguageModelInput,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
convert_to_openai_image_block,
|
||||||
|
is_data_content_block,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.utils import convert_to_messages_v1
|
||||||
|
from langchain_core.messages.v1 import AIMessage as AIMessageV1
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
|
from langchain_core.messages.v1 import HumanMessage as HumanMessageV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1, add_ai_message_chunks
|
||||||
|
from langchain_core.outputs import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatGenerationChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
from langchain_core.rate_limiters import BaseRateLimiter
|
||||||
|
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||||
|
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||||
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
|
from langchain_core.utils.function_calling import (
|
||||||
|
convert_to_json_schema,
|
||||||
|
convert_to_openai_tool,
|
||||||
|
)
|
||||||
|
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
|
from langchain_core.runnables import Runnable, RunnableConfig
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_response_from_error(error: BaseException) -> list[AIMessageV1]:
|
||||||
|
if hasattr(error, "response"):
|
||||||
|
response = error.response
|
||||||
|
metadata: dict = {}
|
||||||
|
if hasattr(response, "headers"):
|
||||||
|
try:
|
||||||
|
metadata["headers"] = dict(response.headers)
|
||||||
|
except Exception:
|
||||||
|
metadata["headers"] = None
|
||||||
|
if hasattr(response, "status_code"):
|
||||||
|
metadata["status_code"] = response.status_code
|
||||||
|
if hasattr(error, "request_id"):
|
||||||
|
metadata["request_id"] = error.request_id
|
||||||
|
generations = [AIMessageV1(content=[], response_metadata=metadata)]
|
||||||
|
else:
|
||||||
|
generations = []
|
||||||
|
|
||||||
|
return generations
|
||||||
|
|
||||||
|
|
||||||
|
def _format_for_tracing(messages: list[MessageV1]) -> list[MessageV1]:
|
||||||
|
"""Format messages for tracing in on_chat_model_start.
|
||||||
|
|
||||||
|
- Update image content blocks to OpenAI Chat Completions format (backward
|
||||||
|
compatibility).
|
||||||
|
- Add "type" key to content blocks that have a single key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages formatted for tracing.
|
||||||
|
"""
|
||||||
|
messages_to_trace = []
|
||||||
|
for message in messages:
|
||||||
|
message_to_trace = message
|
||||||
|
for idx, block in enumerate(message.content):
|
||||||
|
# Update image content blocks to OpenAI # Chat Completions format.
|
||||||
|
if (
|
||||||
|
block["type"] == "image"
|
||||||
|
and is_data_content_block(block)
|
||||||
|
and block.get("source_type") != "id"
|
||||||
|
):
|
||||||
|
if message_to_trace is message:
|
||||||
|
# Shallow copy
|
||||||
|
message_to_trace = copy.copy(message)
|
||||||
|
message_to_trace.content = list(message_to_trace.content)
|
||||||
|
|
||||||
|
message_to_trace.content[idx] = convert_to_openai_image_block(block)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
messages_to_trace.append(message_to_trace)
|
||||||
|
|
||||||
|
return messages_to_trace
|
||||||
|
|
||||||
|
|
||||||
|
def generate_from_stream(stream: Iterator[AIMessageChunkV1]) -> AIMessageV1:
|
||||||
|
"""Generate from a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream: Iterator of AIMessageChunkV1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIMessageV1: aggregated message.
|
||||||
|
"""
|
||||||
|
generation = next(stream, None)
|
||||||
|
if generation:
|
||||||
|
generation += list(stream)
|
||||||
|
if generation is None:
|
||||||
|
msg = "No generations found in stream."
|
||||||
|
raise ValueError(msg)
|
||||||
|
return generation.to_message()
|
||||||
|
|
||||||
|
|
||||||
|
async def agenerate_from_stream(
|
||||||
|
stream: AsyncIterator[AIMessageChunkV1],
|
||||||
|
) -> AIMessageV1:
|
||||||
|
"""Async generate from a stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream: Iterator of AIMessageChunkV1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIMessageV1: aggregated message.
|
||||||
|
"""
|
||||||
|
chunks = [chunk async for chunk in stream]
|
||||||
|
return await run_in_executor(None, generate_from_stream, iter(chunks))
|
||||||
|
|
||||||
|
|
||||||
|
def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> dict:
|
||||||
|
if ls_structured_output_format:
|
||||||
|
try:
|
||||||
|
ls_structured_output_format_dict = {
|
||||||
|
"ls_structured_output_format": {
|
||||||
|
"kwargs": ls_structured_output_format.get("kwargs", {}),
|
||||||
|
"schema": convert_to_json_schema(
|
||||||
|
ls_structured_output_format["schema"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except ValueError:
|
||||||
|
ls_structured_output_format_dict = {}
|
||||||
|
else:
|
||||||
|
ls_structured_output_format_dict = {}
|
||||||
|
|
||||||
|
return ls_structured_output_format_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
|
||||||
|
"""Base class for chat models.
|
||||||
|
|
||||||
|
Key imperative methods:
|
||||||
|
Methods that actually call the underlying model.
|
||||||
|
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| Method | Input | Output | Description |
|
||||||
|
+===========================+================================================================+=====================================================================+==================================================================================================+
|
||||||
|
| `invoke` | str | list[dict | tuple | BaseMessage] | PromptValue | BaseMessage | A single chat model call. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `ainvoke` | ''' | BaseMessage | Defaults to running invoke in an async executor. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `stream` | ''' | Iterator[BaseMessageChunk] | Defaults to yielding output of invoke. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `astream` | ''' | AsyncIterator[BaseMessageChunk] | Defaults to yielding output of ainvoke. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `astream_events` | ''' | AsyncIterator[StreamEvent] | Event types: 'on_chat_model_start', 'on_chat_model_stream', 'on_chat_model_end'. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `batch` | list['''] | list[BaseMessage] | Defaults to running invoke in concurrent threads. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `abatch` | list['''] | list[BaseMessage] | Defaults to running ainvoke in concurrent threads. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `batch_as_completed` | list['''] | Iterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running invoke in concurrent threads. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
| `abatch_as_completed` | list['''] | AsyncIterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running ainvoke in concurrent threads. |
|
||||||
|
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||||
|
|
||||||
|
This table provides a brief overview of the main imperative methods. Please see the base Runnable reference for full documentation.
|
||||||
|
|
||||||
|
Key declarative methods:
|
||||||
|
Methods for creating another Runnable using the ChatModel.
|
||||||
|
|
||||||
|
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||||
|
| Method | Description |
|
||||||
|
+==================================+===========================================================================================================+
|
||||||
|
| `bind_tools` | Create ChatModel that can call tools. |
|
||||||
|
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||||
|
| `with_structured_output` | Create wrapper that structures model output using schema. |
|
||||||
|
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||||
|
| `with_retry` | Create wrapper that retries model calls on failure. |
|
||||||
|
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||||
|
| `with_fallbacks` | Create wrapper that falls back to other models on failure. |
|
||||||
|
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||||
|
| `configurable_fields` | Specify init args of the model that can be configured at runtime via the RunnableConfig. |
|
||||||
|
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||||
|
| `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the RunnableConfig. |
|
||||||
|
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||||
|
|
||||||
|
This table provides a brief overview of the main declarative methods. Please see the reference for each method for full documentation.
|
||||||
|
|
||||||
|
Creating custom chat model:
|
||||||
|
Custom chat model implementations should inherit from this class.
|
||||||
|
Please reference the table below for information about which
|
||||||
|
methods and properties are required or optional for implementations.
|
||||||
|
|
||||||
|
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||||
|
| Method/Property | Description | Required/Optional |
|
||||||
|
+==================================+====================================================================+===================+
|
||||||
|
| `_generate` | Use to generate a chat result from a prompt | Required |
|
||||||
|
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||||
|
| `_llm_type` (property) | Used to uniquely identify the type of the model. Used for logging. | Required |
|
||||||
|
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||||
|
| `_identifying_params` (property) | Represent model parameterization for tracing purposes. | Optional |
|
||||||
|
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||||
|
| `_stream` | Use to implement streaming | Optional |
|
||||||
|
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||||
|
| `_agenerate` | Use to implement a native async method | Optional |
|
||||||
|
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||||
|
| `_astream` | Use to implement async version of `_stream` | Optional |
|
||||||
|
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||||
|
|
||||||
|
Follow the guide for more information on how to implement a custom Chat Model:
|
||||||
|
[Guide](https://python.langchain.com/docs/how_to/custom_chat_model/).
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||||
|
"An optional rate limiter to use for limiting the number of requests."
|
||||||
|
|
||||||
|
disable_streaming: Union[bool, Literal["tool_calling"]] = False
|
||||||
|
"""Whether to disable streaming for this model.
|
||||||
|
|
||||||
|
If streaming is bypassed, then ``stream()``/``astream()``/``astream_events()`` will
|
||||||
|
defer to ``invoke()``/``ainvoke()``.
|
||||||
|
|
||||||
|
- If True, will always bypass streaming case.
|
||||||
|
- If ``'tool_calling'``, will bypass streaming case only when the model is called
|
||||||
|
with a ``tools`` keyword argument. In other words, LangChain will automatically
|
||||||
|
switch to non-streaming behavior (``invoke()``) only when the tools argument is
|
||||||
|
provided. This offers the best of both worlds.
|
||||||
|
- If False (default), will always use streaming case if available.
|
||||||
|
|
||||||
|
The main reason for this flag is that code might be written using ``.stream()`` and
|
||||||
|
a user may want to swap out a given model for another model whose the implementation
|
||||||
|
does not properly support streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Runnable methods ---
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override
|
||||||
|
def OutputType(self) -> Any:
|
||||||
|
"""Get the output type for this runnable."""
|
||||||
|
return AIMessageV1
|
||||||
|
|
||||||
|
def _convert_input(self, model_input: LanguageModelInput) -> list[MessageV1]:
|
||||||
|
if isinstance(model_input, PromptValue):
|
||||||
|
return model_input.to_messages(output_version="v1")
|
||||||
|
if isinstance(model_input, str):
|
||||||
|
return [HumanMessageV1(content=model_input)]
|
||||||
|
if isinstance(model_input, Sequence):
|
||||||
|
return convert_to_messages_v1(model_input)
|
||||||
|
msg = (
|
||||||
|
f"Invalid input type {type(model_input)}. "
|
||||||
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def _should_stream(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
async_api: bool,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if a given model call should hit the streaming API."""
|
||||||
|
sync_not_implemented = type(self)._stream == BaseChatModelV1._stream # noqa: SLF001
|
||||||
|
async_not_implemented = type(self)._astream == BaseChatModelV1._astream # noqa: SLF001
|
||||||
|
|
||||||
|
# Check if streaming is implemented.
|
||||||
|
if (not async_api) and sync_not_implemented:
|
||||||
|
return False
|
||||||
|
# Note, since async falls back to sync we check both here.
|
||||||
|
if async_api and async_not_implemented and sync_not_implemented:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if streaming has been disabled on this instance.
|
||||||
|
if self.disable_streaming is True:
|
||||||
|
return False
|
||||||
|
# We assume tools are passed in via "tools" kwarg in all models.
|
||||||
|
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if a runtime streaming flag has been passed in.
|
||||||
|
if "stream" in kwargs:
|
||||||
|
return kwargs["stream"]
|
||||||
|
|
||||||
|
# Check if any streaming callback handlers have been passed in.
|
||||||
|
handlers = run_manager.handlers if run_manager else []
|
||||||
|
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AIMessageV1:
|
||||||
|
config = ensure_config(config)
|
||||||
|
messages = self._convert_input(input)
|
||||||
|
ls_structured_output_format = kwargs.pop(
|
||||||
|
"ls_structured_output_format", None
|
||||||
|
) or kwargs.pop("structured_output_format", None)
|
||||||
|
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||||
|
ls_structured_output_format
|
||||||
|
)
|
||||||
|
|
||||||
|
params = self._get_invocation_params(**kwargs)
|
||||||
|
options = {**kwargs, **ls_structured_output_format_dict}
|
||||||
|
inheritable_metadata = {
|
||||||
|
**(config.get("metadata") or {}),
|
||||||
|
**self._get_ls_params(**kwargs),
|
||||||
|
}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
inheritable_metadata,
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = callback_manager.on_chat_model_start(
|
||||||
|
{},
|
||||||
|
[_format_for_tracing(messages)],
|
||||||
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.rate_limiter:
|
||||||
|
self.rate_limiter.acquire(blocking=True)
|
||||||
|
|
||||||
|
input_messages = _normalize_messages(messages)
|
||||||
|
|
||||||
|
if self._should_stream(async_api=False, **kwargs):
|
||||||
|
chunks: list[AIMessageChunkV1] = []
|
||||||
|
try:
|
||||||
|
for msg in self._stream(input_messages, **kwargs):
|
||||||
|
run_manager.on_llm_new_token(msg)
|
||||||
|
chunks.append(msg)
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||||
|
raise
|
||||||
|
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
msg = self._invoke(input_messages, **kwargs)
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_llm_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
run_manager.on_llm_end(msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AIMessage:
|
||||||
|
config = ensure_config(config)
|
||||||
|
messages = self._convert_input(input)
|
||||||
|
ls_structured_output_format = kwargs.pop(
|
||||||
|
"ls_structured_output_format", None
|
||||||
|
) or kwargs.pop("structured_output_format", None)
|
||||||
|
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||||
|
ls_structured_output_format
|
||||||
|
)
|
||||||
|
|
||||||
|
params = self._get_invocation_params(**kwargs)
|
||||||
|
options = {**kwargs, **ls_structured_output_format_dict}
|
||||||
|
inheritable_metadata = {
|
||||||
|
**(config.get("metadata") or {}),
|
||||||
|
**self._get_ls_params(**kwargs),
|
||||||
|
}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
inheritable_metadata,
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||||
|
{},
|
||||||
|
[_format_for_tracing(messages)],
|
||||||
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.rate_limiter:
|
||||||
|
await self.rate_limiter.aacquire(blocking=True)
|
||||||
|
|
||||||
|
# TODO: type openai image, audio, file types and permit in MessageV1
|
||||||
|
input_messages = _normalize_messages(messages) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
if self._should_stream(async_api=True, **kwargs):
|
||||||
|
chunks: list[AIMessageChunkV1] = []
|
||||||
|
try:
|
||||||
|
async for msg in self._astream(input_messages, **kwargs):
|
||||||
|
await run_manager.on_llm_new_token(msg)
|
||||||
|
chunks.append(msg)
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_llm_error(
|
||||||
|
e, response=_generate_response_from_error(e)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
msg = await self._ainvoke(input_messages, **kwargs)
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_llm_error(
|
||||||
|
e, response=_generate_response_from_error(e)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
await run_manager.on_llm_end(msg.to_message())
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@override
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[AIMessageChunkV1]:
|
||||||
|
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||||
|
# model doesn't implement streaming, so use default implementation
|
||||||
|
yield cast(
|
||||||
|
"AIMessageChunkV1",
|
||||||
|
self.invoke(input, config=config, **kwargs),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = ensure_config(config)
|
||||||
|
messages = self._convert_input(input)
|
||||||
|
ls_structured_output_format = kwargs.pop(
|
||||||
|
"ls_structured_output_format", None
|
||||||
|
) or kwargs.pop("structured_output_format", None)
|
||||||
|
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||||
|
ls_structured_output_format
|
||||||
|
)
|
||||||
|
|
||||||
|
params = self._get_invocation_params(**kwargs)
|
||||||
|
options = {**kwargs, **ls_structured_output_format_dict}
|
||||||
|
inheritable_metadata = {
|
||||||
|
**(config.get("metadata") or {}),
|
||||||
|
**self._get_ls_params(**kwargs),
|
||||||
|
}
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
inheritable_metadata,
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = callback_manager.on_chat_model_start(
|
||||||
|
{},
|
||||||
|
[_format_for_tracing(messages)],
|
||||||
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks: list[AIMessageChunkV1] = []
|
||||||
|
|
||||||
|
if self.rate_limiter:
|
||||||
|
self.rate_limiter.acquire(blocking=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: replace this with something for new messages
|
||||||
|
input_messages = _normalize_messages(messages)
|
||||||
|
for msg in self._stream(input_messages, **kwargs):
|
||||||
|
run_manager.on_llm_new_token(msg)
|
||||||
|
chunks.append(msg)
|
||||||
|
yield msg
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||||
|
run_manager.on_llm_end(msg)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: LanguageModelInput,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[AIMessageChunkV1]:
|
||||||
|
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
||||||
|
# No async or sync stream is implemented, so fall back to ainvoke
|
||||||
|
yield cast(
|
||||||
|
"AIMessageChunkV1",
|
||||||
|
await self.ainvoke(input, config=config, **kwargs),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
config = ensure_config(config)
|
||||||
|
messages = self._convert_input(input)
|
||||||
|
|
||||||
|
ls_structured_output_format = kwargs.pop(
|
||||||
|
"ls_structured_output_format", None
|
||||||
|
) or kwargs.pop("structured_output_format", None)
|
||||||
|
ls_structured_output_format_dict = _format_ls_structured_output(
|
||||||
|
ls_structured_output_format
|
||||||
|
)
|
||||||
|
|
||||||
|
params = self._get_invocation_params(**kwargs)
|
||||||
|
options = {**kwargs, **ls_structured_output_format_dict}
|
||||||
|
inheritable_metadata = {
|
||||||
|
**(config.get("metadata") or {}),
|
||||||
|
**self._get_ls_params(**kwargs),
|
||||||
|
}
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
config.get("callbacks"),
|
||||||
|
self.callbacks,
|
||||||
|
self.verbose,
|
||||||
|
config.get("tags"),
|
||||||
|
self.tags,
|
||||||
|
inheritable_metadata,
|
||||||
|
self.metadata,
|
||||||
|
)
|
||||||
|
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||||
|
{},
|
||||||
|
[_format_for_tracing(messages)],
|
||||||
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.rate_limiter:
|
||||||
|
await self.rate_limiter.aacquire(blocking=True)
|
||||||
|
|
||||||
|
chunks: list[AIMessageChunkV1] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_messages = _normalize_messages(messages)
|
||||||
|
async for msg in self._astream(
|
||||||
|
input_messages,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
await run_manager.on_llm_new_token(msg)
|
||||||
|
chunks.append(msg)
|
||||||
|
yield msg
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_llm_error(e, response=_generate_response_from_error(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
msg = add_ai_message_chunks(chunks[0], *chunks[1:])
|
||||||
|
await run_manager.on_llm_end(msg)
|
||||||
|
|
||||||
|
# --- Custom methods ---
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict:
|
||||||
|
params = self.dict()
|
||||||
|
params["stop"] = stop
|
||||||
|
return {**params, **kwargs}
|
||||||
|
|
||||||
|
def _get_ls_params(
|
||||||
|
self,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LangSmithParams:
|
||||||
|
"""Get standard params for tracing."""
|
||||||
|
# get default provider from class name
|
||||||
|
default_provider = self.__class__.__name__
|
||||||
|
if default_provider.startswith("Chat"):
|
||||||
|
default_provider = default_provider[4:].lower()
|
||||||
|
elif default_provider.endswith("Chat"):
|
||||||
|
default_provider = default_provider[:-4]
|
||||||
|
default_provider = default_provider.lower()
|
||||||
|
|
||||||
|
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="chat")
|
||||||
|
if stop:
|
||||||
|
ls_params["ls_stop"] = stop
|
||||||
|
|
||||||
|
# model
|
||||||
|
if hasattr(self, "model") and isinstance(self.model, str):
|
||||||
|
ls_params["ls_model_name"] = self.model
|
||||||
|
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||||
|
ls_params["ls_model_name"] = self.model_name
|
||||||
|
|
||||||
|
# temperature
|
||||||
|
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||||
|
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||||
|
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||||
|
ls_params["ls_temperature"] = self.temperature
|
||||||
|
|
||||||
|
# max_tokens
|
||||||
|
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
|
||||||
|
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
|
||||||
|
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
|
||||||
|
ls_params["ls_max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
|
return ls_params
|
||||||
|
|
||||||
|
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
|
||||||
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
return str(sorted(params.items()))
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AIMessage:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def _ainvoke(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AIMessage:
|
||||||
|
return await run_in_executor(
|
||||||
|
None,
|
||||||
|
self._invoke,
|
||||||
|
messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[AIMessageChunkV1]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: list[MessageV1],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[AIMessageChunkV1]:
|
||||||
|
iterator = await run_in_executor(
|
||||||
|
None,
|
||||||
|
self._stream,
|
||||||
|
messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
done = object()
|
||||||
|
while True:
|
||||||
|
item = await run_in_executor(
|
||||||
|
None,
|
||||||
|
next,
|
||||||
|
iterator,
|
||||||
|
done,
|
||||||
|
)
|
||||||
|
if item is done:
|
||||||
|
break
|
||||||
|
yield item # type: ignore[misc]
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def dict(self, **kwargs: Any) -> dict:
|
||||||
|
"""Return a dictionary of the LLM."""
|
||||||
|
starter_dict = dict(self._identifying_params)
|
||||||
|
starter_dict["_type"] = self._llm_type
|
||||||
|
return starter_dict
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[
|
||||||
|
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||||
|
],
|
||||||
|
*,
|
||||||
|
tool_choice: Optional[Union[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
"""Bind tools to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: Sequence of tools to bind to the model.
|
||||||
|
tool_choice: The tool to use. If "any" then any tool can be used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that returns a message.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Union[typing.Dict, type], # noqa: UP006
|
||||||
|
*,
|
||||||
|
include_raw: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
|
||||||
|
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema:
|
||||||
|
The output schema. Can be passed in as:
|
||||||
|
- an OpenAI function/tool schema,
|
||||||
|
- a JSON Schema,
|
||||||
|
- a TypedDict class,
|
||||||
|
- or a Pydantic class.
|
||||||
|
If ``schema`` is a Pydantic class then the model output will be a
|
||||||
|
Pydantic instance of that class, and the model-generated fields will be
|
||||||
|
validated by the Pydantic class. Otherwise the model output will be a
|
||||||
|
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
|
||||||
|
for more on how to properly specify types and descriptions of
|
||||||
|
schema fields when specifying a Pydantic or TypedDict class.
|
||||||
|
|
||||||
|
include_raw:
|
||||||
|
If False then only the parsed structured output is returned. If
|
||||||
|
an error occurs during model output parsing it will be raised. If True
|
||||||
|
then both the raw model response (a BaseMessage) and the parsed model
|
||||||
|
response will be returned. If an error occurs during output parsing it
|
||||||
|
will be caught and returned as well. The final output is always a dict
|
||||||
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||||
|
|
||||||
|
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
|
||||||
|
an instance of ``schema`` (i.e., a Pydantic object).
|
||||||
|
|
||||||
|
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||||
|
|
||||||
|
If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||||
|
- ``"raw"``: BaseMessage
|
||||||
|
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||||
|
- ``"parsing_error"``: Optional[BaseException]
|
||||||
|
|
||||||
|
Example: Pydantic schema (include_raw=False):
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class AnswerWithJustification(BaseModel):
|
||||||
|
'''An answer to the user question along with justification for the answer.'''
|
||||||
|
answer: str
|
||||||
|
justification: str
|
||||||
|
|
||||||
|
llm = ChatModel(model="model-name", temperature=0)
|
||||||
|
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||||
|
|
||||||
|
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||||
|
|
||||||
|
# -> AnswerWithJustification(
|
||||||
|
# answer='They weigh the same',
|
||||||
|
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||||
|
# )
|
||||||
|
|
||||||
|
Example: Pydantic schema (include_raw=True):
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class AnswerWithJustification(BaseModel):
|
||||||
|
'''An answer to the user question along with justification for the answer.'''
|
||||||
|
answer: str
|
||||||
|
justification: str
|
||||||
|
|
||||||
|
llm = ChatModel(model="model-name", temperature=0)
|
||||||
|
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
|
||||||
|
|
||||||
|
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||||
|
# -> {
|
||||||
|
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||||
|
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||||
|
# 'parsing_error': None
|
||||||
|
# }
|
||||||
|
|
||||||
|
Example: Dict schema (include_raw=False):
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
|
class AnswerWithJustification(BaseModel):
|
||||||
|
'''An answer to the user question along with justification for the answer.'''
|
||||||
|
answer: str
|
||||||
|
justification: str
|
||||||
|
|
||||||
|
dict_schema = convert_to_openai_tool(AnswerWithJustification)
|
||||||
|
llm = ChatModel(model="model-name", temperature=0)
|
||||||
|
structured_llm = llm.with_structured_output(dict_schema)
|
||||||
|
|
||||||
|
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||||
|
# -> {
|
||||||
|
# 'answer': 'They weigh the same',
|
||||||
|
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||||
|
# }
|
||||||
|
|
||||||
|
.. versionchanged:: 0.2.26
|
||||||
|
|
||||||
|
Added support for TypedDict class.
|
||||||
|
""" # noqa: E501
|
||||||
|
_ = kwargs.pop("method", None)
|
||||||
|
_ = kwargs.pop("strict", None)
|
||||||
|
if kwargs:
|
||||||
|
msg = f"Received unsupported arguments {kwargs}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
|
JsonOutputKeyToolsParser,
|
||||||
|
PydanticToolsParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
if type(self).bind_tools is BaseChatModelV1.bind_tools:
|
||||||
|
msg = "with_structured_output is not implemented for this model."
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
llm = self.bind_tools(
|
||||||
|
[schema],
|
||||||
|
tool_choice="any",
|
||||||
|
ls_structured_output_format={
|
||||||
|
"kwargs": {"method": "function_calling"},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
|
tools=[cast("TypeBaseModel", schema)], first_tool_only=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name=key_name, first_tool_only=True
|
||||||
|
)
|
||||||
|
if include_raw:
|
||||||
|
parser_assign = RunnablePassthrough.assign(
|
||||||
|
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||||
|
)
|
||||||
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
|
[parser_none], exception_key="parsing_error"
|
||||||
|
)
|
||||||
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
|
return llm | output_parser
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_info_and_msg_metadata(
|
||||||
|
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
**(generation.generation_info or {}),
|
||||||
|
**generation.message.response_metadata,
|
||||||
|
}
|
@ -300,6 +300,81 @@ def _create_message_from_message_type(
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _create_message_from_message_type_v1(
|
||||||
|
message_type: str,
|
||||||
|
content: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
tool_call_id: Optional[str] = None,
|
||||||
|
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> MessageV1:
|
||||||
|
"""Create a message from a message type and content string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
|
||||||
|
content: (str) the content string.
|
||||||
|
name: (str) the name of the message. Default is None.
|
||||||
|
tool_call_id: (str) the tool call id. Default is None.
|
||||||
|
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
|
||||||
|
id: (str) the id of the message. Default is None.
|
||||||
|
kwargs: (dict[str, Any]) additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a message of the appropriate type.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the message type is not one of "human", "user", "ai",
|
||||||
|
"assistant", "tool", "system", or "developer".
|
||||||
|
"""
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if name is not None:
|
||||||
|
kwargs["name"] = name
|
||||||
|
if tool_call_id is not None:
|
||||||
|
kwargs["tool_call_id"] = tool_call_id
|
||||||
|
if kwargs and (response_metadata := kwargs.pop("response_metadata", None)):
|
||||||
|
kwargs["response_metadata"] = response_metadata
|
||||||
|
if id is not None:
|
||||||
|
kwargs["id"] = id
|
||||||
|
if tool_calls is not None:
|
||||||
|
kwargs["tool_calls"] = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
# Convert OpenAI-format tool call to LangChain format.
|
||||||
|
if "function" in tool_call:
|
||||||
|
args = tool_call["function"]["arguments"]
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args, strict=False)
|
||||||
|
kwargs["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call["function"]["name"],
|
||||||
|
"args": args,
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kwargs["tool_calls"].append(tool_call)
|
||||||
|
if message_type in {"human", "user"}:
|
||||||
|
message = HumanMessageV1(content=content, **kwargs)
|
||||||
|
elif message_type in {"ai", "assistant"}:
|
||||||
|
message = AIMessageV1(content=content, **kwargs)
|
||||||
|
elif message_type in {"system", "developer"}:
|
||||||
|
if message_type == "developer":
|
||||||
|
kwargs["custom_role"] = "developer"
|
||||||
|
message = SystemMessageV1(content=content, **kwargs)
|
||||||
|
elif message_type == "tool":
|
||||||
|
artifact = kwargs.pop("artifact", None)
|
||||||
|
message = ToolMessageV1(content=content, artifact=artifact, **kwargs)
|
||||||
|
else:
|
||||||
|
msg = (
|
||||||
|
f"Unexpected message type: '{message_type}'. Use one of 'human',"
|
||||||
|
f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'."
|
||||||
|
)
|
||||||
|
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||||
|
raise ValueError(msg)
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
def _convert_from_v1_message(message: MessageV1) -> BaseMessage:
|
||||||
"""Compatibility layer to convert v1 messages to current messages.
|
"""Compatibility layer to convert v1 messages to current messages.
|
||||||
|
|
||||||
@ -404,6 +479,61 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
|||||||
return message_
|
return message_
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_message_v1(message: MessageLikeRepresentation) -> MessageV1:
|
||||||
|
"""Instantiate a message from a variety of message formats.
|
||||||
|
|
||||||
|
The message format can be one of the following:
|
||||||
|
|
||||||
|
- BaseMessagePromptTemplate
|
||||||
|
- BaseMessage
|
||||||
|
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||||
|
- dict: a message dict with role and content keys
|
||||||
|
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: a representation of a message in one of the supported formats.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an instance of a message or a message template.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: if the message type is not supported.
|
||||||
|
ValueError: if the message dict does not contain the required keys.
|
||||||
|
"""
|
||||||
|
if isinstance(message, MessageV1Types):
|
||||||
|
message_ = message
|
||||||
|
elif isinstance(message, str):
|
||||||
|
message_ = _create_message_from_message_type_v1("human", message)
|
||||||
|
elif isinstance(message, Sequence) and len(message) == 2:
|
||||||
|
# mypy doesn't realise this can't be a string given the previous branch
|
||||||
|
message_type_str, template = message # type: ignore[misc]
|
||||||
|
message_ = _create_message_from_message_type_v1(message_type_str, template)
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
msg_kwargs = message.copy()
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
msg_type = msg_kwargs.pop("role")
|
||||||
|
except KeyError:
|
||||||
|
msg_type = msg_kwargs.pop("type")
|
||||||
|
# None msg content is not allowed
|
||||||
|
msg_content = msg_kwargs.pop("content") or ""
|
||||||
|
except KeyError as e:
|
||||||
|
msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||||
|
msg = create_message(
|
||||||
|
message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
|
||||||
|
)
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
message_ = _create_message_from_message_type_v1(
|
||||||
|
msg_type, msg_content, **msg_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
msg = f"Unsupported message type: {type(message)}"
|
||||||
|
msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
return message_
|
||||||
|
|
||||||
|
|
||||||
def convert_to_messages(
|
def convert_to_messages(
|
||||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
@ -423,6 +553,25 @@ def convert_to_messages(
|
|||||||
return [_convert_to_message(m) for m in messages]
|
return [_convert_to_message(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_messages_v1(
|
||||||
|
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||||
|
) -> list[MessageV1]:
|
||||||
|
"""Convert a sequence of messages to a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of messages (BaseMessages).
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
|
||||||
|
if isinstance(messages, PromptValue):
|
||||||
|
return messages.to_messages(output_version="v1")
|
||||||
|
return [_convert_to_message_v1(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
def _runnable_support(func: Callable) -> Callable:
|
def _runnable_support(func: Callable) -> Callable:
|
||||||
@overload
|
@overload
|
||||||
def wrapped(
|
def wrapped(
|
||||||
|
@ -366,6 +366,8 @@ def add_ai_message_chunks(
|
|||||||
left: AIMessageChunk, *others: AIMessageChunk
|
left: AIMessageChunk, *others: AIMessageChunk
|
||||||
) -> AIMessageChunk:
|
) -> AIMessageChunk:
|
||||||
"""Add multiple AIMessageChunks together."""
|
"""Add multiple AIMessageChunks together."""
|
||||||
|
if not others:
|
||||||
|
return left
|
||||||
content = merge_content(
|
content = merge_content(
|
||||||
cast("list[str | dict[Any, Any]]", left.content),
|
cast("list[str | dict[Any, Any]]", left.content),
|
||||||
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
*(cast("list[str | dict[Any, Any]]", o.content) for o in others),
|
||||||
|
@ -15,6 +15,8 @@ from langchain_core.language_models import (
|
|||||||
ParrotFakeChatModel,
|
ParrotFakeChatModel,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1
|
||||||
|
from langchain_core.messages.v1 import MessageV1
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
from tests.unit_tests.stubs import (
|
from tests.unit_tests.stubs import (
|
||||||
_any_id_ai_message,
|
_any_id_ai_message,
|
||||||
@ -157,13 +159,13 @@ async def test_callback_handlers() -> None:
|
|||||||
"""Verify that model is implemented correctly with handlers working."""
|
"""Verify that model is implemented correctly with handlers working."""
|
||||||
|
|
||||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||||
def __init__(self, store: list[str]) -> None:
|
def __init__(self, store: list[Union[str, AIMessageChunkV1]]) -> None:
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
async def on_chat_model_start(
|
async def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
messages: list[list[BaseMessage]],
|
messages: Union[list[list[BaseMessage]], list[list[MessageV1]]],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -178,9 +180,11 @@ async def test_callback_handlers() -> None:
|
|||||||
@override
|
@override
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: Union[str, AIMessageChunkV1],
|
||||||
*,
|
*,
|
||||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
chunk: Optional[
|
||||||
|
Union[GenerationChunk, ChatGenerationChunk, AIMessageChunkV1]
|
||||||
|
] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
@ -194,7 +198,7 @@ async def test_callback_handlers() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
tokens: list[str] = []
|
tokens: list[Union[str, AIMessageChunkV1]] = []
|
||||||
# New model
|
# New model
|
||||||
results = [
|
results = [
|
||||||
chunk
|
chunk
|
||||||
|
@ -463,7 +463,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
|
||||||
)
|
)
|
||||||
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||||
emulator."""
|
emulator."""
|
||||||
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
openai_organization: Optional[str] = Field(default=None, alias="organization")
|
||||||
"""Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
|
"""Automatically inferred from env var ``OPENAI_ORG_ID`` if not provided."""
|
||||||
@ -494,7 +494,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"""Whether to return logprobs."""
|
"""Whether to return logprobs."""
|
||||||
top_logprobs: Optional[int] = None
|
top_logprobs: Optional[int] = None
|
||||||
"""Number of most likely tokens to return at each token position, each with
|
"""Number of most likely tokens to return at each token position, each with
|
||||||
an associated log probability. `logprobs` must be set to true
|
an associated log probability. `logprobs` must be set to true
|
||||||
if this parameter is used."""
|
if this parameter is used."""
|
||||||
logit_bias: Optional[dict[int, int]] = None
|
logit_bias: Optional[dict[int, int]] = None
|
||||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||||
@ -512,7 +512,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
Reasoning models only, like OpenAI o1, o3, and o4-mini.
|
Reasoning models only, like OpenAI o1, o3, and o4-mini.
|
||||||
|
|
||||||
Currently supported values are low, medium, and high. Reducing reasoning effort
|
Currently supported values are low, medium, and high. Reducing reasoning effort
|
||||||
can result in faster responses and fewer tokens used on reasoning in a response.
|
can result in faster responses and fewer tokens used on reasoning in a response.
|
||||||
|
|
||||||
.. versionadded:: 0.2.14
|
.. versionadded:: 0.2.14
|
||||||
@ -534,21 +534,21 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
tiktoken_model_name: Optional[str] = None
|
tiktoken_model_name: Optional[str] = None
|
||||||
"""The model name to pass to tiktoken when using this class.
|
"""The model name to pass to tiktoken when using this class.
|
||||||
Tiktoken is used to count the number of tokens in documents to constrain
|
Tiktoken is used to count the number of tokens in documents to constrain
|
||||||
them to be under a certain limit. By default, when set to None, this will
|
them to be under a certain limit. By default, when set to None, this will
|
||||||
be the same as the embedding model name. However, there are some cases
|
be the same as the embedding model name. However, there are some cases
|
||||||
where you may want to use this Embedding class with a model name not
|
where you may want to use this Embedding class with a model name not
|
||||||
supported by tiktoken. This can include when using Azure embeddings or
|
supported by tiktoken. This can include when using Azure embeddings or
|
||||||
when using one of the many model providers that expose an OpenAI-like
|
when using one of the many model providers that expose an OpenAI-like
|
||||||
API but with different models. In those cases, in order to avoid erroring
|
API but with different models. In those cases, in order to avoid erroring
|
||||||
when tiktoken is called, you can specify a model name to use here."""
|
when tiktoken is called, you can specify a model name to use here."""
|
||||||
default_headers: Union[Mapping[str, str], None] = None
|
default_headers: Union[Mapping[str, str], None] = None
|
||||||
default_query: Union[Mapping[str, object], None] = None
|
default_query: Union[Mapping[str, object], None] = None
|
||||||
# Configure a custom httpx client. See the
|
# Configure a custom httpx client. See the
|
||||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||||
http_client: Union[Any, None] = Field(default=None, exclude=True)
|
http_client: Union[Any, None] = Field(default=None, exclude=True)
|
||||||
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
|
"""Optional ``httpx.Client``. Only used for sync invocations. Must specify
|
||||||
``http_async_client`` as well if you'd like a custom client for async
|
``http_async_client`` as well if you'd like a custom client for async
|
||||||
invocations.
|
invocations.
|
||||||
"""
|
"""
|
||||||
@ -580,21 +580,21 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
include_response_headers: bool = False
|
include_response_headers: bool = False
|
||||||
"""Whether to include response headers in the output message ``response_metadata``.""" # noqa: E501
|
"""Whether to include response headers in the output message ``response_metadata``.""" # noqa: E501
|
||||||
disabled_params: Optional[dict[str, Any]] = Field(default=None)
|
disabled_params: Optional[dict[str, Any]] = Field(default=None)
|
||||||
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
"""Parameters of the OpenAI client or chat.completions endpoint that should be
|
||||||
disabled for the given model.
|
disabled for the given model.
|
||||||
|
|
||||||
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
Should be specified as ``{"param": None | ['val1', 'val2']}`` where the key is the
|
||||||
parameter and the value is either None, meaning that parameter should never be
|
parameter and the value is either None, meaning that parameter should never be
|
||||||
used, or it's a list of disabled values for the parameter.
|
used, or it's a list of disabled values for the parameter.
|
||||||
|
|
||||||
For example, older models may not support the ``'parallel_tool_calls'`` parameter at
|
For example, older models may not support the ``'parallel_tool_calls'`` parameter at
|
||||||
all, in which case ``disabled_params={"parallel_tool_calls": None}`` can be passed
|
all, in which case ``disabled_params={"parallel_tool_calls": None}`` can be passed
|
||||||
in.
|
in.
|
||||||
|
|
||||||
If a parameter is disabled then it will not be used by default in any methods, e.g.
|
If a parameter is disabled then it will not be used by default in any methods, e.g.
|
||||||
in :meth:`~langchain_openai.chat_models.base.ChatOpenAI.with_structured_output`.
|
in :meth:`~langchain_openai.chat_models.base.ChatOpenAI.with_structured_output`.
|
||||||
However this does not prevent a user from directly passed in the parameter during
|
However this does not prevent a user from directly passed in the parameter during
|
||||||
invocation.
|
invocation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
include: Optional[list[str]] = None
|
include: Optional[list[str]] = None
|
||||||
|
Loading…
Reference in New Issue
Block a user