mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
langchain_google_vertexai : added logic to override get_num_tokens_from_messages() for ChatVertexAI (#16784)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description: added logic to override get_num_tokens_from_messages() for ChatVertexAI. Currently ChatVertexAI was inheriting get_num_tokens_from_messages() from BaseChatModel which in-turn was calling GPT-2 tokenizer - **Issue: NA - **Dependencies: NA - **Twitter handle:@aditya_rane @lkuligin for review --------- Co-authored-by: adityarane@google.com <adityarane@google.com> Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
This commit is contained in:
parent
00a09e1b71
commit
98176ac982
@ -45,6 +45,12 @@ from vertexai.preview.generative_models import ( # type: ignore
|
||||
Image,
|
||||
Part,
|
||||
)
|
||||
from vertexai.preview.language_models import ( # type: ignore
|
||||
ChatModel as PreviewChatModel,
|
||||
)
|
||||
from vertexai.preview.language_models import (
|
||||
CodeChatModel as PreviewCodeChatModel,
|
||||
)
|
||||
|
||||
from langchain_google_vertexai._utils import (
|
||||
get_generation_info,
|
||||
@ -316,12 +322,20 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
values["client"] = GenerativeModel(
|
||||
model_name=values["model_name"], safety_settings=safety_settings
|
||||
)
|
||||
values["client_preview"] = GenerativeModel(
|
||||
model_name=values["model_name"], safety_settings=safety_settings
|
||||
)
|
||||
else:
|
||||
if is_codey_model(values["model_name"]):
|
||||
model_cls = CodeChatModel
|
||||
model_cls_preview = PreviewCodeChatModel
|
||||
else:
|
||||
model_cls = ChatModel
|
||||
model_cls_preview = PreviewChatModel
|
||||
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||
values["client_preview"] = model_cls_preview.from_pretrained(
|
||||
values["model_name"]
|
||||
)
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
|
@ -31,6 +31,12 @@ from vertexai.preview.generative_models import ( # type: ignore[import-untyped]
|
||||
Image,
|
||||
)
|
||||
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
|
||||
ChatModel as PreviewChatModel,
|
||||
)
|
||||
from vertexai.preview.language_models import (
|
||||
CodeChatModel as PreviewCodeChatModel,
|
||||
)
|
||||
from vertexai.preview.language_models import (
|
||||
CodeGenerationModel as PreviewCodeGenerationModel,
|
||||
)
|
||||
from vertexai.preview.language_models import (
|
||||
@ -239,6 +245,27 @@ class _VertexAICommon(_VertexAIBase):
|
||||
params.pop("candidate_count")
|
||||
return params
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
is_palm_chat_model = isinstance(
|
||||
self.client_preview, PreviewChatModel
|
||||
) or isinstance(self.client_preview, PreviewCodeChatModel)
|
||||
if is_palm_chat_model:
|
||||
result = self.client_preview.start_chat().count_tokens(text)
|
||||
else:
|
||||
result = self.client_preview.count_tokens([text])
|
||||
|
||||
return result.total_tokens
|
||||
|
||||
|
||||
class VertexAI(_VertexAICommon, BaseLLM):
|
||||
"""Google Vertex AI large language models."""
|
||||
@ -300,20 +327,6 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
raise ValueError("Only one candidate can be generated with streaming!")
|
||||
return values
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
result = self.client_preview.count_tokens([text])
|
||||
return result.total_tokens
|
||||
|
||||
def _response_to_generation(
|
||||
self, response: TextGenerationResponse, *, stream: bool = False
|
||||
) -> GenerationChunk:
|
||||
|
@ -225,6 +225,18 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_get_num_tokens_from_messages(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name, temperature=0.0)
|
||||
else:
|
||||
model = ChatVertexAI(temperature=0.0)
|
||||
message = HumanMessage(content="Hello")
|
||||
token = model.get_num_tokens_from_messages(messages=[message])
|
||||
assert isinstance(token, int)
|
||||
assert token == 3
|
||||
|
||||
|
||||
def test_chat_vertexai_gemini_function_calling() -> None:
|
||||
class MyModel(BaseModel):
|
||||
name: str
|
||||
|
@ -81,7 +81,6 @@ def test_tools() -> None:
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
|
||||
print(response)
|
||||
assert isinstance(response, dict)
|
||||
assert response["input"] == "What is 6 raised to the 0.43 power?"
|
||||
|
||||
@ -106,7 +105,6 @@ def test_stream() -> None:
|
||||
]
|
||||
response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools))
|
||||
assert len(response) == 1
|
||||
# for chunk in response:
|
||||
assert isinstance(response[0], AIMessageChunk)
|
||||
assert "function_call" in response[0].additional_kwargs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user