mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
core, standard tests, partner packages: add test for model params (#21677)
1. Adds `.get_ls_params` to BaseChatModel which returns ```python class LangSmithParams(TypedDict, total=False): ls_provider: str ls_model_name: str ls_model_type: Literal["chat"] ls_temperature: Optional[float] ls_max_tokens: Optional[int] ls_stop: Optional[List[str]] ``` by default it will only return ```python {ls_model_type="chat", ls_stop=stop} ``` 2. Add these params to inheritable metadata in `CallbackManager.configure` 3. Implement `.get_ls_params` and populate all params for Anthropic + all subclasses of BaseChatOpenAI Sample trace: https://smith.langchain.com/public/d2962673-4c83-47c7-b51e-61d07aaffb1b/r **OpenAI**: <img width="984" alt="Screenshot 2024-05-17 at 10 03 35 AM" src="https://github.com/langchain-ai/langchain/assets/26529506/2ef41f74-a9df-4e0e-905d-da74fa82a910"> **Anthropic**: <img width="978" alt="Screenshot 2024-05-17 at 10 06 07 AM" src="https://github.com/langchain-ai/langchain/assets/26529506/39701c9f-7da5-4f1a-ab14-84e9169d63e7"> **Mistral** (and all others for which params are not yet populated): <img width="977" alt="Screenshot 2024-05-17 at 10 08 43 AM" src="https://github.com/langchain-ai/langchain/assets/26529506/37d7d894-fec2-4300-986f-49a5f0191b03">
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type
|
||||
from typing import List, Literal, Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@@ -89,3 +89,29 @@ class ChatModelUnitTests(ABC):
|
||||
model = chat_model_class(**chat_model_params)
|
||||
assert model is not None
|
||||
assert model.with_structured_output(Person) is not None
|
||||
|
||||
def test_standard_params(
|
||||
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||
) -> None:
|
||||
class ExpectedParams(BaseModel):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
ls_model_type: Literal["chat"]
|
||||
ls_temperature: Optional[float]
|
||||
ls_max_tokens: Optional[int]
|
||||
ls_stop: Optional[List[str]]
|
||||
|
||||
model = chat_model_class(**chat_model_params)
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
||||
# Test optional params
|
||||
model = chat_model_class(max_tokens=10, stop=["test"], **chat_model_params)
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
Reference in New Issue
Block a user