diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f0493497c1f..adc97e04c68 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -11,10 +11,11 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator, Sequence from functools import cached_property from operator import itemgetter -from typing import ( +from typing import ( # noqa: UP035 TYPE_CHECKING, Any, Callable, + Dict, Literal, Optional, Union, @@ -70,11 +71,13 @@ from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough from langchain_core.runnables.config import ensure_config, run_in_executor from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.utils import get_pydantic_field_names from langchain_core.utils.function_calling import ( convert_to_json_schema, convert_to_openai_tool, ) from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass +from langchain_core.utils.utils import _build_model_kwargs if TYPE_CHECKING: import uuid @@ -302,6 +305,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): - If False (default), will always use streaming case if available. """ + model_kwargs: Dict[str, Any] = Field(default_factory=dict) # noqa: UP006 + """Holds any model parameters valid for `create` call not explicitly specified.""" + @model_validator(mode="before") @classmethod def raise_deprecation(cls, values: dict) -> Any: @@ -329,6 +335,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): arbitrary_types_allowed=True, ) + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: dict[str, Any]) -> Any: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + return _build_model_kwargs(values, all_required_field_names) + @cached_property def _serialized(self) -> dict[str, Any]: return dumpd(self) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 00e7c8e9f26..1a270662565 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -403,6 +403,40 @@ async def test_disable_streaming_no_streaming_model_async( break +def test_model_kwargs() -> None: + llm = FakeListChatModel( + responses=["a", "b", "c"], + sleep=0.1, + disable_streaming=False, + model_kwargs={"foo": "bar"}, + ) + assert llm.responses == ["a", "b", "c"] + assert llm.sleep == 0.1 + assert llm.disable_streaming is False + assert llm.model_kwargs == {"foo": "bar"} + + with pytest.warns(match="transferred to model_kwargs"): + llm = FakeListChatModel( + responses=["a", "b", "c"], + sleep=0.1, + disable_streaming=False, + foo="bar", # type: ignore[call-arg] + ) + assert llm.responses == ["a", "b", "c"] + assert llm.sleep == 0.1 + assert llm.disable_streaming is False + assert llm.model_kwargs == {"foo": "bar"} + + # Backward compatibility + with pytest.warns(match="should be specified explicitly"): + llm = FakeListChatModel( # type: ignore[call-arg] + model_kwargs={"foo": "bar", "responses": ["a", "b", "c"], "sleep": 0.1}, + ) + assert llm.responses == ["a", "b", "c"] + assert llm.sleep == 0.1 + assert llm.model_kwargs == {"foo": "bar"} + + class FakeChatModelStartTracer(FakeTracer): def __init__(self) -> None: super().__init__() diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index b1519bc3f75..3bd12cf21bc 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -97,7 +97,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['foo, bar'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])", "name": "FakeListChatModel" } ], @@ -227,7 +227,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['baz, qux'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])", "name": "FakeListChatModel" } ], @@ -346,7 +346,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['foo, bar'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])", "name": "FakeListChatModel" }, { @@ -457,7 +457,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['baz, qux'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])", "name": "FakeListChatModel" } ], @@ -1009,7 +1009,7 @@ # name: test_prompt_with_chat_model ''' ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})]) - | FakeListChatModel(responses=['foo']) + | FakeListChatModel(model_kwargs={}, responses=['foo']) ''' # --- # name: test_prompt_with_chat_model.1 @@ -1109,7 +1109,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['foo'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])", "name": "FakeListChatModel" } }, @@ -1220,7 +1220,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['foo, bar'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])", "name": "FakeListChatModel" } ], @@ -1249,7 +1249,7 @@ # name: test_prompt_with_chat_model_async ''' ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})]) - | FakeListChatModel(responses=['foo']) + | FakeListChatModel(model_kwargs={}, responses=['foo']) ''' # --- # name: test_prompt_with_chat_model_async.1 @@ -1349,7 +1349,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['foo'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])", "name": "FakeListChatModel" } }, @@ -13535,7 +13535,7 @@ just_to_test_lambda: RunnableLambda(...) } | ChatPromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, template='Context:\n{documents}\n\nQuestion:\n{question}'), additional_kwargs={})]) - | FakeListChatModel(responses=['foo, bar']) + | FakeListChatModel(model_kwargs={}, responses=['foo, bar']) | CommaSeparatedListOutputParser() ''' # --- @@ -13738,7 +13738,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=['foo, bar'])", + "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])", "name": "FakeListChatModel" } ], @@ -13764,7 +13764,7 @@ ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})]) | RunnableLambda(...) | { - chat: FakeListChatModel(responses=["i'm a chatbot"]), + chat: FakeListChatModel(model_kwargs={}, responses=["i'm a chatbot"]), llm: FakeListLLM(responses=["i'm a textbot"]) } ''' @@ -13890,7 +13890,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])", + "repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])", "name": "FakeListChatModel" }, "llm": { @@ -14045,7 +14045,7 @@ "fake_chat_models", "FakeListChatModel" ], - "repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])", + "repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])", "name": "FakeListChatModel" }, "kwargs": {