implement in core

This commit is contained in:
Chester Curme 2025-04-24 16:36:59 -04:00
parent a7903280dd
commit a85e0aed5f
3 changed files with 62 additions and 15 deletions

View File

@ -11,10 +11,11 @@ from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Iterator, Sequence
from functools import cached_property from functools import cached_property
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import ( # noqa: UP035
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
Literal, Literal,
Optional, Optional,
Union, Union,
@ -70,11 +71,13 @@ from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
convert_to_json_schema, convert_to_json_schema,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from langchain_core.utils.utils import _build_model_kwargs
if TYPE_CHECKING: if TYPE_CHECKING:
import uuid import uuid
@ -302,6 +305,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
- If False (default), will always use streaming case if available. - If False (default), will always use streaming case if available.
""" """
model_kwargs: Dict[str, Any] = Field(default_factory=dict) # noqa: UP006
"""Holds any model parameters valid for `create` call not explicitly specified."""
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
@ -329,6 +335,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
) )
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
return _build_model_kwargs(values, all_required_field_names)
@cached_property @cached_property
def _serialized(self) -> dict[str, Any]: def _serialized(self) -> dict[str, Any]:
return dumpd(self) return dumpd(self)

View File

@ -403,6 +403,40 @@ async def test_disable_streaming_no_streaming_model_async(
break break
def test_model_kwargs() -> None:
llm = FakeListChatModel(
responses=["a", "b", "c"],
sleep=0.1,
disable_streaming=False,
model_kwargs={"foo": "bar"},
)
assert llm.responses == ["a", "b", "c"]
assert llm.sleep == 0.1
assert llm.disable_streaming is False
assert llm.model_kwargs == {"foo": "bar"}
with pytest.warns(match="transferred to model_kwargs"):
llm = FakeListChatModel(
responses=["a", "b", "c"],
sleep=0.1,
disable_streaming=False,
foo="bar", # type: ignore[call-arg]
)
assert llm.responses == ["a", "b", "c"]
assert llm.sleep == 0.1
assert llm.disable_streaming is False
assert llm.model_kwargs == {"foo": "bar"}
# Backward compatibility
with pytest.warns(match="should be specified explicitly"):
llm = FakeListChatModel( # type: ignore[call-arg]
model_kwargs={"foo": "bar", "responses": ["a", "b", "c"], "sleep": 0.1},
)
assert llm.responses == ["a", "b", "c"]
assert llm.sleep == 0.1
assert llm.model_kwargs == {"foo": "bar"}
class FakeChatModelStartTracer(FakeTracer): class FakeChatModelStartTracer(FakeTracer):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()

View File

@ -97,7 +97,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['foo, bar'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
} }
], ],
@ -227,7 +227,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['baz, qux'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
} }
], ],
@ -346,7 +346,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['foo, bar'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
}, },
{ {
@ -457,7 +457,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['baz, qux'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
} }
], ],
@ -1009,7 +1009,7 @@
# name: test_prompt_with_chat_model # name: test_prompt_with_chat_model
''' '''
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})]) ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
| FakeListChatModel(responses=['foo']) | FakeListChatModel(model_kwargs={}, responses=['foo'])
''' '''
# --- # ---
# name: test_prompt_with_chat_model.1 # name: test_prompt_with_chat_model.1
@ -1109,7 +1109,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['foo'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
} }
}, },
@ -1220,7 +1220,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['foo, bar'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
} }
], ],
@ -1249,7 +1249,7 @@
# name: test_prompt_with_chat_model_async # name: test_prompt_with_chat_model_async
''' '''
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})]) ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
| FakeListChatModel(responses=['foo']) | FakeListChatModel(model_kwargs={}, responses=['foo'])
''' '''
# --- # ---
# name: test_prompt_with_chat_model_async.1 # name: test_prompt_with_chat_model_async.1
@ -1349,7 +1349,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['foo'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
} }
}, },
@ -13535,7 +13535,7 @@
just_to_test_lambda: RunnableLambda(...) just_to_test_lambda: RunnableLambda(...)
} }
| ChatPromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, template='Context:\n{documents}\n\nQuestion:\n{question}'), additional_kwargs={})]) | ChatPromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, template='Context:\n{documents}\n\nQuestion:\n{question}'), additional_kwargs={})])
| FakeListChatModel(responses=['foo, bar']) | FakeListChatModel(model_kwargs={}, responses=['foo, bar'])
| CommaSeparatedListOutputParser() | CommaSeparatedListOutputParser()
''' '''
# --- # ---
@ -13738,7 +13738,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=['foo, bar'])", "repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
} }
], ],
@ -13764,7 +13764,7 @@
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})]) ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
| RunnableLambda(...) | RunnableLambda(...)
| { | {
chat: FakeListChatModel(responses=["i'm a chatbot"]), chat: FakeListChatModel(model_kwargs={}, responses=["i'm a chatbot"]),
llm: FakeListLLM(responses=["i'm a textbot"]) llm: FakeListLLM(responses=["i'm a textbot"])
} }
''' '''
@ -13890,7 +13890,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])", "repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
}, },
"llm": { "llm": {
@ -14045,7 +14045,7 @@
"fake_chat_models", "fake_chat_models",
"FakeListChatModel" "FakeListChatModel"
], ],
"repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])", "repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])",
"name": "FakeListChatModel" "name": "FakeListChatModel"
}, },
"kwargs": { "kwargs": {