mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-25 08:57:48 +00:00
Use correct tokenizer for Bedrock/Anthropic LLMs (#11561)
**Description** This PR implements the usage of the correct tokenizer in Bedrock LLMs, if using anthropic models. **Issue:** #11560 **Dependencies:** optional dependency on `anthropic` python library. **Twitter handle:** jtolgyesi --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
467b082c34
commit
15687a28d5
@ -9,6 +9,10 @@ from langchain.llms.bedrock import BedrockBase
|
|||||||
from langchain.pydantic_v1 import Extra
|
from langchain.pydantic_v1 import Extra
|
||||||
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain.utilities.anthropic import (
|
||||||
|
get_num_tokens_anthropic,
|
||||||
|
get_token_ids_anthropic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptAdapter:
|
class ChatPromptAdapter:
|
||||||
@ -86,3 +90,15 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
|||||||
|
|
||||||
message = AIMessage(content=completion)
|
message = AIMessage(content=completion)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
if self._model_is_anthropic:
|
||||||
|
return get_num_tokens_anthropic(text)
|
||||||
|
else:
|
||||||
|
return super().get_num_tokens(text)
|
||||||
|
|
||||||
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
|
if self._model_is_anthropic:
|
||||||
|
return get_token_ids_anthropic(text)
|
||||||
|
else:
|
||||||
|
return super().get_token_ids(text)
|
||||||
|
@ -8,6 +8,10 @@ from langchain.llms.base import LLM
|
|||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import GenerationChunk
|
||||||
|
from langchain.utilities.anthropic import (
|
||||||
|
get_num_tokens_anthropic,
|
||||||
|
get_token_ids_anthropic,
|
||||||
|
)
|
||||||
|
|
||||||
HUMAN_PROMPT = "\n\nHuman:"
|
HUMAN_PROMPT = "\n\nHuman:"
|
||||||
ASSISTANT_PROMPT = "\n\nAssistant:"
|
ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||||
@ -222,6 +226,10 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
def _get_provider(self) -> str:
|
def _get_provider(self) -> str:
|
||||||
return self.model_id.split(".")[0]
|
return self.model_id.split(".")[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _model_is_anthropic(self) -> bool:
|
||||||
|
return self._get_provider() == "anthropic"
|
||||||
|
|
||||||
def _prepare_input_and_invoke(
|
def _prepare_input_and_invoke(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -393,3 +401,15 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
return completion
|
return completion
|
||||||
|
|
||||||
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
if self._model_is_anthropic:
|
||||||
|
return get_num_tokens_anthropic(text)
|
||||||
|
else:
|
||||||
|
return super().get_num_tokens(text)
|
||||||
|
|
||||||
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
|
if self._model_is_anthropic:
|
||||||
|
return get_token_ids_anthropic(text)
|
||||||
|
else:
|
||||||
|
return super().get_token_ids(text)
|
||||||
|
25
libs/langchain/langchain/utilities/anthropic.py
Normal file
25
libs/langchain/langchain/utilities/anthropic.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
|
||||||
|
def _get_anthropic_client() -> Any:
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import anthropic python package. "
|
||||||
|
"This is needed in order to accurately tokenize the text "
|
||||||
|
"for anthropic models. Please install it with `pip install anthropic`."
|
||||||
|
)
|
||||||
|
return anthropic.Anthropic()
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_tokens_anthropic(text: str) -> int:
|
||||||
|
client = _get_anthropic_client()
|
||||||
|
return client.count_tokens(text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_ids_anthropic(text: str) -> List[int]:
|
||||||
|
client = _get_anthropic_client()
|
||||||
|
tokenizer = client.get_tokenizer()
|
||||||
|
encoded_text = tokenizer.encode(text)
|
||||||
|
return encoded_text.ids
|
Loading…
Reference in New Issue
Block a user