From bc53c928fc1b221d0038b839d111039d31729def Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 6 Feb 2023 22:29:25 -0800 Subject: [PATCH] Harrison/athropic (#921) Co-authored-by: Mike Lambert Co-authored-by: mrbean Co-authored-by: mrbean <43734688+sam-h-bean@users.noreply.github.com> Co-authored-by: Ivan Vendrov --- langchain/__init__.py | 3 +- langchain/llms/__init__.py | 3 + langchain/llms/anthropic.py | 155 ++++++++++++++++++ poetry.lock | 35 +++- pyproject.toml | 5 +- .../integration_tests/llms/test_anthropic.py | 23 +++ 6 files changed, 212 insertions(+), 12 deletions(-) create mode 100644 langchain/llms/anthropic.py create mode 100644 tests/integration_tests/llms/test_anthropic.py diff --git a/langchain/__init__.py b/langchain/__init__.py index 9b9bfb35719..3096f77a474 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -22,7 +22,7 @@ from langchain.chains import ( VectorDBQAWithSourcesChain, ) from langchain.docstore import InMemoryDocstore, Wikipedia -from langchain.llms import Cohere, HuggingFaceHub, OpenAI +from langchain.llms import Anthropic, Cohere, HuggingFaceHub, OpenAI from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.prompts import ( BasePromptTemplate, @@ -50,6 +50,7 @@ __all__ = [ "SerpAPIChain", "GoogleSearchAPIWrapper", "WolframAlphaAPIWrapper", + "Anthropic", "Cohere", "OpenAI", "BasePromptTemplate", diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index e07d0d5561e..dac7fb67bc5 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -2,6 +2,7 @@ from typing import Dict, Type from langchain.llms.ai21 import AI21 +from langchain.llms.anthropic import Anthropic from langchain.llms.base import BaseLLM from langchain.llms.cohere import Cohere from langchain.llms.huggingface_hub import HuggingFaceHub @@ -10,6 +11,7 @@ from langchain.llms.nlpcloud import NLPCloud from langchain.llms.openai import AzureOpenAI, OpenAI __all__ = [ + "Anthropic", "Cohere", "NLPCloud", "OpenAI", @@ -21,6 +23,7 @@ __all__ = [ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "ai21": AI21, + "anthropic": Anthropic, "cohere": Cohere, "huggingface_hub": HuggingFaceHub, "nlpcloud": NLPCloud, diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py new file mode 100644 index 00000000000..85da3dd7675 --- /dev/null +++ b/langchain/llms/anthropic.py @@ -0,0 +1,155 @@ +"""Wrapper around Anthropic APIs.""" +from typing import Any, Dict, Generator, List, Mapping, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.llms.base import LLM +from langchain.utils import get_from_dict_or_env + + +class Anthropic(LLM, BaseModel): + """Wrapper around Anthropic large language models. + + To use, you should have the ``anthropic`` python package installed, and the + environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass + it as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain import Anthropic + anthropic = Anthropic(model="", anthropic_api_key="my-api-key") + """ + + client: Any #: :meta private: + model: Optional[str] = None + """Model name to use.""" + + max_tokens_to_sample: int = 256 + """Denotes the number of tokens to predict per generation.""" + + temperature: float = 1.0 + """A non-negative float that tunes the degree of randomness in generation.""" + + top_k: int = 0 + """Number of most likely tokens to consider at each step.""" + + top_p: float = 1 + """Total probability mass of tokens to consider at each step.""" + + anthropic_api_key: Optional[str] = None + + HUMAN_PROMPT: Optional[str] = None + AI_PROMPT: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + anthropic_api_key = get_from_dict_or_env( + values, "anthropic_api_key", "ANTHROPIC_API_KEY" + ) + try: + import anthropic + + values["client"] = anthropic.Client(anthropic_api_key) + values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT + values["AI_PROMPT"] = anthropic.AI_PROMPT + except ImportError: + raise ValueError( + "Could not import anthropic python package. " + "Please it install it with `pip install anthropic`." + ) + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Anthropic API.""" + return { + "max_tokens_to_sample": self.max_tokens_to_sample, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + } + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "anthropic" + + def _call( + self, prompt: str, stop: Optional[List[str]] = None, instruct_mode: bool = True + ) -> str: + r"""Call out to Anthropic's completion endpoint. + + Will by default act like an instruction-following model, by wrapping the prompt + with Human: and Assistant: If you want to use for chat or few-shot, pass + in instruct_mode=False + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + instruct_mode: Whether to emulate an instruction-following model. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = anthropic("Tell me a joke.") + + response = anthropic( + "\n\nHuman: Tell me a joke.\n\nAssistant:", instruct_mode=False + ) + + """ + if stop is None: + stop = [] + if not self.HUMAN_PROMPT or not self.AI_PROMPT: + raise NameError("Please ensure the anthropic package is loaded") + # Never want model to invent new turns of Human / Assistant dialog. + stop.extend([self.HUMAN_PROMPT, self.AI_PROMPT]) + + if instruct_mode: + # Wrap the prompt so it emulates an instruction following model. + prompt = f"{self.HUMAN_PROMPT} prompt{self.AI_PROMPT} Sure, here you go:\n" + + response = self.client.completion( + model=self.model, prompt=prompt, stop_sequences=stop, **self._default_params + ) + text = response["completion"] + return text + + def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: + """Call Anthropic completion_stream and return the resulting generator. + + BETA: this is a beta feature while we figure out the right abstraction. + Once that happens, this interface could change. + + Args: + prompt: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + A generator representing the stream of tokens from Anthropic. + + Example: + .. code-block:: python + + generator = anthropic.stream("Tell me a joke.") + for token in generator: + yield token + """ + return self.client.completion_stream( + model=self.model, prompt=prompt, stop_sequences=stop, **self._default_params + ) diff --git a/poetry.lock b/poetry.lock index bd75920e089..7324ad8c20b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -24,6 +24,28 @@ files = [ {file = "alabaster-0.7.13.tar.gz", hash = "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2"}, ] +[[package]] +name = "anthropic" +version = "0.2.1" +description = "Library for accessing the anthropic API" +category = "main" +optional = true +python-versions = ">=3.8" +files = [] +develop = false + +[package.dependencies] +requests = "*" + +[package.extras] +dev = ["black (>=22.3.0)", "pytest"] + +[package.source] +type = "git" +url = "https://github.com/anthropics/anthropic-sdk-python.git" +reference = "HEAD" +resolved_reference = "b61e0637b6a4b34399aafff4adc756fe649325bb" + [[package]] name = "anyio" version = "3.6.2" @@ -851,6 +873,7 @@ files = [ {file = "debugpy-1.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b5d1b13d7c7bf5d7cf700e33c0b8ddb7baf030fcf502f76fc061ddd9405d16c"}, {file = "debugpy-1.6.6-cp38-cp38-win32.whl", hash = "sha256:70ab53918fd907a3ade01909b3ed783287ede362c80c75f41e79596d5ccacd32"}, {file = "debugpy-1.6.6-cp38-cp38-win_amd64.whl", hash = "sha256:c05349890804d846eca32ce0623ab66c06f8800db881af7a876dc073ac1c2225"}, + {file = "debugpy-1.6.6-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:11a0f3a106f69901e4a9a5683ce943a7a5605696024134b522aa1bfda25b5fec"}, {file = "debugpy-1.6.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a771739902b1ae22a120dbbb6bd91b2cae6696c0e318b5007c5348519a4211c6"}, {file = "debugpy-1.6.6-cp39-cp39-win32.whl", hash = "sha256:549ae0cb2d34fc09d1675f9b01942499751d174381b6082279cf19cdb3c47cbe"}, {file = "debugpy-1.6.6-cp39-cp39-win_amd64.whl", hash = "sha256:de4a045fbf388e120bb6ec66501458d3134f4729faed26ff95de52a754abddb1"}, @@ -2410,7 +2433,6 @@ files = [ {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"}, {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"}, {file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"}, - {file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"}, {file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"}, @@ -2420,7 +2442,6 @@ files = [ {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"}, {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"}, {file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"}, - {file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"}, {file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"}, @@ -5521,14 +5542,11 @@ files = [ {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"}, {file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"}, {file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"}, - {file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"}, - {file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"}, {file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"}, - {file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"}, {file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"}, @@ -5537,7 +5555,6 @@ files = [ {file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"}, {file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"}, {file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"}, - {file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"}, @@ -6268,10 +6285,10 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text"] -llms = ["manifest-ml", "torch", "transformers"] +all = ["anthropic", "manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text"] +llms = ["anthropic", "manifest-ml", "torch", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "6066217c63ce7620ee615d5db269327e288991f48304526767c88a1870f0c7a0" +content-hash = "1f44af9c9fa6aa7fe3fed70c0342c71489e7288562af54f6721bd622695ab6f4" diff --git a/pyproject.toml b/pyproject.toml index 1258d596537..3dff7a3bd4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ tiktoken = {version = "^0", optional = true, python="^3.9"} pinecone-client = {version = "^2", optional = true} weaviate-client = {version = "^3", optional = true} google-api-python-client = {version = "2.70.0", optional = true} +anthropic = {git = "https://github.com/anthropics/anthropic-sdk-python.git", optional = true} wolframalpha = {version = "5.0.0", optional = true} qdrant-client = {version = "^0.11.7", optional = true} dataclasses-json = "^0.5.7" @@ -82,8 +83,8 @@ jupyter = "^1.0.0" playwright = "^1.28.0" [tool.poetry.extras] -llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] -all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text"] +llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text"] [tool.isort] profile = "black" diff --git a/tests/integration_tests/llms/test_anthropic.py b/tests/integration_tests/llms/test_anthropic.py new file mode 100644 index 00000000000..9077633abbf --- /dev/null +++ b/tests/integration_tests/llms/test_anthropic.py @@ -0,0 +1,23 @@ +"""Test Anthropic API wrapper.""" + +from typing import Generator + +from langchain.llms.anthropic import Anthropic + + +def test_anthropic_call() -> None: + """Test valid call to anthropic.""" + llm = Anthropic(model="bare-nano-0") + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_anthropic_streaming() -> None: + """Test streaming tokens from anthropic.""" + llm = Anthropic(model="bare-nano-0") + generator = llm.stream("I'm Pickle Rick") + + assert isinstance(generator, Generator) + + for token in generator: + assert isinstance(token["completion"], str)