mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 16:11:02 +00:00
Add optional model kwargs to ChatAnthropic to allow overrides (#9013)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3b51817706
commit
f4a47ec717
@ -2,7 +2,7 @@ import re
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -11,7 +11,12 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import GenerationChunk
|
||||||
from langchain.utils import check_package_version, get_from_dict_or_env
|
from langchain.utils import (
|
||||||
|
check_package_version,
|
||||||
|
get_from_dict_or_env,
|
||||||
|
get_pydantic_field_names,
|
||||||
|
)
|
||||||
|
from langchain.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
class _AnthropicCommon(BaseLanguageModel):
|
class _AnthropicCommon(BaseLanguageModel):
|
||||||
@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel):
|
|||||||
HUMAN_PROMPT: Optional[str] = None
|
HUMAN_PROMPT: Optional[str] = None
|
||||||
AI_PROMPT: Optional[str] = None
|
AI_PROMPT: Optional[str] = None
|
||||||
count_tokens: Optional[Callable[[str], int]] = None
|
count_tokens: Optional[Callable[[str], int]] = None
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict) -> Dict:
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
values["model_kwargs"] = build_extra_kwargs(
|
||||||
|
extra, values, all_required_field_names
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel):
|
|||||||
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
||||||
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
||||||
values["count_tokens"] = values["client"].count_tokens
|
values["count_tokens"] = values["client"].count_tokens
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import anthropic python package. "
|
"Could not import anthropic python package. "
|
||||||
@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel):
|
|||||||
d["top_k"] = self.top_k
|
d["top_k"] = self.top_k
|
||||||
if self.top_p is not None:
|
if self.top_p is not None:
|
||||||
d["top_p"] = self.top_p
|
d["top_p"] = self.top_p
|
||||||
return d
|
return {**d, **self.model_kwargs}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
4
libs/langchain/poetry.lock
generated
4
libs/langchain/poetry.lock
generated
@ -13608,7 +13608,7 @@ clarifai = ["clarifai"]
|
|||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
docarray = ["docarray"]
|
docarray = ["docarray"]
|
||||||
embeddings = ["sentence-transformers"]
|
embeddings = ["sentence-transformers"]
|
||||||
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
|
extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
|
||||||
javascript = ["esprima"]
|
javascript = ["esprima"]
|
||||||
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
|
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
|
||||||
openai = ["openai", "tiktoken"]
|
openai = ["openai", "tiktoken"]
|
||||||
@ -13619,4 +13619,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0"
|
content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e"
|
||||||
|
@ -373,6 +373,7 @@ extended_testing = [
|
|||||||
"feedparser",
|
"feedparser",
|
||||||
"xata",
|
"xata",
|
||||||
"xmltodict",
|
"xmltodict",
|
||||||
|
"anthropic",
|
||||||
]
|
]
|
||||||
|
|
||||||
scheduled_testing = [
|
scheduled_testing = [
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
"""Test Anthropic Chat API wrapper."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatAnthropic
|
||||||
|
|
||||||
|
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_kwargs() -> None:
|
||||||
|
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_invalid_model_kwargs() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_incorrect_field() -> None:
|
||||||
|
with pytest.warns(match="not default parameter"):
|
||||||
|
llm = ChatAnthropic(foo="bar")
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
Loading…
Reference in New Issue
Block a user