diff --git a/libs/partners/deepseek/langchain_deepseek/chat_models.py b/libs/partners/deepseek/langchain_deepseek/chat_models.py index 2398b1f9881..4714c6454bf 100644 --- a/libs/partners/deepseek/langchain_deepseek/chat_models.py +++ b/libs/partners/deepseek/langchain_deepseek/chat_models.py @@ -1,21 +1,27 @@ """DeepSeek chat models.""" from json import JSONDecodeError -from typing import Any, Dict, Iterator, List, Optional, Type, Union +from typing import Any, Dict, Iterator, List, Literal, Optional, Type, TypeVar, Union import openai from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.messages import AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable from langchain_core.utils import from_env, secret_from_env from langchain_openai.chat_models.base import BaseChatOpenAI -from pydantic import ConfigDict, Field, SecretStr, model_validator +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self DEFAULT_API_BASE = "https://api.deepseek.com/v1" +_BM = TypeVar("_BM", bound=BaseModel) +_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] +_DictOrPydantic = Union[Dict, _BM] + class ChatDeepSeek(BaseChatOpenAI): """DeepSeek chat model integration to access models hosted in DeepSeek's API. @@ -197,14 +203,15 @@ class ChatDeepSeek(BaseChatOpenAI): if not (self.client or None): sync_specific: dict = {"http_client": self.http_client} - self.client = openai.OpenAI( - **client_params, **sync_specific - ).chat.completions + self.root_client = openai.OpenAI(**client_params, **sync_specific) + self.client = self.root_client.chat.completions if not (self.async_client or None): async_specific: dict = {"http_client": self.http_async_client} - self.async_client = openai.AsyncOpenAI( - **client_params, **async_specific - ).chat.completions + self.root_async_client = openai.AsyncOpenAI( + **client_params, + **async_specific, + ) + self.async_client = self.root_async_client.chat.completions return self def _create_chat_result( @@ -281,3 +288,73 @@ class ChatDeepSeek(BaseChatOpenAI): e.doc, e.pos, ) from e + + def with_structured_output( + self, + schema: Optional[_DictOrPydanticClass] = None, + *, + method: Literal[ + "function_calling", "json_mode", "json_schema" + ] = "function_calling", + include_raw: bool = False, + strict: Optional[bool] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _DictOrPydantic]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: + The output schema. Can be passed in as: + + - an OpenAI function/tool schema, + - a JSON Schema, + - a TypedDict class (support added in 0.1.20), + - or a Pydantic class. + + If ``schema`` is a Pydantic class then the model output will be a + Pydantic instance of that class, and the model-generated fields will be + validated by the Pydantic class. Otherwise the model output will be a + dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` + for more on how to properly specify types and descriptions of + schema fields when specifying a Pydantic or TypedDict class. + + method: The method for steering model generation, one of: + + - "function_calling": + Uses DeekSeek's `tool-calling features `_. + - "json_mode": + Uses DeepSeek's `JSON mode feature `_. + + .. versionchanged:: 0.1.3 + + Added support for ``"json_mode"``. + + include_raw: + If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + kwargs: Additional keyword args aren't supported. + + Returns: + A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. + + | If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + | If ``include_raw`` is True, then Runnable outputs a dict with keys: + + - "raw": BaseMessage + - "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - "parsing_error": Optional[BaseException] + + """ # noqa: E501 + # Some applications require that incompatible parameters (e.g., unsupported + # methods) be handled. + if method == "json_schema": + method = "function_calling" + return super().with_structured_output( + schema, method=method, include_raw=include_raw, strict=strict, **kwargs + ) diff --git a/libs/partners/deepseek/tests/integration_tests/test_chat_models.py b/libs/partners/deepseek/tests/integration_tests/test_chat_models.py index 521296c3ef6..3ead8786e8b 100644 --- a/libs/partners/deepseek/tests/integration_tests/test_chat_models.py +++ b/libs/partners/deepseek/tests/integration_tests/test_chat_models.py @@ -24,6 +24,11 @@ class TestChatDeepSeek(ChatModelIntegrationTests): "temperature": 0, } + @property + def supports_json_mode(self) -> bool: + """(bool) whether the chat model supports JSON mode.""" + return True + @pytest.mark.xfail(reason="Not yet supported.") def test_tool_message_histories_list_content( self, model: BaseChatModel, my_adder_tool: BaseTool diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 6ac752d3777..3f776456559 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -75,6 +75,7 @@ from langchain_core.utils import ( get_pydantic_field_names, ) from langchain_core.utils.function_calling import ( + convert_to_json_schema, convert_to_openai_function, convert_to_openai_tool, ) @@ -737,7 +738,9 @@ class ChatFireworks(BaseChatModel): self, schema: Optional[Union[Dict, Type[BaseModel]]] = None, *, - method: Literal["function_calling", "json_mode"] = "function_calling", + method: Literal[ + "function_calling", "json_mode", "json_schema" + ] = "function_calling", include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: @@ -761,13 +764,19 @@ class ChatFireworks(BaseChatModel): Added support for TypedDict class. - method: - The method for steering model generation, either "function_calling" - or "json_mode". If "function_calling" then the schema will be converted - to an OpenAI function and the returned model will make use of the - function-calling API. If "json_mode" then OpenAI's JSON mode will be - used. Note that if using "json_mode" then you must include instructions - for formatting the output into the desired schema into the model call. + method: The method for steering model generation, one of: + + - "function_calling": + Uses Fireworks's `tool-calling features `_. + - "json_schema": + Uses Fireworks's `structured output feature `_. + - "json_mode": + Uses Fireworks's `JSON mode feature `_. + + .. versionchanged:: 0.2.8 + + Added support for ``"json_schema"``. + include_raw: If False then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If True @@ -928,11 +937,11 @@ class ChatFireworks(BaseChatModel): structured_llm.invoke( "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "Make sure to return a JSON blob with keys 'answer' and 'justification'. " "What's heavier a pound of bricks or a pound of feathers?" ) # -> { - # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'raw': AIMessage(content='{"answer": "They are both the same weight.", "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight."}'), # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), # 'parsing_error': None # } @@ -944,11 +953,11 @@ class ChatFireworks(BaseChatModel): structured_llm.invoke( "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "Make sure to return a JSON blob with keys 'answer' and 'justification'. " "What's heavier a pound of bricks or a pound of feathers?" ) # -> { - # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'raw': AIMessage(content='{"answer": "They are both the same weight.", "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight."}'), # 'parsed': { # 'answer': 'They are both the same weight.', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' @@ -956,6 +965,7 @@ class ChatFireworks(BaseChatModel): # 'parsing_error': None # } """ # noqa: E501 + _ = kwargs.pop("strict", None) if kwargs: raise ValueError(f"Received unsupported arguments {kwargs}") is_pydantic_schema = _is_pydantic_class(schema) @@ -984,6 +994,25 @@ class ChatFireworks(BaseChatModel): output_parser = JsonOutputKeyToolsParser( key_name=tool_name, first_tool_only=True ) + elif method == "json_schema": + if schema is None: + raise ValueError( + "schema must be specified when method is 'json_schema'. " + "Received None." + ) + formatted_schema = convert_to_json_schema(schema) + llm = self.bind( + response_format={"type": "json_object", "schema": formatted_schema}, + ls_structured_output_format={ + "kwargs": {"method": "json_schema"}, + "schema": schema, + }, + ) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) elif method == "json_mode": llm = self.bind( response_format={"type": "json_object"}, diff --git a/libs/partners/fireworks/pyproject.toml b/libs/partners/fireworks/pyproject.toml index 5efa04c4be1..c89257dbee7 100644 --- a/libs/partners/fireworks/pyproject.toml +++ b/libs/partners/fireworks/pyproject.toml @@ -7,7 +7,7 @@ authors = [] license = { text = "MIT" } requires-python = "<4.0,>=3.9" dependencies = [ - "langchain-core<1.0.0,>=0.3.33", + "langchain-core<1.0.0,>=0.3.46", "fireworks-ai>=0.13.0", "openai<2.0.0,>=1.10.0", "requests<3,>=2", diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index 885e3071cff..ecaa2ebca8a 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -4,10 +4,12 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests. """ import json -from typing import Optional +from typing import Any, Literal, Optional +import pytest from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated, TypedDict from langchain_fireworks import ChatFireworks @@ -161,3 +163,54 @@ def test_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + + +def _get_joke_class( + schema_type: Literal["pydantic", "typeddict", "json_schema"], +) -> Any: + class Joke(BaseModel): + """Joke to tell user.""" + + setup: str = Field(description="question to set up a joke") + punchline: str = Field(description="answer to resolve the joke") + + def validate_joke(result: Any) -> bool: + return isinstance(result, Joke) + + class JokeDict(TypedDict): + """Joke to tell user.""" + + setup: Annotated[str, ..., "question to set up a joke"] + punchline: Annotated[str, ..., "answer to resolve the joke"] + + def validate_joke_dict(result: Any) -> bool: + return all(key in ["setup", "punchline"] for key in result.keys()) + + if schema_type == "pydantic": + return Joke, validate_joke + + elif schema_type == "typeddict": + return JokeDict, validate_joke_dict + + elif schema_type == "json_schema": + return Joke.model_json_schema(), validate_joke_dict + else: + raise ValueError("Invalid schema type") + + +@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"]) +def test_structured_output_json_schema(schema_type: str) -> None: + llm = ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct") + schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type] + chat = llm.with_structured_output(schema, method="json_schema") + + # Test invoke + result = chat.invoke("Tell me a joke about cats.") + validation_function(result) + + # Test stream + chunks = [] + for chunk in chat.stream("Tell me a joke about cats."): + validation_function(chunk) + chunks.append(chunk) + assert chunk diff --git a/libs/partners/fireworks/uv.lock b/libs/partners/fireworks/uv.lock index 570944acf7b..5a0446c6308 100644 --- a/libs/partners/fireworks/uv.lock +++ b/libs/partners/fireworks/uv.lock @@ -635,7 +635,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "0.3.35" +version = "0.3.46" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" }, @@ -667,7 +667,7 @@ dev = [ ] lint = [{ name = "ruff", specifier = ">=0.9.2,<1.0.0" }] test = [ - { name = "blockbuster", specifier = "~=1.5.11" }, + { name = "blockbuster", specifier = "~=1.5.18" }, { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, { name = "grandalf", specifier = ">=0.8,<1.0" }, { name = "langchain-tests", directory = "../../standard-tests" }, @@ -763,7 +763,7 @@ typing = [ [[package]] name = "langchain-tests" -version = "0.3.11" +version = "0.3.14" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" }, @@ -780,8 +780,7 @@ dependencies = [ requires-dist = [ { name = "httpx", specifier = ">=0.25.0,<1" }, { name = "langchain-core", editable = "../../core" }, - { name = "numpy", marker = "python_full_version < '3.12'", specifier = ">=1.24.0,<2.0.0" }, - { name = "numpy", marker = "python_full_version >= '3.12'", specifier = ">=1.26.2,<3" }, + { name = "numpy", specifier = ">=1.26.2,<3" }, { name = "pytest", specifier = ">=7,<9" }, { name = "pytest-asyncio", specifier = ">=0.20,<1" }, { name = "pytest-socket", specifier = ">=0.6.0,<1" }, diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index a95722695d8..560e9a5d479 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -460,6 +460,24 @@ class ChatGroq(BaseChatModel): ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop] return ls_params + def _should_stream( + self, + *, + async_api: bool, + run_manager: Optional[ + Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] + ] = None, + **kwargs: Any, + ) -> bool: + """Determine if a given model call should hit the streaming API.""" + base_should_stream = super()._should_stream( + async_api=async_api, run_manager=run_manager, **kwargs + ) + if base_should_stream and ("response_format" in kwargs): + # Streaming not supported in JSON mode. + return kwargs["response_format"] != {"type": "json_object"} + return base_should_stream + def _generate( self, messages: List[BaseMessage], @@ -987,9 +1005,14 @@ class ChatGroq(BaseChatModel): # 'parsing_error': None # } """ # noqa: E501 + _ = kwargs.pop("strict", None) if kwargs: raise ValueError(f"Received unsupported arguments {kwargs}") is_pydantic_schema = _is_pydantic_class(schema) + if method == "json_schema": + # Some applications require that incompatible parameters (e.g., unsupported + # methods) be handled. + method = "function_calling" if method == "function_calling": if schema is None: raise ValueError( diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index fe91a614dd7..6b4e0a3b068 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -47,4 +47,4 @@ class TestGroqLlama(BaseTestGroq): @property def supports_json_mode(self) -> bool: - return False # Not supported in streaming mode + return True diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 98aa414e573..9cf01cd7a24 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -945,6 +945,7 @@ class ChatMistralAI(BaseChatModel): # } """ # noqa: E501 + _ = kwargs.pop("strict", None) if kwargs: raise ValueError(f"Received unsupported arguments {kwargs}") is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 0081a734cf3..f3f8553551d 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -4,16 +4,28 @@ from typing import ( Any, Dict, List, + Literal, Optional, + Type, + TypeVar, + Union, ) import openai -from langchain_core.language_models.chat_models import LangSmithParams +from langchain_core.language_models.chat_models import ( + LangSmithParams, + LanguageModelInput, +) +from langchain_core.runnables import Runnable from langchain_core.utils import secret_from_env from langchain_openai.chat_models.base import BaseChatOpenAI -from pydantic import ConfigDict, Field, SecretStr, model_validator +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self +_BM = TypeVar("_BM", bound=BaseModel) +_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] +_DictOrPydantic = Union[Dict, _BM] + class ChatXAI(BaseChatOpenAI): # type: ignore[override] r"""ChatXAI chat model. @@ -359,3 +371,83 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] **async_specific, ) return self + + def with_structured_output( + self, + schema: Optional[_DictOrPydanticClass] = None, + *, + method: Literal[ + "function_calling", "json_mode", "json_schema" + ] = "function_calling", + include_raw: bool = False, + strict: Optional[bool] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _DictOrPydantic]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: + The output schema. Can be passed in as: + + - an OpenAI function/tool schema, + - a JSON Schema, + - a TypedDict class (support added in 0.1.20), + - or a Pydantic class. + + If ``schema`` is a Pydantic class then the model output will be a + Pydantic instance of that class, and the model-generated fields will be + validated by the Pydantic class. Otherwise the model output will be a + dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool` + for more on how to properly specify types and descriptions of + schema fields when specifying a Pydantic or TypedDict class. + + method: The method for steering model generation, one of: + + - "function_calling": + Uses xAI's `tool-calling features `_. + - "json_schema": + Uses xAI's `structured output feature `_. + - "json_mode": + Uses xAI's JSON mode feature. + + include_raw: + If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + strict: + + - True: + Model output is guaranteed to exactly match the schema. + The input schema will also be validated according to + https://platform.openai.com/docs/guides/structured-outputs/supported-schemas + - False: + Input schema will not be validated and model output will not be + validated. + - None: + ``strict`` argument will not be passed to the model. + + kwargs: Additional keyword args aren't supported. + + Returns: + A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. + + | If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + | If ``include_raw`` is True, then Runnable outputs a dict with keys: + + - "raw": BaseMessage + - "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - "parsing_error": Optional[BaseException] + + """ # noqa: E501 + # Some applications require that incompatible parameters (e.g., unsupported + # methods) be handled. + if method == "function_calling" and strict: + strict = None + return super().with_structured_output( + schema, method=method, include_raw=include_raw, strict=strict, **kwargs + ) diff --git a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py index 65dea9079b8..f3363ef5fa7 100644 --- a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py @@ -626,6 +626,12 @@ class ChatModelUnitTests(ChatModelTests): return assert model.with_structured_output(schema) is not None + for method in ["json_schema", "function_calling", "json_mode"]: + strict_values = [None, False, True] if method != "json_mode" else [None] + for strict in strict_values: + assert model.with_structured_output( + schema, method=method, strict=strict + ) def test_standard_params(self, model: BaseChatModel) -> None: """Test that model properly generates standard parameters. These are used