From f4a47ec7175925270e09e4161f9375b4305645ac Mon Sep 17 00:00:00 2001 From: colegottdank Date: Wed, 9 Aug 2023 17:34:00 -0700 Subject: [PATCH] Add optional model kwargs to ChatAnthropic to allow overrides (#9013) --------- Co-authored-by: Bagatur --- libs/langchain/langchain/llms/anthropic.py | 22 ++++++++++++--- libs/langchain/poetry.lock | 4 +-- libs/langchain/pyproject.toml | 1 + .../unit_tests/chat_models/test_anthropic.py | 27 +++++++++++++++++++ 4 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/chat_models/test_anthropic.py diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index f32f581d1f7..5e5695762f2 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -2,7 +2,7 @@ import re import warnings from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional -from pydantic import root_validator +from pydantic import Field, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -11,7 +11,12 @@ from langchain.callbacks.manager import ( from langchain.llms.base import LLM from langchain.schema.language_model import BaseLanguageModel from langchain.schema.output import GenerationChunk -from langchain.utils import check_package_version, get_from_dict_or_env +from langchain.utils import ( + check_package_version, + get_from_dict_or_env, + get_pydantic_field_names, +) +from langchain.utils.utils import build_extra_kwargs class _AnthropicCommon(BaseLanguageModel): @@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel): HUMAN_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None count_tokens: Optional[Callable[[str], int]] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator(pre=True) + def build_extra(cls, values: Dict) -> Dict: + extra = values.get("model_kwargs", {}) + all_required_field_names = get_pydantic_field_names(cls) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel): values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT values["AI_PROMPT"] = anthropic.AI_PROMPT values["count_tokens"] = values["client"].count_tokens + except ImportError: raise ImportError( "Could not import anthropic python package. " @@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel): d["top_k"] = self.top_k if self.top_p is not None: d["top_p"] = self.top_p - return d + return {**d, **self.model_kwargs} @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index dc34e82d2e6..712a4406bff 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -13608,7 +13608,7 @@ clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"] +extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"] javascript = ["esprima"] llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"] openai = ["openai", "tiktoken"] @@ -13619,4 +13619,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0" +content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 2d5aef61224..3fb30ee0d6c 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -373,6 +373,7 @@ extended_testing = [ "feedparser", "xata", "xmltodict", + "anthropic", ] scheduled_testing = [ diff --git a/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py new file mode 100644 index 00000000000..7447ec03e41 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py @@ -0,0 +1,27 @@ +"""Test Anthropic Chat API wrapper.""" +import os + +import pytest + +from langchain.chat_models import ChatAnthropic + +os.environ["ANTHROPIC_API_KEY"] = "foo" + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_kwargs() -> None: + llm = ChatAnthropic(model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("anthropic") +def test_anthropic_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5}) + + +@pytest.mark.requires("anthropic") +def test_anthropic_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = ChatAnthropic(foo="bar") + assert llm.model_kwargs == {"foo": "bar"}