mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 08:56:27 +00:00
langchain_google_vertexai : added logic to override get_num_tokens_from_messages() for ChatVertexAI (#16784)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description: added logic to override get_num_tokens_from_messages() for ChatVertexAI. Currently ChatVertexAI was inheriting get_num_tokens_from_messages() from BaseChatModel which in-turn was calling GPT-2 tokenizer - **Issue: NA - **Dependencies: NA - **Twitter handle:@aditya_rane @lkuligin for review --------- Co-authored-by: adityarane@google.com <adityarane@google.com> Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
This commit is contained in:
parent
00a09e1b71
commit
98176ac982
@ -45,6 +45,12 @@ from vertexai.preview.generative_models import ( # type: ignore
|
|||||||
Image,
|
Image,
|
||||||
Part,
|
Part,
|
||||||
)
|
)
|
||||||
|
from vertexai.preview.language_models import ( # type: ignore
|
||||||
|
ChatModel as PreviewChatModel,
|
||||||
|
)
|
||||||
|
from vertexai.preview.language_models import (
|
||||||
|
CodeChatModel as PreviewCodeChatModel,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_google_vertexai._utils import (
|
from langchain_google_vertexai._utils import (
|
||||||
get_generation_info,
|
get_generation_info,
|
||||||
@ -316,12 +322,20 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
values["client"] = GenerativeModel(
|
values["client"] = GenerativeModel(
|
||||||
model_name=values["model_name"], safety_settings=safety_settings
|
model_name=values["model_name"], safety_settings=safety_settings
|
||||||
)
|
)
|
||||||
|
values["client_preview"] = GenerativeModel(
|
||||||
|
model_name=values["model_name"], safety_settings=safety_settings
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if is_codey_model(values["model_name"]):
|
if is_codey_model(values["model_name"]):
|
||||||
model_cls = CodeChatModel
|
model_cls = CodeChatModel
|
||||||
|
model_cls_preview = PreviewCodeChatModel
|
||||||
else:
|
else:
|
||||||
model_cls = ChatModel
|
model_cls = ChatModel
|
||||||
|
model_cls_preview = PreviewChatModel
|
||||||
values["client"] = model_cls.from_pretrained(values["model_name"])
|
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||||
|
values["client_preview"] = model_cls_preview.from_pretrained(
|
||||||
|
values["model_name"]
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
|
@ -31,6 +31,12 @@ from vertexai.preview.generative_models import ( # type: ignore[import-untyped]
|
|||||||
Image,
|
Image,
|
||||||
)
|
)
|
||||||
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
|
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
|
||||||
|
ChatModel as PreviewChatModel,
|
||||||
|
)
|
||||||
|
from vertexai.preview.language_models import (
|
||||||
|
CodeChatModel as PreviewCodeChatModel,
|
||||||
|
)
|
||||||
|
from vertexai.preview.language_models import (
|
||||||
CodeGenerationModel as PreviewCodeGenerationModel,
|
CodeGenerationModel as PreviewCodeGenerationModel,
|
||||||
)
|
)
|
||||||
from vertexai.preview.language_models import (
|
from vertexai.preview.language_models import (
|
||||||
@ -239,6 +245,27 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
params.pop("candidate_count")
|
params.pop("candidate_count")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Get the number of tokens present in the text.
|
||||||
|
|
||||||
|
Useful for checking if an input will fit in a model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The string input to tokenize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The integer number of tokens in the text.
|
||||||
|
"""
|
||||||
|
is_palm_chat_model = isinstance(
|
||||||
|
self.client_preview, PreviewChatModel
|
||||||
|
) or isinstance(self.client_preview, PreviewCodeChatModel)
|
||||||
|
if is_palm_chat_model:
|
||||||
|
result = self.client_preview.start_chat().count_tokens(text)
|
||||||
|
else:
|
||||||
|
result = self.client_preview.count_tokens([text])
|
||||||
|
|
||||||
|
return result.total_tokens
|
||||||
|
|
||||||
|
|
||||||
class VertexAI(_VertexAICommon, BaseLLM):
|
class VertexAI(_VertexAICommon, BaseLLM):
|
||||||
"""Google Vertex AI large language models."""
|
"""Google Vertex AI large language models."""
|
||||||
@ -300,20 +327,6 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
raise ValueError("Only one candidate can be generated with streaming!")
|
raise ValueError("Only one candidate can be generated with streaming!")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
"""Get the number of tokens present in the text.
|
|
||||||
|
|
||||||
Useful for checking if an input will fit in a model's context window.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The string input to tokenize.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The integer number of tokens in the text.
|
|
||||||
"""
|
|
||||||
result = self.client_preview.count_tokens([text])
|
|
||||||
return result.total_tokens
|
|
||||||
|
|
||||||
def _response_to_generation(
|
def _response_to_generation(
|
||||||
self, response: TextGenerationResponse, *, stream: bool = False
|
self, response: TextGenerationResponse, *, stream: bool = False
|
||||||
) -> GenerationChunk:
|
) -> GenerationChunk:
|
||||||
|
@ -225,6 +225,18 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||||
|
def test_get_num_tokens_from_messages(model_name: str) -> None:
|
||||||
|
if model_name:
|
||||||
|
model = ChatVertexAI(model_name=model_name, temperature=0.0)
|
||||||
|
else:
|
||||||
|
model = ChatVertexAI(temperature=0.0)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
token = model.get_num_tokens_from_messages(messages=[message])
|
||||||
|
assert isinstance(token, int)
|
||||||
|
assert token == 3
|
||||||
|
|
||||||
|
|
||||||
def test_chat_vertexai_gemini_function_calling() -> None:
|
def test_chat_vertexai_gemini_function_calling() -> None:
|
||||||
class MyModel(BaseModel):
|
class MyModel(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
@ -81,7 +81,6 @@ def test_tools() -> None:
|
|||||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||||
|
|
||||||
response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
|
response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
|
||||||
print(response)
|
|
||||||
assert isinstance(response, dict)
|
assert isinstance(response, dict)
|
||||||
assert response["input"] == "What is 6 raised to the 0.43 power?"
|
assert response["input"] == "What is 6 raised to the 0.43 power?"
|
||||||
|
|
||||||
@ -106,7 +105,6 @@ def test_stream() -> None:
|
|||||||
]
|
]
|
||||||
response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools))
|
response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools))
|
||||||
assert len(response) == 1
|
assert len(response) == 1
|
||||||
# for chunk in response:
|
|
||||||
assert isinstance(response[0], AIMessageChunk)
|
assert isinstance(response[0], AIMessageChunk)
|
||||||
assert "function_call" in response[0].additional_kwargs
|
assert "function_call" in response[0].additional_kwargs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user