This commit is contained in:
Chester Curme 2025-07-28 11:01:42 -04:00
parent b8fed06409
commit 61e329637b
3 changed files with 33 additions and 20 deletions

View File

@ -740,7 +740,7 @@ class BaseChatModelV1(BaseLanguageModel[AIMessageV1], ABC):
*,
tool_choice: Optional[Union[str]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
) -> Runnable[LanguageModelInput, AIMessageV1]:
"""Bind tools to the model.
Args:

View File

@ -9,7 +9,7 @@ from typing import Annotated, Any, Optional
from pydantic import SkipValidation, ValidationError
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages import AIMessage, InvalidToolCall, ToolCall
from langchain_core.messages.tool import invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
@ -26,7 +26,7 @@ def parse_tool_call(
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[dict[str, Any]]:
) -> Optional[ToolCall]:
"""Parse a single tool call.
Args:

View File

@ -37,9 +37,9 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.language_models.v1.chat_models import (
BaseChatModelV1,
agenerate_from_stream,
generate_from_stream,
)
@ -103,6 +103,7 @@ from langchain_openai.chat_models._compat import (
)
if TYPE_CHECKING:
from langchain_core.messages import content_blocks as types
from openai.types.responses import Response
logger = logging.getLogger(__name__)
@ -138,13 +139,17 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> MessageV1:
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = [{"type": "text", "text": _dict.get("content", "") or ""}]
content: list[types.ContentBlock] = [
{"type": "text", "text": _dict.get("content", "") or ""}
]
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
tool_call = parse_tool_call(raw_tool_call, return_id=True)
if tool_call:
tool_calls.append(tool_call)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
@ -152,7 +157,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> MessageV1:
content.extend(tool_calls)
if audio := _dict.get("audio"):
# TODO: populate standard fields
content.append({"type": "audio", "audio": audio})
content.append(
cast(types.AudioContentBlock, {"type": "audio", "audio": audio})
)
return AIMessageV1(
content=content,
name=name,
@ -368,7 +375,7 @@ class _AllReturnType(TypedDict):
parsing_error: Optional[BaseException]
class BaseChatOpenAIV1(BaseChatModel):
class BaseChatOpenAIV1(BaseChatModelV1):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private:
@ -822,7 +829,7 @@ class BaseChatOpenAIV1(BaseChatModel):
if generation_chunk:
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
generation_chunk.text or "", chunk=generation_chunk
)
is_first_chunk = False
yield generation_chunk
@ -873,7 +880,7 @@ class BaseChatOpenAIV1(BaseChatModel):
if generation_chunk:
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
generation_chunk.text or "", chunk=generation_chunk
)
is_first_chunk = False
yield generation_chunk
@ -944,7 +951,9 @@ class BaseChatOpenAIV1(BaseChatModel):
logprobs = message_chunk.response_metadata.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
message_chunk.text, chunk=message_chunk, logprobs=logprobs
message_chunk.text or "",
chunk=message_chunk,
logprobs=logprobs,
)
is_first_chunk = False
yield message_chunk
@ -954,7 +963,9 @@ class BaseChatOpenAIV1(BaseChatModel):
final_completion = response.get_final_completion()
message_chunk = self._get_message_chunk_from_completion(final_completion)
if run_manager:
run_manager.on_llm_new_token(message_chunk.text, chunk=message_chunk)
run_manager.on_llm_new_token(
message_chunk.text or "", chunk=message_chunk
)
yield message_chunk
def _generate(
@ -1029,7 +1040,7 @@ class BaseChatOpenAIV1(BaseChatModel):
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict:
messages = self._convert_input(input_).to_messages(output_version="v1")
messages = self._convert_input(input_)
if stop is not None:
kwargs["stop"] = stop
@ -1168,7 +1179,9 @@ class BaseChatOpenAIV1(BaseChatModel):
logprobs = message_chunk.response_metadata.get("logprobs")
if run_manager:
await run_manager.on_llm_new_token(
message_chunk.text, chunk=message_chunk, logprobs=logprobs
message_chunk.text or "",
chunk=message_chunk,
logprobs=logprobs,
)
is_first_chunk = False
yield message_chunk
@ -1179,7 +1192,7 @@ class BaseChatOpenAIV1(BaseChatModel):
message_chunk = self._get_message_chunk_from_completion(final_completion)
if run_manager:
await run_manager.on_llm_new_token(
message_chunk.text, chunk=message_chunk
message_chunk.text or "", chunk=message_chunk
)
yield message_chunk
@ -1420,7 +1433,7 @@ class BaseChatOpenAIV1(BaseChatModel):
strict: Optional[bool] = None,
parallel_tool_calls: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, MessageV1]:
) -> Runnable[LanguageModelInput, AIMessageV1]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
@ -1614,7 +1627,7 @@ class BaseChatOpenAIV1(BaseChatModel):
kwargs: Additional keyword args are passed through to the model.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
A Runnable that takes same inputs as a :class:`from langchain_core.language_models.v1.chat_models import BaseChatModelV1`.
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
@ -2629,7 +2642,7 @@ class ChatOpenAI(BaseChatOpenAIV1): # type: ignore[override]
kwargs: Additional keyword args are passed through to the model.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
A Runnable that takes same inputs as a :class:`from langchain_core.language_models.v1.chat_models import BaseChatModelV1`.
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.