mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
core, standard tests, partner packages: add test for model params (#21677)
1. Adds `.get_ls_params` to BaseChatModel which returns ```python class LangSmithParams(TypedDict, total=False): ls_provider: str ls_model_name: str ls_model_type: Literal["chat"] ls_temperature: Optional[float] ls_max_tokens: Optional[int] ls_stop: Optional[List[str]] ``` by default it will only return ```python {ls_model_type="chat", ls_stop=stop} ``` 2. Add these params to inheritable metadata in `CallbackManager.configure` 3. Implement `.get_ls_params` and populate all params for Anthropic + all subclasses of BaseChatOpenAI Sample trace: https://smith.langchain.com/public/d2962673-4c83-47c7-b51e-61d07aaffb1b/r **OpenAI**: <img width="984" alt="Screenshot 2024-05-17 at 10 03 35 AM" src="https://github.com/langchain-ai/langchain/assets/26529506/2ef41f74-a9df-4e0e-905d-da74fa82a910"> **Anthropic**: <img width="978" alt="Screenshot 2024-05-17 at 10 06 07 AM" src="https://github.com/langchain-ai/langchain/assets/26529506/39701c9f-7da5-4f1a-ab14-84e9169d63e7"> **Mistral** (and all others for which params are not yet populated): <img width="977" alt="Screenshot 2024-05-17 at 10 08 43 AM" src="https://github.com/langchain-ai/langchain/assets/26529506/37d7d894-fec2-4300-986f-49a5f0191b03">
This commit is contained in:
parent
4ca2149b70
commit
181dfef118
@ -13,6 +13,7 @@ from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
@ -20,6 +21,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
@ -60,6 +63,15 @@ if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
class LangSmithParams(TypedDict, total=False):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
ls_model_type: Literal["chat"]
|
||||
ls_temperature: Optional[float]
|
||||
ls_max_tokens: Optional[int]
|
||||
ls_stop: Optional[List[str]]
|
||||
|
||||
|
||||
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
"""Generate from a stream."""
|
||||
|
||||
@ -206,13 +218,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
@ -273,13 +289,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
@ -336,6 +356,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
ls_params = LangSmithParams(ls_model_type="chat")
|
||||
if stop:
|
||||
ls_params["ls_stop"] = stop
|
||||
return ls_params
|
||||
|
||||
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
||||
if self.is_lc_serializable():
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
@ -385,6 +416,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
@ -392,7 +427,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
@ -472,6 +507,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
@ -479,7 +518,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
@ -476,7 +476,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -484,7 +484,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -492,7 +492,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -500,7 +500,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -526,7 +526,7 @@ async def test_astream_events_from_model() -> None:
|
||||
},
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -569,7 +569,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -577,7 +577,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -585,7 +585,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -593,7 +593,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -619,7 +619,7 @@ async def test_astream_events_from_model() -> None:
|
||||
},
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -724,7 +724,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
}
|
||||
},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -732,7 +737,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -748,7 +758,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -764,7 +779,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -805,7 +825,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
},
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
|
@ -417,7 +417,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -425,7 +425,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -433,7 +433,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -441,7 +441,7 @@ async def test_astream_events_from_model() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -451,7 +451,7 @@ async def test_astream_events_from_model() -> None:
|
||||
"output": AIMessageChunk(content="hello world!", id=AnyStr()),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -495,7 +495,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -503,7 +503,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -511,7 +511,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -519,7 +519,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -530,7 +530,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"output": AIMessage(content="hello world!", id=AnyStr()),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -573,7 +573,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -581,7 +581,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -589,7 +589,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -597,7 +597,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -608,7 +608,7 @@ async def test_astream_with_model_in_chain() -> None:
|
||||
"output": AIMessage(content="hello world!", id=AnyStr()),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
@ -713,7 +713,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
}
|
||||
},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -721,7 +726,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -737,7 +747,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -753,7 +768,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!", id="ai1")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -779,7 +799,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"output": AIMessageChunk(content="hello world!", id="ai1"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b", "foo": "bar"},
|
||||
"metadata": {
|
||||
"a": "b",
|
||||
"foo": "bar",
|
||||
"ls_model_type": "chat",
|
||||
"ls_stop": "<stop_token>",
|
||||
},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
@ -1459,7 +1484,7 @@ async def test_events_astream_config() -> None:
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {},
|
||||
"metadata": {"ls_model_type": "chat"},
|
||||
"name": "GenericFakeChatModel",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
@ -1467,7 +1492,7 @@ async def test_events_astream_config() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="Goodbye", id="ai2")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"metadata": {"ls_model_type": "chat"},
|
||||
"name": "GenericFakeChatModel",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
@ -1475,7 +1500,7 @@ async def test_events_astream_config() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ", id="ai2")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"metadata": {"ls_model_type": "chat"},
|
||||
"name": "GenericFakeChatModel",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
@ -1483,7 +1508,7 @@ async def test_events_astream_config() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world", id="ai2")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"metadata": {"ls_model_type": "chat"},
|
||||
"name": "GenericFakeChatModel",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
@ -1493,7 +1518,7 @@ async def test_events_astream_config() -> None:
|
||||
"output": AIMessageChunk(content="Goodbye world", id="ai2"),
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {},
|
||||
"metadata": {"ls_model_type": "chat"},
|
||||
"name": "GenericFakeChatModel",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
|
@ -21,6 +21,17 @@ class TestAI21J2(ChatModelUnitTests):
|
||||
"api_key": "test_api_key",
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_standard_params(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_standard_params(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
||||
|
||||
class TestAI21Jamba(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
@ -33,3 +44,14 @@ class TestAI21Jamba(ChatModelUnitTests):
|
||||
"model": "jamba-instruct",
|
||||
"api_key": "test_api_key",
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_standard_params(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_standard_params(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
@ -326,6 +327,23 @@ class ChatAnthropic(BaseChatModel):
|
||||
"default_request_timeout": self.default_request_timeout,
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="anthropic",
|
||||
ls_model_name=self.model,
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict) -> Dict:
|
||||
extra = values.get("model_kwargs", {})
|
||||
|
@ -19,3 +19,14 @@ class TestFireworksStandard(ChatModelUnitTests):
|
||||
return {
|
||||
"api_key": "test_api_key",
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_standard_params(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_standard_params(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
@ -13,3 +13,14 @@ class TestGroqStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_standard_params(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_standard_params(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
@ -13,3 +13,14 @@ class TestMistralStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatMistralAI
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_standard_params(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
) -> None:
|
||||
super().test_standard_params(
|
||||
chat_model_class,
|
||||
chat_model_params,
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
@ -228,6 +229,16 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"openai_api_version": self.openai_api_version,
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "azure"
|
||||
if self.deployment_name:
|
||||
params["ls_model_name"] = self.deployment_name
|
||||
return params
|
||||
|
||||
def _create_chat_result(
|
||||
self, response: Union[dict, openai.BaseModel]
|
||||
) -> ChatResult:
|
||||
|
@ -36,6 +36,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
@ -639,6 +640,23 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="openai",
|
||||
ls_model_name=self.model_name,
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
@ -28,3 +28,7 @@ def test_initialize_more() -> None:
|
||||
assert llm.deployment_name == "35-turbo-dev"
|
||||
assert llm.openai_api_version == "2023-05-15"
|
||||
assert llm.temperature == 0
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_provider"] == "azure"
|
||||
assert ls_params["ls_model_name"] == "35-turbo-dev"
|
||||
|
@ -9,6 +9,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
@ -54,6 +55,14 @@ class ChatTogether(BaseChatOpenAI):
|
||||
"""Return type of chat model."""
|
||||
return "together-chat"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "together"
|
||||
return params
|
||||
|
||||
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
|
||||
"""Model name to use."""
|
||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
|
@ -28,6 +28,8 @@ def test_together_model_param() -> None:
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatTogether(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_provider"] == "together"
|
||||
|
||||
|
||||
def test_function_dict_to_message_function_message() -> None:
|
||||
|
@ -7,6 +7,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
@ -52,6 +53,14 @@ class ChatUpstage(BaseChatOpenAI):
|
||||
"""Return type of chat model."""
|
||||
return "upstage-chat"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "upstage"
|
||||
return params
|
||||
|
||||
model_name: str = Field(default="solar-1-mini-chat", alias="model")
|
||||
"""Model name to use."""
|
||||
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
|
@ -28,6 +28,8 @@ def test_upstage_model_param() -> None:
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatUpstage(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_provider"] == "upstage"
|
||||
|
||||
|
||||
def test_function_dict_to_message_function_message() -> None:
|
||||
|
@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type
|
||||
from typing import List, Literal, Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@ -89,3 +89,29 @@ class ChatModelUnitTests(ABC):
|
||||
model = chat_model_class(**chat_model_params)
|
||||
assert model is not None
|
||||
assert model.with_structured_output(Person) is not None
|
||||
|
||||
def test_standard_params(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
class ExpectedParams(BaseModel):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
ls_model_type: Literal["chat"]
|
||||
ls_temperature: Optional[float]
|
||||
ls_max_tokens: Optional[int]
|
||||
ls_stop: Optional[List[str]]
|
||||
|
||||
model = chat_model_class(**chat_model_params)
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
||||
# Test optional params
|
||||
model = chat_model_class(max_tokens=10, stop=["test"], **chat_model_params)
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
Loading…
Reference in New Issue
Block a user