mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
Anthropic: Allow the use of kwargs consistent with ChatOpenAI. (#9515)
- Description: ~~Creates a new root_validator in `_AnthropicCommon` that allows the use of `model_name` and `max_tokens` keyword arguments.~~ Adds pydantic field aliases to support `model_name` and `max_tokens` as keyword arguments. Ultimately, this makes `ChatAnthropic` more consistent with `ChatOpenAI`, making the two classes more interchangeable for the developer. - Issue: https://github.com/langchain-ai/langchain/issues/9510 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a8c916955f
commit
a9c86774da
@ -36,6 +36,12 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
|
||||
|
@ -21,10 +21,10 @@ from langchain.utils.utils import build_extra_kwargs
|
||||
class _AnthropicCommon(BaseLanguageModel):
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = None #: :meta private:
|
||||
model: str = "claude-2"
|
||||
model: str = Field(default="claude-2", alias="model_name")
|
||||
"""Model name to use."""
|
||||
|
||||
max_tokens_to_sample: int = 256
|
||||
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
@ -144,6 +144,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
|
||||
import anthropic
|
||||
from langchain.llms import Anthropic
|
||||
|
||||
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||
|
||||
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
|
||||
@ -157,6 +158,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
response = model(prompt)
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def raise_warning(cls, values: Dict) -> Dict:
|
||||
"""Raise warning that this class is deprecated."""
|
||||
|
@ -8,6 +8,18 @@ from langchain.chat_models import ChatAnthropic
|
||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_name_param() -> None:
|
||||
llm = ChatAnthropic(model_name="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_param() -> None:
|
||||
llm = ChatAnthropic(model="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_kwargs() -> None:
|
||||
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
|
||||
|
@ -9,6 +9,18 @@ from langchain.schema import LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_name_param() -> None:
|
||||
llm = Anthropic(model_name="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_param() -> None:
|
||||
llm = Anthropic(model="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
def test_anthropic_call() -> None:
|
||||
"""Test valid call to anthropic."""
|
||||
llm = Anthropic(model="claude-instant-1")
|
||||
@ -24,7 +36,7 @@ def test_anthropic_streaming() -> None:
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for token in generator:
|
||||
assert isinstance(token["completion"], str)
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_anthropic_streaming_callback() -> None:
|
||||
|
@ -6,9 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
)
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
@ -17,6 +15,14 @@ from langchain.schema.messages import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_param() -> None:
|
||||
llm = ChatOpenAI(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatOpenAI(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
def test_function_message_dict_to_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
|
Loading…
Reference in New Issue
Block a user