mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
Anthropic: Allow the use of kwargs consistent with ChatOpenAI. (#9515)
- Description: ~~Creates a new root_validator in `_AnthropicCommon` that allows the use of `model_name` and `max_tokens` keyword arguments.~~ Adds pydantic field aliases to support `model_name` and `max_tokens` as keyword arguments. Ultimately, this makes `ChatAnthropic` more consistent with `ChatOpenAI`, making the two classes more interchangeable for the developer. - Issue: https://github.com/langchain-ai/langchain/issues/9510 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a8c916955f
commit
a9c86774da
@ -36,6 +36,12 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
|
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
|
||||||
|
@ -21,10 +21,10 @@ from langchain.utils.utils import build_extra_kwargs
|
|||||||
class _AnthropicCommon(BaseLanguageModel):
|
class _AnthropicCommon(BaseLanguageModel):
|
||||||
client: Any = None #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
async_client: Any = None #: :meta private:
|
async_client: Any = None #: :meta private:
|
||||||
model: str = "claude-2"
|
model: str = Field(default="claude-2", alias="model_name")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
|
||||||
max_tokens_to_sample: int = 256
|
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
|
||||||
"""Denotes the number of tokens to predict per generation."""
|
"""Denotes the number of tokens to predict per generation."""
|
||||||
|
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
@ -144,6 +144,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from langchain.llms import Anthropic
|
from langchain.llms import Anthropic
|
||||||
|
|
||||||
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||||
|
|
||||||
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
|
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
|
||||||
@ -157,6 +158,12 @@ class Anthropic(LLM, _AnthropicCommon):
|
|||||||
response = model(prompt)
|
response = model(prompt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def raise_warning(cls, values: Dict) -> Dict:
|
def raise_warning(cls, values: Dict) -> Dict:
|
||||||
"""Raise warning that this class is deprecated."""
|
"""Raise warning that this class is deprecated."""
|
||||||
|
@ -8,6 +8,18 @@ from langchain.chat_models import ChatAnthropic
|
|||||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_name_param() -> None:
|
||||||
|
llm = ChatAnthropic(model_name="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_param() -> None:
|
||||||
|
llm = ChatAnthropic(model="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("anthropic")
|
@pytest.mark.requires("anthropic")
|
||||||
def test_anthropic_model_kwargs() -> None:
|
def test_anthropic_model_kwargs() -> None:
|
||||||
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
|
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
|
||||||
|
@ -9,6 +9,18 @@ from langchain.schema import LLMResult
|
|||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_name_param() -> None:
|
||||||
|
llm = Anthropic(model_name="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_param() -> None:
|
||||||
|
llm = Anthropic(model="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_call() -> None:
|
def test_anthropic_call() -> None:
|
||||||
"""Test valid call to anthropic."""
|
"""Test valid call to anthropic."""
|
||||||
llm = Anthropic(model="claude-instant-1")
|
llm = Anthropic(model="claude-instant-1")
|
||||||
@ -24,7 +36,7 @@ def test_anthropic_streaming() -> None:
|
|||||||
assert isinstance(generator, Generator)
|
assert isinstance(generator, Generator)
|
||||||
|
|
||||||
for token in generator:
|
for token in generator:
|
||||||
assert isinstance(token["completion"], str)
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_streaming_callback() -> None:
|
def test_anthropic_streaming_callback() -> None:
|
||||||
|
@ -6,9 +6,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.adapters.openai import convert_dict_to_message
|
from langchain.adapters.openai import convert_dict_to_message
|
||||||
from langchain.chat_models.openai import (
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
ChatOpenAI,
|
|
||||||
)
|
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
@ -17,6 +15,14 @@ from langchain.schema.messages import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_openai_model_param() -> None:
|
||||||
|
llm = ChatOpenAI(model="foo")
|
||||||
|
assert llm.model_name == "foo"
|
||||||
|
llm = ChatOpenAI(model_name="foo")
|
||||||
|
assert llm.model_name == "foo"
|
||||||
|
|
||||||
|
|
||||||
def test_function_message_dict_to_function_message() -> None:
|
def test_function_message_dict_to_function_message() -> None:
|
||||||
content = json.dumps({"result": "Example #1"})
|
content = json.dumps({"result": "Example #1"})
|
||||||
name = "test_function"
|
name = "test_function"
|
||||||
|
Loading…
Reference in New Issue
Block a user