Compare commits

...

3 Commits

Author SHA1 Message Date
Mason Daugherty
b778fe6079 Merge branch 'master' into mdrxy/improve-core-typing 2025-12-10 02:32:12 -05:00
Mason Daugherty
2b2c5fe1a6 comments 2025-12-10 02:30:35 -05:00
Mason Daugherty
7f8e62f1e9 refactor(core): improved typing (WIP) 2025-11-15 02:23:59 -05:00
5 changed files with 108 additions and 36 deletions

View File

@@ -12,6 +12,7 @@ from typing import (
Literal,
TypeAlias,
TypeVar,
overload,
)
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -107,6 +108,10 @@ LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", AIMessage, str)
"""Type variable for the output of a language model."""
# TypeVar for with_structured_output overloads. Enables precise return type inference
# when a Pydantic BaseModel or TypedDict is passed as the schema argument.
_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping)
def _get_verbosity() -> bool:
return get_verbose()
@@ -267,9 +272,44 @@ class BaseLanguageModel(
"""
# Overloads for with_structured_output provide precise return type inference:
# - Mapping schema (JSON/dict) -> returns dict
# - type[_ModelT] schema (Pydantic/TypedDict) -> returns that specific type
# - include_raw=True -> returns dict (with raw response included)
@overload
def with_structured_output(
self, schema: dict | type, **kwargs: Any
) -> Runnable[LanguageModelInput, dict | BaseModel]:
self,
schema: Mapping[str, Any],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict]: ...
@overload
def with_structured_output(
self,
schema: type[_ModelT],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _ModelT]: ...
@overload
def with_structured_output(
self,
schema: Mapping[str, Any] | type[_ModelT],
*,
include_raw: Literal[True],
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict]: ...
def with_structured_output(
self,
schema: Mapping | type,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Any]:
"""Not implemented on this class."""
# Implement this on child class if there is a way of steering the model to
# generate responses that match a given schema.

View File

@@ -3,14 +3,14 @@
from __future__ import annotations
import asyncio
import builtins
import inspect
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override
@@ -73,6 +73,10 @@ from langchain_core.utils.function_calling import (
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from langchain_core.utils.utils import LC_ID_PREFIX, from_env
# TypeVar for with_structured_output overloads. Enables precise return type inference
# when a Pydantic BaseModel or TypedDict is passed as the schema argument.
_ModelT = TypeVar("_ModelT", bound=BaseModel | Mapping)
if TYPE_CHECKING:
import uuid
@@ -224,7 +228,7 @@ async def agenerate_from_stream(
return await run_in_executor(None, generate_from_stream, iter(chunks))
def _format_ls_structured_output(ls_structured_output_format: dict | None) -> dict:
def _format_ls_structured_output(ls_structured_output_format: Mapping | None) -> dict:
if ls_structured_output_format:
try:
ls_structured_output_format_dict = {
@@ -730,7 +734,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
# --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: # noqa: ARG002
def _combine_llm_outputs(self, llm_outputs: list[Mapping | None]) -> builtins.dict: # noqa: ARG002
return {}
def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]:
@@ -776,7 +780,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
self,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
) -> builtins.dict:
params = self.dict()
params["stop"] = stop
return {**params, **kwargs}
@@ -1492,7 +1496,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
"""Return type of chat model."""
@override
def dict(self, **kwargs: Any) -> dict:
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
@@ -1500,9 +1504,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
def bind_tools(
self,
tools: Sequence[
typing.Dict[str, Any] | type | Callable | BaseTool # noqa: UP006
],
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
*,
tool_choice: str | None = None,
**kwargs: Any,
@@ -1519,13 +1521,44 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
"""
raise NotImplementedError
# Overloads for with_structured_output provide precise return type inference:
# - Mapping schema (JSON/dict) -> returns dict
# - type[_ModelT] schema (Pydantic/TypedDict) -> returns that specific type
# - include_raw=True -> returns dict (with raw response included)
@overload
def with_structured_output(
self,
schema: typing.Dict | type, # noqa: UP006
schema: Mapping[str, Any],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, builtins.dict]: ...
@overload
def with_structured_output(
self,
schema: type[_ModelT],
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _ModelT]: ...
@overload
def with_structured_output(
self,
schema: Mapping[str, Any] | type[_ModelT],
*,
include_raw: Literal[True],
**kwargs: Any,
) -> Runnable[LanguageModelInput, builtins.dict]: ...
def with_structured_output(
self,
schema: Mapping | type,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, typing.Dict | BaseModel]: # noqa: UP006
) -> Runnable[LanguageModelInput, Any]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:

View File

@@ -8,6 +8,7 @@ import logging
import types
import typing
import uuid
from collections.abc import Mapping
from typing import (
TYPE_CHECKING,
Annotated,
@@ -327,7 +328,7 @@ def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
def convert_to_openai_function(
function: dict[str, Any] | type | Callable | BaseTool,
function: Mapping[str, Any] | type | Callable | BaseTool,
*,
strict: bool | None = None,
) -> dict[str, Any]:
@@ -358,7 +359,7 @@ def convert_to_openai_function(
required and guaranteed to be part of the output.
"""
# an Anthropic format tool
if isinstance(function, dict) and all(
if isinstance(function, Mapping) and all(
k in function for k in ("name", "input_schema")
):
oai_function = {
@@ -368,7 +369,7 @@ def convert_to_openai_function(
if "description" in function:
oai_function["description"] = function["description"]
# an Amazon Bedrock Converse format tool
elif isinstance(function, dict) and "toolSpec" in function:
elif isinstance(function, Mapping) and "toolSpec" in function:
oai_function = {
"name": function["toolSpec"]["name"],
"parameters": function["toolSpec"]["inputSchema"]["json"],
@@ -376,15 +377,15 @@ def convert_to_openai_function(
if "description" in function["toolSpec"]:
oai_function["description"] = function["toolSpec"]["description"]
# already in OpenAI function format
elif isinstance(function, dict) and "name" in function:
elif isinstance(function, Mapping) and "name" in function:
oai_function = {
k: v
for k, v in function.items()
if k in {"name", "description", "parameters", "strict"}
}
# a JSON schema with title and description
elif isinstance(function, dict) and "title" in function:
function_copy = function.copy()
elif isinstance(function, Mapping) and "title" in function:
function_copy = dict(function)
oai_function = {"name": function_copy.pop("title")}
if "description" in function_copy:
oai_function["description"] = function_copy.pop("description")
@@ -454,7 +455,7 @@ _WellKnownOpenAITools = (
def convert_to_openai_tool(
tool: dict[str, Any] | type[BaseModel] | Callable | BaseTool,
tool: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool,
*,
strict: bool | None = None,
) -> dict[str, Any]:
@@ -495,12 +496,12 @@ def convert_to_openai_tool(
# Import locally to prevent circular import
from langchain_core.tools import Tool # noqa: PLC0415
if isinstance(tool, dict):
if isinstance(tool, Mapping):
if tool.get("type") in _WellKnownOpenAITools:
return tool
return dict(tool)
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
if (tool.get("type") or "").startswith("web_search_preview"):
return tool
return dict(tool)
if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool":
oai_tool = {
"type": "custom",
@@ -515,7 +516,7 @@ def convert_to_openai_tool(
def convert_to_json_schema(
schema: dict[str, Any] | type[BaseModel] | Callable | BaseTool,
schema: Mapping[str, Any] | type[BaseModel] | Callable | BaseTool,
*,
strict: bool | None = None,
) -> dict[str, Any]:

View File

@@ -1,3 +1,4 @@
from collections.abc import Mapping
from functools import partial
from inspect import isclass
from typing import Any, cast
@@ -17,8 +18,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable(
_: Any, *, schema: dict | type[BaseModel], value: Any = 42, **_kwargs: Any
) -> BaseModel | dict:
_: Any, *, schema: Mapping | type, value: Any = 42, **_kwargs: Any
) -> Any:
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value)
params = cast("dict", schema)["parameters"]
@@ -29,9 +30,7 @@ class FakeStructuredChatModel(FakeListChatModel):
"""Fake chat model for testing purposes."""
@override
def with_structured_output(
self, schema: dict | type[BaseModel], **kwargs: Any
) -> Runnable:
def with_structured_output(self, schema: Mapping | type, **kwargs: Any) -> Runnable:
return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs))
@property

View File

@@ -1,10 +1,9 @@
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from typing import (
Any,
)
import pytest
from pydantic import BaseModel
from syrupy.assertion import SnapshotAssertion
from typing_extensions import override
@@ -335,15 +334,15 @@ class FakeStructuredOutputModel(BaseChatModel):
@override
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.bind(tools=tools)
@override
def with_structured_output(
self, schema: dict | type[BaseModel], **kwargs: Any
) -> Runnable[LanguageModelInput, dict | BaseModel]:
self, schema: Mapping | type, **kwargs: Any
) -> Runnable[LanguageModelInput, Any]:
return RunnableLambda(lambda _: {"foo": self.foo})
@property
@@ -368,7 +367,7 @@ class FakeModel(BaseChatModel):
@override
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
tools: Sequence[Mapping[str, Any] | type | Callable | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.bind(tools=tools)