mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 18:38:48 +00:00
implement get_num_tokens to use google's count_tokens function (#10565)
can get the correct token count instead of using gpt-2 model **Description:** Implement get_num_tokens within VertexLLM to use google's count_tokens function. (https://cloud.google.com/vertex-ai/docs/generative-ai/get-token-count). So we don't need to download gpt-2 model from huggingface, also when we do the mapreduce chain we can get correct token count. **Tag maintainer:** @lkuligin **Twitter handle:** My twitter: @abehsu1992626 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -276,6 +276,27 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
raise ValueError("Only one candidate can be generated with streaming!")
|
||||
return values
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
try:
|
||||
result = self.client.count_tokens(text)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
"Your google-cloud-aiplatform version didn't implement count_tokens."
|
||||
"Please, install it with pip install google-cloud-aiplatform>=1.35.0"
|
||||
)
|
||||
|
||||
return result.total_tokens
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
|
@@ -2,7 +2,7 @@
|
||||
In order to run this test, you need to install VertexAI SDK (that is is the private
|
||||
preview) and be whitelisted to list the models themselves:
|
||||
In order to run this test, you need to install VertexAI SDK
|
||||
pip install google-cloud-aiplatform>=1.25.0
|
||||
pip install google-cloud-aiplatform>=1.35.0
|
||||
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Test Vertex AI API wrapper.
|
||||
In order to run this test, you need to install VertexAI SDK
|
||||
pip install google-cloud-aiplatform>=1.25.0
|
||||
pip install google-cloud-aiplatform>=1.35.0
|
||||
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
|
@@ -2,7 +2,7 @@
|
||||
In order to run this test, you need to install VertexAI SDK (that is is the private
|
||||
preview) and be whitelisted to list the models themselves:
|
||||
In order to run this test, you need to install VertexAI SDK
|
||||
pip install google-cloud-aiplatform>=1.25.0
|
||||
pip install google-cloud-aiplatform>=1.35.0
|
||||
|
||||
Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
@@ -10,7 +10,10 @@ Your end-user credentials would be used to make the calls (make sure you've run
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms import VertexAI, VertexAIModelGarden
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -108,3 +111,37 @@ async def test_model_garden_agenerate() -> None:
|
||||
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_vertex_call_trigger_count_tokens() -> None:
|
||||
llm = VertexAI()
|
||||
output = llm.get_num_tokens("Hi")
|
||||
assert output == 2
|
||||
|
||||
|
||||
@pytest.mark.requires("google.cloud.aiplatform")
|
||||
def test_get_num_tokens_be_called_when_using_mapreduce_chain(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
from vertexai.language_models._language_models import CountTokensResponse
|
||||
|
||||
m1 = mocker.patch(
|
||||
"vertexai.preview.language_models._PreviewTextGenerationModel.count_tokens",
|
||||
return_value=CountTokensResponse(
|
||||
total_tokens=2,
|
||||
total_billable_characters=2,
|
||||
_count_tokens_response={"total_tokens": 2, "total_billable_characters": 2},
|
||||
),
|
||||
)
|
||||
llm = VertexAI()
|
||||
chain = load_summarize_chain(
|
||||
llm,
|
||||
chain_type="map_reduce",
|
||||
return_intermediate_steps=False,
|
||||
)
|
||||
doc = Document(page_content="Hi")
|
||||
output = chain({"input_documents": [doc]})
|
||||
assert isinstance(output["output_text"], str)
|
||||
m1.assert_called_once()
|
||||
assert llm._llm_type == "vertexai"
|
||||
assert llm.model_name == llm.client._model_id
|
||||
|
Reference in New Issue
Block a user