mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-16 18:02:57 +00:00
Compare commits
3 Commits
langchain-
...
mdrxy/impr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b778fe6079 | ||
|
|
2b2c5fe1a6 | ||
|
|
7f8e62f1e9 |
@@ -12,6 +12,7 @@ from typing import (
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@@ -107,6 +108,10 @@ LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
|
||||
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", AIMessage, str)
|
||||
"""Type variable for the output of a language model."""
|
||||
|
||||
# TypeVar for with_structured_output overloads. Enables precise return type inference
|
||||
# when a Pydantic BaseModel or TypedDict is passed as the schema argument.
|
||||
_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping)
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return get_verbose()
|
||||
@@ -267,9 +272,44 @@ class BaseLanguageModel(
|
||||
|
||||
"""
|
||||
|
||||
# Overloads for with_structured_output provide precise return type inference:
|
||||
# - Mapping schema (JSON/dict) -> returns dict
|
||||
# - type[_ModelT] schema (Pydantic/TypedDict) -> returns that specific type
|
||||
# - include_raw=True -> returns dict (with raw response included)
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self, schema: dict | type, **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
||||
self,
|
||||
schema: Mapping[str, Any],
|
||||
*,
|
||||
include_raw: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, dict]: ...
|
||||
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: type[_ModelT],
|
||||
*,
|
||||
include_raw: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _ModelT]: ...
|
||||
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Mapping[str, Any] | type[_ModelT],
|
||||
*,
|
||||
include_raw: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, dict]: ...
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Mapping | type,
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Any]:
|
||||
"""Not implemented on this class."""
|
||||
# Implement this on child class if there is a way of steering the model to
|
||||
# generate responses that match a given schema.
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import inspect
|
||||
import json
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
@@ -73,6 +73,10 @@ from langchain_core.utils.function_calling import (
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
from langchain_core.utils.utils import LC_ID_PREFIX, from_env
|
||||
|
||||
# TypeVar for with_structured_output overloads. Enables precise return type inference
|
||||
# when a Pydantic BaseModel or TypedDict is passed as the schema argument.
|
||||
_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
|
||||
@@ -224,7 +228,7 @@ async def agenerate_from_stream(
|
||||
return await run_in_executor(None, generate_from_stream, iter(chunks))
|
||||
|
||||
|
||||
def _format_ls_structured_output(ls_structured_output_format: dict | None) -> dict:
|
||||
def _format_ls_structured_output(ls_structured_output_format: Mapping | None) -> dict:
|
||||
if ls_structured_output_format:
|
||||
try:
|
||||
ls_structured_output_format_dict = {
|
||||
@@ -730,7 +734,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: # noqa: ARG002
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Mapping | None]) -> builtins.dict: # noqa: ARG002
|
||||
return {}
|
||||
|
||||
def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]:
|
||||
@@ -776,7 +780,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
self,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
) -> builtins.dict:
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
@@ -1492,7 +1496,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
"""Return type of chat model."""
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
@@ -1500,9 +1504,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[
|
||||
typing.Dict[str, Any] | type | Callable | BaseTool # noqa: UP006
|
||||
],
|
||||
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
|
||||
*,
|
||||
tool_choice: str | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -1519,13 +1521,44 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Overloads for with_structured_output provide precise return type inference:
|
||||
# - Mapping schema (JSON/dict) -> returns dict
|
||||
# - type[_ModelT] schema (Pydantic/TypedDict) -> returns that specific type
|
||||
# - include_raw=True -> returns dict (with raw response included)
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: typing.Dict | type, # noqa: UP006
|
||||
schema: Mapping[str, Any],
|
||||
*,
|
||||
include_raw: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, builtins.dict]: ...
|
||||
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: type[_ModelT],
|
||||
*,
|
||||
include_raw: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _ModelT]: ...
|
||||
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Mapping[str, Any] | type[_ModelT],
|
||||
*,
|
||||
include_raw: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, builtins.dict]: ...
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Mapping | type,
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, typing.Dict | BaseModel]: # noqa: UP006
|
||||
) -> Runnable[LanguageModelInput, Any]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -327,7 +328,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
|
||||
|
||||
def convert_to_openai_function(
|
||||
function: dict[str, Any] | type | Callable | BaseTool,
|
||||
function: Mapping[str, Any] | type | Callable | BaseTool,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
@@ -358,7 +359,7 @@ def convert_to_openai_function(
|
||||
required and guaranteed to be part of the output.
|
||||
"""
|
||||
# an Anthropic format tool
|
||||
if isinstance(function, dict) and all(
|
||||
if isinstance(function, Mapping) and all(
|
||||
k in function for k in ("name", "input_schema")
|
||||
):
|
||||
oai_function = {
|
||||
@@ -368,7 +369,7 @@ def convert_to_openai_function(
|
||||
if "description" in function:
|
||||
oai_function["description"] = function["description"]
|
||||
# an Amazon Bedrock Converse format tool
|
||||
elif isinstance(function, dict) and "toolSpec" in function:
|
||||
elif isinstance(function, Mapping) and "toolSpec" in function:
|
||||
oai_function = {
|
||||
"name": function["toolSpec"]["name"],
|
||||
"parameters": function["toolSpec"]["inputSchema"]["json"],
|
||||
@@ -376,15 +377,15 @@ def convert_to_openai_function(
|
||||
if "description" in function["toolSpec"]:
|
||||
oai_function["description"] = function["toolSpec"]["description"]
|
||||
# already in OpenAI function format
|
||||
elif isinstance(function, dict) and "name" in function:
|
||||
elif isinstance(function, Mapping) and "name" in function:
|
||||
oai_function = {
|
||||
k: v
|
||||
for k, v in function.items()
|
||||
if k in {"name", "description", "parameters", "strict"}
|
||||
}
|
||||
# a JSON schema with title and description
|
||||
elif isinstance(function, dict) and "title" in function:
|
||||
function_copy = function.copy()
|
||||
elif isinstance(function, Mapping) and "title" in function:
|
||||
function_copy = dict(function)
|
||||
oai_function = {"name": function_copy.pop("title")}
|
||||
if "description" in function_copy:
|
||||
oai_function["description"] = function_copy.pop("description")
|
||||
@@ -454,7 +455,7 @@ _WellKnownOpenAITools = (
|
||||
|
||||
|
||||
def convert_to_openai_tool(
|
||||
tool: dict[str, Any] | type[BaseModel] | Callable | BaseTool,
|
||||
tool: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
@@ -495,12 +496,12 @@ def convert_to_openai_tool(
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.tools import Tool # noqa: PLC0415
|
||||
|
||||
if isinstance(tool, dict):
|
||||
if isinstance(tool, Mapping):
|
||||
if tool.get("type") in _WellKnownOpenAITools:
|
||||
return tool
|
||||
return dict(tool)
|
||||
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
|
||||
if (tool.get("type") or "").startswith("web_search_preview"):
|
||||
return tool
|
||||
return dict(tool)
|
||||
if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool":
|
||||
oai_tool = {
|
||||
"type": "custom",
|
||||
@@ -515,7 +516,7 @@ def convert_to_openai_tool(
|
||||
|
||||
|
||||
def convert_to_json_schema(
|
||||
schema: dict[str, Any] | type[BaseModel] | Callable | BaseTool,
|
||||
schema: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from typing import Any, cast
|
||||
@@ -17,8 +18,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
def _fake_runnable(
|
||||
_: Any, *, schema: dict | type[BaseModel], value: Any = 42, **_kwargs: Any
|
||||
) -> BaseModel | dict:
|
||||
_: Any, *, schema: Mapping | type, value: Any = 42, **_kwargs: Any
|
||||
) -> Any:
|
||||
if isclass(schema) and is_basemodel_subclass(schema):
|
||||
return schema(name="yo", value=value)
|
||||
params = cast("dict", schema)["parameters"]
|
||||
@@ -29,9 +30,7 @@ class FakeStructuredChatModel(FakeListChatModel):
|
||||
"""Fake chat model for testing purposes."""
|
||||
|
||||
@override
|
||||
def with_structured_output(
|
||||
self, schema: dict | type[BaseModel], **kwargs: Any
|
||||
) -> Runnable:
|
||||
def with_structured_output(self, schema: Mapping | type, **kwargs: Any) -> Runnable:
|
||||
return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs))
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -335,15 +334,15 @@ class FakeStructuredOutputModel(BaseChatModel):
|
||||
@override
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
|
||||
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
return self.bind(tools=tools)
|
||||
|
||||
@override
|
||||
def with_structured_output(
|
||||
self, schema: dict | type[BaseModel], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
||||
self, schema: Mapping | type, **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Any]:
|
||||
return RunnableLambda(lambda _: {"foo": self.foo})
|
||||
|
||||
@property
|
||||
@@ -368,7 +367,7 @@ class FakeModel(BaseChatModel):
|
||||
@override
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
|
||||
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
return self.bind(tools=tools)
|
||||
|
||||
Reference in New Issue
Block a user