mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 03:31:51 +00:00
implement in core
This commit is contained in:
parent
a7903280dd
commit
a85e0aed5f
@ -11,10 +11,11 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
from typing import ( # noqa: UP035
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
@ -70,11 +71,13 @@ from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
from langchain_core.utils.utils import _build_model_kwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
@ -302,6 +305,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
- If False (default), will always use streaming case if available.
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict) # noqa: UP006
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
@ -329,6 +335,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
return _build_model_kwargs(values, all_required_field_names)
|
||||
|
||||
@cached_property
|
||||
def _serialized(self) -> dict[str, Any]:
|
||||
return dumpd(self)
|
||||
|
@ -403,6 +403,40 @@ async def test_disable_streaming_no_streaming_model_async(
|
||||
break
|
||||
|
||||
|
||||
def test_model_kwargs() -> None:
|
||||
llm = FakeListChatModel(
|
||||
responses=["a", "b", "c"],
|
||||
sleep=0.1,
|
||||
disable_streaming=False,
|
||||
model_kwargs={"foo": "bar"},
|
||||
)
|
||||
assert llm.responses == ["a", "b", "c"]
|
||||
assert llm.sleep == 0.1
|
||||
assert llm.disable_streaming is False
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
with pytest.warns(match="transferred to model_kwargs"):
|
||||
llm = FakeListChatModel(
|
||||
responses=["a", "b", "c"],
|
||||
sleep=0.1,
|
||||
disable_streaming=False,
|
||||
foo="bar", # type: ignore[call-arg]
|
||||
)
|
||||
assert llm.responses == ["a", "b", "c"]
|
||||
assert llm.sleep == 0.1
|
||||
assert llm.disable_streaming is False
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
# Backward compatibility
|
||||
with pytest.warns(match="should be specified explicitly"):
|
||||
llm = FakeListChatModel( # type: ignore[call-arg]
|
||||
model_kwargs={"foo": "bar", "responses": ["a", "b", "c"], "sleep": 0.1},
|
||||
)
|
||||
assert llm.responses == ["a", "b", "c"]
|
||||
assert llm.sleep == 0.1
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
class FakeChatModelStartTracer(FakeTracer):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -97,7 +97,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['foo, bar'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
|
||||
"name": "FakeListChatModel"
|
||||
}
|
||||
],
|
||||
@ -227,7 +227,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['baz, qux'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])",
|
||||
"name": "FakeListChatModel"
|
||||
}
|
||||
],
|
||||
@ -346,7 +346,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['foo, bar'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
|
||||
"name": "FakeListChatModel"
|
||||
},
|
||||
{
|
||||
@ -457,7 +457,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['baz, qux'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['baz, qux'])",
|
||||
"name": "FakeListChatModel"
|
||||
}
|
||||
],
|
||||
@ -1009,7 +1009,7 @@
|
||||
# name: test_prompt_with_chat_model
|
||||
'''
|
||||
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
|
||||
| FakeListChatModel(responses=['foo'])
|
||||
| FakeListChatModel(model_kwargs={}, responses=['foo'])
|
||||
'''
|
||||
# ---
|
||||
# name: test_prompt_with_chat_model.1
|
||||
@ -1109,7 +1109,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['foo'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])",
|
||||
"name": "FakeListChatModel"
|
||||
}
|
||||
},
|
||||
@ -1220,7 +1220,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['foo, bar'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
|
||||
"name": "FakeListChatModel"
|
||||
}
|
||||
],
|
||||
@ -1249,7 +1249,7 @@
|
||||
# name: test_prompt_with_chat_model_async
|
||||
'''
|
||||
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
|
||||
| FakeListChatModel(responses=['foo'])
|
||||
| FakeListChatModel(model_kwargs={}, responses=['foo'])
|
||||
'''
|
||||
# ---
|
||||
# name: test_prompt_with_chat_model_async.1
|
||||
@ -1349,7 +1349,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['foo'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo'])",
|
||||
"name": "FakeListChatModel"
|
||||
}
|
||||
},
|
||||
@ -13535,7 +13535,7 @@
|
||||
just_to_test_lambda: RunnableLambda(...)
|
||||
}
|
||||
| ChatPromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, template='Context:\n{documents}\n\nQuestion:\n{question}'), additional_kwargs={})])
|
||||
| FakeListChatModel(responses=['foo, bar'])
|
||||
| FakeListChatModel(model_kwargs={}, responses=['foo, bar'])
|
||||
| CommaSeparatedListOutputParser()
|
||||
'''
|
||||
# ---
|
||||
@ -13738,7 +13738,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=['foo, bar'])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=['foo, bar'])",
|
||||
"name": "FakeListChatModel"
|
||||
}
|
||||
],
|
||||
@ -13764,7 +13764,7 @@
|
||||
ChatPromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='You are a nice assistant.'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], input_types={}, partial_variables={}, template='{question}'), additional_kwargs={})])
|
||||
| RunnableLambda(...)
|
||||
| {
|
||||
chat: FakeListChatModel(responses=["i'm a chatbot"]),
|
||||
chat: FakeListChatModel(model_kwargs={}, responses=["i'm a chatbot"]),
|
||||
llm: FakeListLLM(responses=["i'm a textbot"])
|
||||
}
|
||||
'''
|
||||
@ -13890,7 +13890,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])",
|
||||
"name": "FakeListChatModel"
|
||||
},
|
||||
"llm": {
|
||||
@ -14045,7 +14045,7 @@
|
||||
"fake_chat_models",
|
||||
"FakeListChatModel"
|
||||
],
|
||||
"repr": "FakeListChatModel(responses=[\"i'm a chatbot\"])",
|
||||
"repr": "FakeListChatModel(model_kwargs={}, responses=[\"i'm a chatbot\"])",
|
||||
"name": "FakeListChatModel"
|
||||
},
|
||||
"kwargs": {
|
||||
|
Loading…
Reference in New Issue
Block a user