mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-24 16:37:46 +00:00
Use correct tokenizer for Bedrock/Anthropic LLMs (#11561)
**Description** This PR implements the usage of the correct tokenizer in Bedrock LLMs, if using anthropic models. **Issue:** #11560 **Dependencies:** optional dependency on `anthropic` python library. **Twitter handle:** jtolgyesi --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
467b082c34
commit
15687a28d5
@ -9,6 +9,10 @@ from langchain.llms.bedrock import BedrockBase
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain.utilities.anthropic import (
|
||||
get_num_tokens_anthropic,
|
||||
get_token_ids_anthropic,
|
||||
)
|
||||
|
||||
|
||||
class ChatPromptAdapter:
|
||||
@ -86,3 +90,15 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
return get_num_tokens_anthropic(text)
|
||||
else:
|
||||
return super().get_num_tokens(text)
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
if self._model_is_anthropic:
|
||||
return get_token_ids_anthropic(text)
|
||||
else:
|
||||
return super().get_token_ids(text)
|
||||
|
@ -8,6 +8,10 @@ from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utilities.anthropic import (
|
||||
get_num_tokens_anthropic,
|
||||
get_token_ids_anthropic,
|
||||
)
|
||||
|
||||
HUMAN_PROMPT = "\n\nHuman:"
|
||||
ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||
@ -222,6 +226,10 @@ class BedrockBase(BaseModel, ABC):
|
||||
def _get_provider(self) -> str:
|
||||
return self.model_id.split(".")[0]
|
||||
|
||||
@property
|
||||
def _model_is_anthropic(self) -> bool:
|
||||
return self._get_provider() == "anthropic"
|
||||
|
||||
def _prepare_input_and_invoke(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -393,3 +401,15 @@ class Bedrock(LLM, BedrockBase):
|
||||
return completion
|
||||
|
||||
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
return get_num_tokens_anthropic(text)
|
||||
else:
|
||||
return super().get_num_tokens(text)
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
if self._model_is_anthropic:
|
||||
return get_token_ids_anthropic(text)
|
||||
else:
|
||||
return super().get_token_ids(text)
|
||||
|
25
libs/langchain/langchain/utilities/anthropic.py
Normal file
25
libs/langchain/langchain/utilities/anthropic.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
def _get_anthropic_client() -> Any:
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import anthropic python package. "
|
||||
"This is needed in order to accurately tokenize the text "
|
||||
"for anthropic models. Please install it with `pip install anthropic`."
|
||||
)
|
||||
return anthropic.Anthropic()
|
||||
|
||||
|
||||
def get_num_tokens_anthropic(text: str) -> int:
|
||||
client = _get_anthropic_client()
|
||||
return client.count_tokens(text=text)
|
||||
|
||||
|
||||
def get_token_ids_anthropic(text: str) -> List[int]:
|
||||
client = _get_anthropic_client()
|
||||
tokenizer = client.get_tokenizer()
|
||||
encoded_text = tokenizer.encode(text)
|
||||
return encoded_text.ids
|
Loading…
Reference in New Issue
Block a user