diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py index 1c62d76c75c..32255d268b7 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py @@ -45,6 +45,12 @@ from vertexai.preview.generative_models import ( # type: ignore Image, Part, ) +from vertexai.preview.language_models import ( # type: ignore + ChatModel as PreviewChatModel, +) +from vertexai.preview.language_models import ( + CodeChatModel as PreviewCodeChatModel, +) from langchain_google_vertexai._utils import ( get_generation_info, @@ -316,12 +322,20 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): values["client"] = GenerativeModel( model_name=values["model_name"], safety_settings=safety_settings ) + values["client_preview"] = GenerativeModel( + model_name=values["model_name"], safety_settings=safety_settings + ) else: if is_codey_model(values["model_name"]): model_cls = CodeChatModel + model_cls_preview = PreviewCodeChatModel else: model_cls = ChatModel + model_cls_preview = PreviewChatModel values["client"] = model_cls.from_pretrained(values["model_name"]) + values["client_preview"] = model_cls_preview.from_pretrained( + values["model_name"] + ) return values def _generate( diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py index bd4d346a019..9b3099b4c33 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py @@ -31,6 +31,12 @@ from vertexai.preview.generative_models import ( # type: ignore[import-untyped] Image, ) from vertexai.preview.language_models import ( # type: ignore[import-untyped] + ChatModel as PreviewChatModel, +) +from vertexai.preview.language_models import ( + CodeChatModel as PreviewCodeChatModel, +) +from vertexai.preview.language_models import ( CodeGenerationModel as PreviewCodeGenerationModel, ) from vertexai.preview.language_models import ( @@ -239,6 +245,27 @@ class _VertexAICommon(_VertexAIBase): params.pop("candidate_count") return params + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input will fit in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + is_palm_chat_model = isinstance( + self.client_preview, PreviewChatModel + ) or isinstance(self.client_preview, PreviewCodeChatModel) + if is_palm_chat_model: + result = self.client_preview.start_chat().count_tokens(text) + else: + result = self.client_preview.count_tokens([text]) + + return result.total_tokens + class VertexAI(_VertexAICommon, BaseLLM): """Google Vertex AI large language models.""" @@ -300,20 +327,6 @@ class VertexAI(_VertexAICommon, BaseLLM): raise ValueError("Only one candidate can be generated with streaming!") return values - def get_num_tokens(self, text: str) -> int: - """Get the number of tokens present in the text. - - Useful for checking if an input will fit in a model's context window. - - Args: - text: The string input to tokenize. - - Returns: - The integer number of tokens in the text. - """ - result = self.client_preview.count_tokens([text]) - return result.total_tokens - def _response_to_generation( self, response: TextGenerationResponse, *, stream: bool = False ) -> GenerationChunk: diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py index 94c7ea6a55b..338533103ea 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py @@ -225,6 +225,18 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None: assert isinstance(response.content, str) +@pytest.mark.parametrize("model_name", model_names_to_test) +def test_get_num_tokens_from_messages(model_name: str) -> None: + if model_name: + model = ChatVertexAI(model_name=model_name, temperature=0.0) + else: + model = ChatVertexAI(temperature=0.0) + message = HumanMessage(content="Hello") + token = model.get_num_tokens_from_messages(messages=[message]) + assert isinstance(token, int) + assert token == 3 + + def test_chat_vertexai_gemini_function_calling() -> None: class MyModel(BaseModel): name: str diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_tools.py b/libs/partners/google-vertexai/tests/integration_tests/test_tools.py index f239727ecac..e33187e6a10 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_tools.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_tools.py @@ -81,7 +81,6 @@ def test_tools() -> None: agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"}) - print(response) assert isinstance(response, dict) assert response["input"] == "What is 6 raised to the 0.43 power?" @@ -106,7 +105,6 @@ def test_stream() -> None: ] response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools)) assert len(response) == 1 - # for chunk in response: assert isinstance(response[0], AIMessageChunk) assert "function_call" in response[0].additional_kwargs