core[minor]: BaseChatModel with_structured_output implementation (#22859)

This commit is contained in:
Brace Sproul
2024-06-21 08:14:03 -07:00
committed by GitHub
parent 360a70c8a8
commit abe7566d7d
4 changed files with 142 additions and 29 deletions

View File

@@ -6,7 +6,6 @@ from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
@@ -14,7 +13,6 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from langchain_community.chat_models.ollama import ChatOllama
@@ -72,7 +70,6 @@ DEFAULT_RESPONSE_FUNCTION = {
}
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
_DictOrPydantic = Union[Dict, _BM]
@@ -151,33 +148,13 @@ class OllamaFunctions(ChatOllama):
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
@overload
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
include_raw: Literal[True] = True,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _AllReturnType]:
...
@overload
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
include_raw: Literal[False] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
...
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args: