Add optional model kwargs to ChatAnthropic to allow overrides (#9013)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
colegottdank 2023-08-09 17:34:00 -07:00 committed by GitHub
parent 3b51817706
commit f4a47ec717
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 5 deletions

View File

@ -2,7 +2,7 @@ import re
import warnings
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
from pydantic import root_validator
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -11,7 +11,12 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk
from langchain.utils import check_package_version, get_from_dict_or_env
from langchain.utils import (
check_package_version,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain.utils.utils import build_extra_kwargs
class _AnthropicCommon(BaseLanguageModel):
@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel):
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel):
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel):
d["top_k"] = self.top_k
if self.top_p is not None:
d["top_p"] = self.top_p
return d
return {**d, **self.model_kwargs}
@property
def _identifying_params(self) -> Mapping[str, Any]:

View File

@ -13608,7 +13608,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
javascript = ["esprima"]
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
openai = ["openai", "tiktoken"]
@ -13619,4 +13619,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0"
content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e"

View File

@ -373,6 +373,7 @@ extended_testing = [
"feedparser",
"xata",
"xmltodict",
"anthropic",
]
scheduled_testing = [

View File

@ -0,0 +1,27 @@
"""Test Anthropic Chat API wrapper."""
import os
import pytest
from langchain.chat_models import ChatAnthropic
os.environ["ANTHROPIC_API_KEY"] = "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_kwargs() -> None:
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
@pytest.mark.requires("anthropic")
def test_anthropic_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = ChatAnthropic(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}