mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 07:36:08 +00:00
Add optional model kwargs to ChatAnthropic to allow overrides (#9013)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3b51817706
commit
f4a47ec717
@ -2,7 +2,7 @@ import re
|
||||
import warnings
|
||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -11,7 +11,12 @@ from langchain.callbacks.manager import (
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import check_package_version, get_from_dict_or_env
|
||||
from langchain.utils import (
|
||||
check_package_version,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain.utils.utils import build_extra_kwargs
|
||||
|
||||
|
||||
class _AnthropicCommon(BaseLanguageModel):
|
||||
@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
HUMAN_PROMPT: Optional[str] = None
|
||||
AI_PROMPT: Optional[str] = None
|
||||
count_tokens: Optional[Callable[[str], int]] = None
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict) -> Dict:
|
||||
extra = values.get("model_kwargs", {})
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
||||
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
||||
values["count_tokens"] = values["client"].count_tokens
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import anthropic python package. "
|
||||
@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
d["top_k"] = self.top_k
|
||||
if self.top_p is not None:
|
||||
d["top_p"] = self.top_p
|
||||
return d
|
||||
return {**d, **self.model_kwargs}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
|
4
libs/langchain/poetry.lock
generated
4
libs/langchain/poetry.lock
generated
@ -13608,7 +13608,7 @@ clarifai = ["clarifai"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
|
||||
extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
|
||||
javascript = ["esprima"]
|
||||
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
@ -13619,4 +13619,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0"
|
||||
content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e"
|
||||
|
@ -373,6 +373,7 @@ extended_testing = [
|
||||
"feedparser",
|
||||
"xata",
|
||||
"xmltodict",
|
||||
"anthropic",
|
||||
]
|
||||
|
||||
scheduled_testing = [
|
||||
|
@ -0,0 +1,27 @@
|
||||
"""Test Anthropic Chat API wrapper."""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_kwargs() -> None:
|
||||
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = ChatAnthropic(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
Loading…
Reference in New Issue
Block a user