anthropic: beta messages integration (#14928)

This commit is contained in:
Erick Friis
2023-12-19 18:55:19 -08:00
committed by GitHub
parent 795cf2ddda
commit 8a3360edf6
22 changed files with 1707 additions and 54 deletions

View File

@@ -41,6 +41,7 @@ jobs:
shell: bash shell: bash
env: env:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: | run: |

View File

@@ -153,6 +153,7 @@ jobs:
if: ${{ startsWith(inputs.working-directory, 'libs/partners/') }} if: ${{ startsWith(inputs.working-directory, 'libs/partners/') }}
env: env:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: make integration_tests run: make integration_tests

View File

@@ -22,7 +22,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9", "id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": { "metadata": {
"tags": [] "tags": []
@@ -35,7 +35,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817", "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": { "metadata": {
"tags": [] "tags": []
@@ -47,30 +47,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c", "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"messages = [\n", "messages = [\n",
" HumanMessage(\n", " HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n", " content=\"Translate this sentence from English to French. I love programming.\"\n",
" )\n", " )\n",
"]\n", "]\n",
"chat(messages)" "chat.invoke(messages)"
] ]
}, },
{ {
@@ -83,7 +72,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb", "id": "93a21c5c-6ef9-4688-be60-b2e1f94842fb",
"metadata": { "metadata": {
"tags": [] "tags": []
@@ -96,60 +85,41 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b", "id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\" J'aime programmer.\", generation_info=None, message=AIMessage(content=\" J'aime programmer.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('8cc8fb68-1c35-439c-96a0-695036a93652'))])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"await chat.agenerate([messages])" "await chat.ainvoke([messages])"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34", "id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
" J'aime la programmation."
]
},
{
"data": {
"text/plain": [
"AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"chat = ChatAnthropic(\n", "chat = ChatAnthropic(\n",
" streaming=True,\n", " streaming=True,\n",
" verbose=True,\n", " verbose=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
")\n", ")\n",
"chat(messages)" "chat.stream(messages)"
]
},
{
"cell_type": "markdown",
"id": "3737fc8d",
"metadata": {},
"source": [
"# ChatAnthropicMessages\n",
"\n",
"LangChain also offers the beta Anthropic Messages endpoint through the new `langchain-anthropic` package."
] ]
}, },
{ {
@@ -158,7 +128,22 @@
"id": "c253883f", "id": "c253883f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"!pip install langchain-anthropic"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07c47c2a",
"metadata": {},
"outputs": [],
"source": [
"from langchain_anthropic import ChatAnthropicMessages\n",
"\n",
"chat = ChatAnthropicMessages(model_name=\"claude-instant-1.2\")\n",
"chat.invoke(messages)"
]
} }
], ],
"metadata": { "metadata": {
@@ -177,7 +162,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.11.4"
} }
}, },
"nbformat": 4, "nbformat": 4,

1
libs/partners/anthropic/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
__pycache__

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,59 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/anthropic --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_anthropic
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_anthropic -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@@ -0,0 +1 @@
# langchain-anthropic

View File

@@ -0,0 +1,3 @@
from langchain_anthropic.chat_models import ChatAnthropicMessages
__all__ = ["ChatAnthropicMessages"]

View File

@@ -0,0 +1,187 @@
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple
import anthropic
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
_message_type_lookups = {"human": "user", "assistant": "ai"}
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic."""
"""
[
{
"role": _message_type_lookups[m.type],
"content": [_AnthropicMessageContent(text=m.content).dict()],
}
for m in messages
]
"""
system = None
formatted_messages = []
for i, message in enumerate(messages):
if not isinstance(message.content, str):
raise ValueError("Anthropic Messages API only supports text generation.")
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
system = message.content
else:
formatted_messages.append(
{
"role": _message_type_lookups[message.type],
"content": message.content,
}
)
return system, formatted_messages
class ChatAnthropicMessages(BaseChatModel):
"""Beta ChatAnthropicMessages chat model.
Example:
.. code-block:: python
from langchain_anthropic import ChatAnthropicMessages
model = ChatAnthropicMessages()
"""
_client: anthropic.Client = Field(default_factory=anthropic.Client)
_async_client: anthropic.AsyncClient = Field(default_factory=anthropic.AsyncClient)
model: str = Field(alias="model_name")
"""Model name to use."""
max_tokens: int = Field(default=256)
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
"""A non-negative float that tunes the degree of randomness in generation."""
top_k: Optional[int] = None
"""Number of most likely tokens to consider at each step."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
default_request_timeout: Optional[float] = None
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
anthropic_api_url: str = "https://api.anthropic.com"
anthropic_api_key: Optional[SecretStr] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-anthropic-messages"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str(
values.get("anthropic_api_key") or os.environ.get("ANTHROPIC_API_KEY") or ""
)
values["anthropic_api_key"] = anthropic_api_key
values["_client"] = anthropic.Client(
api_key=anthropic_api_key.get_secret_value()
)
values["_async_client"] = anthropic.AsyncClient(
api_key=anthropic_api_key.get_secret_value()
)
return values
def _format_params(
self,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Dict,
) -> Dict:
# get system prompt if any
system, formatted_messages = _format_messages(messages)
rtn = {
"model": self.model,
"max_tokens": self.max_tokens,
"messages": formatted_messages,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"stop_sequences": stop,
"system": system,
}
rtn = {k: v for k, v in rtn.items() if v is not None}
return rtn
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self._client.beta.messages.stream(**params) as stream:
for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self._async_client.beta.messages.stream(**params) as stream:
async for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = self._client.beta.messages.create(**params)
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data.content[0].text))
],
llm_output=data,
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = await self._async_client.beta.messages.create(**params)
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data.content[0].text))
],
llm_output=data,
)

1134
libs/partners/anthropic/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,90 @@
[tool.poetry]
name = "langchain-anthropic"
version = "0.0.1"
description = "An integration package connecting AnthropicMessages and LangChain"
authors = []
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.0.12"
anthropic = "^0.8.0"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
langchain-core = {path = "../../core", develop = true}
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = [
"tests/*",
]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

View File

@@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@@ -0,0 +1,86 @@
"""Test ChatAnthropicMessages chat model."""
from langchain_core.prompts import ChatPromptTemplate
from langchain_anthropic.chat_models import ChatAnthropicMessages
def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_abatch() -> None:
"""Test streaming tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_batch() -> None:
"""Test batch tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_invoke() -> None:
"""Test invoke tokens from ChatAnthropicMessages."""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
def test_system_invoke() -> None:
"""Test invoke tokens with a system message"""
llm = ChatAnthropicMessages(model_name="claude-instant-1.2")
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an expert cartographer. If asked, you are a cartographer. "
"STAY IN CHARACTER",
),
("human", "Are you a mathematician?"),
]
)
chain = prompt | llm
result = chain.invoke({})
assert isinstance(result.content, str)

View File

@@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@@ -0,0 +1,9 @@
"""Test chat model integration."""
from langchain_anthropic.chat_models import ChatAnthropicMessages
def test_initialization() -> None:
"""Test chat model initialization."""
ChatAnthropicMessages(model_name="claude-instant-1.2", anthropic_api_key="xyz")

View File

@@ -0,0 +1,7 @@
from langchain_anthropic import __all__
EXPECTED_ALL = ["ChatAnthropicMessages"]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)