diff --git a/libs/langchain/langchain/chat_models/bedrock.py b/libs/langchain/langchain/chat_models/bedrock.py index 7ba186bb222..3cc7b84cef8 100644 --- a/libs/langchain/langchain/chat_models/bedrock.py +++ b/libs/langchain/langchain/chat_models/bedrock.py @@ -9,6 +9,10 @@ from langchain.llms.bedrock import BedrockBase from langchain.pydantic_v1 import Extra from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain.utilities.anthropic import ( + get_num_tokens_anthropic, + get_token_ids_anthropic, +) class ChatPromptAdapter: @@ -86,3 +90,15 @@ class BedrockChat(BaseChatModel, BedrockBase): message = AIMessage(content=completion) return ChatResult(generations=[ChatGeneration(message=message)]) + + def get_num_tokens(self, text: str) -> int: + if self._model_is_anthropic: + return get_num_tokens_anthropic(text) + else: + return super().get_num_tokens(text) + + def get_token_ids(self, text: str) -> List[int]: + if self._model_is_anthropic: + return get_token_ids_anthropic(text) + else: + return super().get_token_ids(text) diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 7971f6e9c54..e1bb281e728 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -8,6 +8,10 @@ from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema.output import GenerationChunk +from langchain.utilities.anthropic import ( + get_num_tokens_anthropic, + get_token_ids_anthropic, +) HUMAN_PROMPT = "\n\nHuman:" ASSISTANT_PROMPT = "\n\nAssistant:" @@ -222,6 +226,10 @@ class BedrockBase(BaseModel, ABC): def _get_provider(self) -> str: return self.model_id.split(".")[0] + @property + def _model_is_anthropic(self) -> bool: + return self._get_provider() == "anthropic" + def _prepare_input_and_invoke( self, prompt: str, @@ -318,7 +326,7 @@ class Bedrock(LLM, BedrockBase): from bedrock_langchain.bedrock_llm import BedrockLLM llm = BedrockLLM( - credentials_profile_name="default", + credentials_profile_name="default", model_id="amazon.titan-text-express-v1", streaming=True ) @@ -393,3 +401,15 @@ class Bedrock(LLM, BedrockBase): return completion return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs) + + def get_num_tokens(self, text: str) -> int: + if self._model_is_anthropic: + return get_num_tokens_anthropic(text) + else: + return super().get_num_tokens(text) + + def get_token_ids(self, text: str) -> List[int]: + if self._model_is_anthropic: + return get_token_ids_anthropic(text) + else: + return super().get_token_ids(text) diff --git a/libs/langchain/langchain/utilities/anthropic.py b/libs/langchain/langchain/utilities/anthropic.py new file mode 100644 index 00000000000..89e6fd37aac --- /dev/null +++ b/libs/langchain/langchain/utilities/anthropic.py @@ -0,0 +1,25 @@ +from typing import Any, List + + +def _get_anthropic_client() -> Any: + try: + import anthropic + except ImportError: + raise ImportError( + "Could not import anthropic python package. " + "This is needed in order to accurately tokenize the text " + "for anthropic models. Please install it with `pip install anthropic`." + ) + return anthropic.Anthropic() + + +def get_num_tokens_anthropic(text: str) -> int: + client = _get_anthropic_client() + return client.count_tokens(text=text) + + +def get_token_ids_anthropic(text: str) -> List[int]: + client = _get_anthropic_client() + tokenizer = client.get_tokenizer() + encoded_text = tokenizer.encode(text) + return encoded_text.ids