implement in core

This commit is contained in:
Chester Curme 2025-04-24 16:36:59 -04:00
parent a7903280dd
commit a85e0aed5f
3 changed files with 62 additions and 15 deletions

View File

@ -11,10 +11,11 @@ from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from functools import cached_property
from operator import itemgetter
from typing import (
from typing import ( # noqa: UP035
TYPE_CHECKING,
Any,
Callable,
Dict,
Literal,
Optional,
Union,
@ -70,11 +71,13 @@ from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.function_calling import (
convert_to_json_schema,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
from langchain_core.utils.utils import _build_model_kwargs
if TYPE_CHECKING:
import uuid
@ -302,6 +305,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
- If False (default), will always use streaming case if available.
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) # noqa: UP006
"""Holds any model parameters valid for `create` call not explicitly specified."""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: dict) -> Any:
@ -329,6 +335,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
arbitrary_types_allowed=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
return _build_model_kwargs(values, all_required_field_names)
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)

View File

@ -403,6 +403,40 @@ async def test_disable_streaming_no_streaming_model_async(
break
def test_model_kwargs() -> None:
llm = FakeListChatModel(
responses=["a", "b", "c"],
sleep=0.1,
disable_streaming=False,
model_kwargs={"foo": "bar"},
)
assert llm.responses == ["a", "b", "c"]
assert llm.sleep == 0.1
assert llm.disable_streaming is False
assert llm.model_kwargs == {"foo": "bar"}
with pytest.warns(match="transferred to model_kwargs"):
llm = FakeListChatModel(
responses=["a", "b", "c"],
sleep=0.1,
disable_streaming=False,
foo="bar", # type: ignore[call-arg]
)
assert llm.responses == ["a", "b", "c"]
assert llm.sleep == 0.1
assert llm.disable_streaming is False
assert llm.model_kwargs == {"foo": "bar"}
# Backward compatibility
with pytest.warns(match="should be specified explicitly"):
llm = FakeListChatModel( # type: ignore[call-arg]
model_kwargs={"foo": "bar", "responses": ["a", "b", "c"], "sleep": 0.1},
)
assert llm.responses == ["a", "b", "c"]
assert llm.sleep == 0.1
assert llm.model_kwargs == {"foo": "bar"}
class FakeChatModelStartTracer(FakeTracer):
def __init__(self) -> None:
super().__init__()

View File

@ -97,7 +97,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['foo, bar'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel"
}
],
@ -227,7 +227,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['baz, qux'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])",
"name": "FakeListChatModel"
}
],
@ -346,7 +346,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['foo, bar'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel"
},
{
@ -457,7 +457,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['baz, qux'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])",
"name": "FakeListChatModel"
}
],
@ -1009,7 +1009,7 @@
# name: test_prompt_with_chat_model
'''
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
| FakeListChatModel(responses=['foo'])
| FakeListChatModel(model_kwargs={}, responses=['foo'])
'''
# ---
# name: test_prompt_with_chat_model.1
@ -1109,7 +1109,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['foo'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])",
"name": "FakeListChatModel"
}
},
@ -1220,7 +1220,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['foo, bar'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel"
}
],
@ -1249,7 +1249,7 @@
# name: test_prompt_with_chat_model_async
'''
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
| FakeListChatModel(responses=['foo'])
| FakeListChatModel(model_kwargs={}, responses=['foo'])
'''
# ---
# name: test_prompt_with_chat_model_async.1
@ -1349,7 +1349,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['foo'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])",
"name": "FakeListChatModel"
}
},
@ -13535,7 +13535,7 @@
just_to_test_lambda: RunnableLambda(...)
}
| ChatPromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, template='Context:\n{documents}\n\nQuestion:\n{question}'), additional_kwargs={})])
| FakeListChatModel(responses=['foo, bar'])
| FakeListChatModel(model_kwargs={}, responses=['foo, bar'])
| CommaSeparatedListOutputParser()
'''
# ---
@ -13738,7 +13738,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=['foo, bar'])",
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
"name": "FakeListChatModel"
}
],
@ -13764,7 +13764,7 @@
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
| RunnableLambda(...)
| {
chat: FakeListChatModel(responses=["i'm a chatbot"]),
chat: FakeListChatModel(model_kwargs={}, responses=["i'm a chatbot"]),
llm: FakeListLLM(responses=["i'm a textbot"])
}
'''
@ -13890,7 +13890,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])",
"repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])",
"name": "FakeListChatModel"
},
"llm": {
@ -14045,7 +14045,7 @@
"fake_chat_models",
"FakeListChatModel"
],
"repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])",
"repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])",
"name": "FakeListChatModel"
},
"kwargs": {