mirror of
https://github.com/hwchase17/langchain.git
synced 2025-12-03 17:17:21 +00:00
1. Adds `.get_ls_params` to BaseChatModel which returns
```python
class LangSmithParams(TypedDict, total=False):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
```
by default it will only return
```python
{ls_model_type="chat", ls_stop=stop}
```
2. Add these params to inheritable metadata in
`CallbackManager.configure`
3. Implement `.get_ls_params` and populate all params for Anthropic +
all subclasses of BaseChatOpenAI
Sample trace:
https://smith.langchain.com/public/d2962673-4c83-47c7-b51e-61d07aaffb1b/r
**OpenAI**:
<img width="984" alt="Screenshot 2024-05-17 at 10 03 35 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/2ef41f74-a9df-4e0e-905d-da74fa82a910">
**Anthropic**:
<img width="978" alt="Screenshot 2024-05-17 at 10 06 07 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/39701c9f-7da5-4f1a-ab14-84e9169d63e7">
**Mistral** (and all others for which params are not yet populated):
<img width="977" alt="Screenshot 2024-05-17 at 10 08 43 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/37d7d894-fec2-4300-986f-49a5f0191b03">
118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import List, Literal, Optional, Type
|
|
|
|
import pytest
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
|
|
from langchain_core.tools import tool
|
|
|
|
|
|
class Person(BaseModel):
|
|
name: str = Field(..., description="The name of the person.")
|
|
age: int = Field(..., description="The age of the person.")
|
|
|
|
|
|
@tool
|
|
def my_adder_tool(a: int, b: int) -> int:
|
|
"""Takes two integers, a and b, and returns their sum."""
|
|
return a + b
|
|
|
|
|
|
class ChatModelUnitTests(ABC):
|
|
@abstractmethod
|
|
@pytest.fixture
|
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
|
...
|
|
|
|
@pytest.fixture
|
|
def chat_model_params(self) -> dict:
|
|
return {}
|
|
|
|
@pytest.fixture
|
|
def chat_model_has_tool_calling(
|
|
self, chat_model_class: Type[BaseChatModel]
|
|
) -> bool:
|
|
return chat_model_class.bind_tools is not BaseChatModel.bind_tools
|
|
|
|
@pytest.fixture
|
|
def chat_model_has_structured_output(
|
|
self, chat_model_class: Type[BaseChatModel]
|
|
) -> bool:
|
|
return (
|
|
chat_model_class.with_structured_output
|
|
is not BaseChatModel.with_structured_output
|
|
)
|
|
|
|
def test_chat_model_init(
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
) -> None:
|
|
model = chat_model_class(**chat_model_params)
|
|
assert model is not None
|
|
|
|
def test_chat_model_init_api_key(
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
) -> None:
|
|
params = {**chat_model_params, "api_key": "test"}
|
|
model = chat_model_class(**params) # type: ignore
|
|
assert model is not None
|
|
|
|
def test_chat_model_init_streaming(
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
) -> None:
|
|
model = chat_model_class(streaming=True, **chat_model_params) # type: ignore
|
|
assert model is not None
|
|
|
|
def test_chat_model_bind_tool_pydantic(
|
|
self,
|
|
chat_model_class: Type[BaseChatModel],
|
|
chat_model_params: dict,
|
|
chat_model_has_tool_calling: bool,
|
|
) -> None:
|
|
if not chat_model_has_tool_calling:
|
|
return
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
|
|
assert hasattr(model, "bind_tools")
|
|
tool_model = model.bind_tools([Person])
|
|
assert tool_model is not None
|
|
|
|
def test_chat_model_with_structured_output(
|
|
self,
|
|
chat_model_class: Type[BaseChatModel],
|
|
chat_model_params: dict,
|
|
chat_model_has_structured_output: bool,
|
|
) -> None:
|
|
if not chat_model_has_structured_output:
|
|
return
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
assert model is not None
|
|
assert model.with_structured_output(Person) is not None
|
|
|
|
def test_standard_params(
|
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
|
) -> None:
|
|
class ExpectedParams(BaseModel):
|
|
ls_provider: str
|
|
ls_model_name: str
|
|
ls_model_type: Literal["chat"]
|
|
ls_temperature: Optional[float]
|
|
ls_max_tokens: Optional[int]
|
|
ls_stop: Optional[List[str]]
|
|
|
|
model = chat_model_class(**chat_model_params)
|
|
ls_params = model._get_ls_params()
|
|
try:
|
|
ExpectedParams(**ls_params)
|
|
except ValidationError as e:
|
|
pytest.fail(f"Validation error: {e}")
|
|
|
|
# Test optional params
|
|
model = chat_model_class(max_tokens=10, stop=["test"], **chat_model_params)
|
|
ls_params = model._get_ls_params()
|
|
try:
|
|
ExpectedParams(**ls_params)
|
|
except ValidationError as e:
|
|
pytest.fail(f"Validation error: {e}")
|