Add optional model kwargs to ChatAnthropic to allow overrides (#9013)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
colegottdank 2023-08-09 17:34:00 -07:00 committed by GitHub
parent 3b51817706
commit f4a47ec717
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 5 deletions

View File

@ -2,7 +2,7 @@ import re
import warnings import warnings
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
from pydantic import root_validator from pydantic import Field, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -11,7 +11,12 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.utils import check_package_version, get_from_dict_or_env from langchain.utils import (
check_package_version,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain.utils.utils import build_extra_kwargs
class _AnthropicCommon(BaseLanguageModel): class _AnthropicCommon(BaseLanguageModel):
@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel):
HUMAN_PROMPT: Optional[str] = None HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel):
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens values["count_tokens"] = values["client"].count_tokens
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import anthropic python package. " "Could not import anthropic python package. "
@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel):
d["top_k"] = self.top_k d["top_k"] = self.top_k
if self.top_p is not None: if self.top_p is not None:
d["top_p"] = self.top_p d["top_p"] = self.top_p
return d return {**d, **self.model_kwargs}
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:

View File

@ -13608,7 +13608,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"] cohere = ["cohere"]
docarray = ["docarray"] docarray = ["docarray"]
embeddings = ["sentence-transformers"] embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"] extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
javascript = ["esprima"] javascript = ["esprima"]
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"] llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
openai = ["openai", "tiktoken"] openai = ["openai", "tiktoken"]
@ -13619,4 +13619,4 @@ text-helpers = ["chardet"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0" content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e"

View File

@ -373,6 +373,7 @@ extended_testing = [
"feedparser", "feedparser",
"xata", "xata",
"xmltodict", "xmltodict",
"anthropic",
] ]
scheduled_testing = [ scheduled_testing = [

View File

@ -0,0 +1,27 @@
"""Test Anthropic Chat API wrapper."""
import os
import pytest
from langchain.chat_models import ChatAnthropic
os.environ["ANTHROPIC_API_KEY"] = "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_kwargs() -> None:
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
@pytest.mark.requires("anthropic")
def test_anthropic_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = ChatAnthropic(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}