mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
google-genai[patch], community[patch]: Added support for new Google GenerativeAI models (#14530)
Replace this entire comment with: - **Description:** added support for new Google GenerativeAI models - **Twitter handle:** lkuligin --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.llms import BaseLLM
|
||||
@@ -13,7 +15,9 @@ from langchain_community.utilities.vertexai import create_retry_decorator
|
||||
|
||||
def completion_with_retry(
|
||||
llm: GooglePalm,
|
||||
*args: Any,
|
||||
prompt: LanguageModelInput,
|
||||
is_gemini: bool = False,
|
||||
stream: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@@ -23,10 +27,23 @@ def completion_with_retry(
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.generate_text(*args, **kwargs)
|
||||
def _completion_with_retry(
|
||||
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
||||
) -> Any:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
contents=prompt, stream=stream, generation_config=generation_config
|
||||
)
|
||||
return llm.client.generate_text(prompt=prompt, **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
return _completion_with_retry(
|
||||
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _is_gemini_model(model_name: str) -> bool:
|
||||
return "gemini" in model_name
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
@@ -42,11 +59,16 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
@deprecated("0.0.351", alternative="langchain_google_genai.GoogleGenerativeAI")
|
||||
class GooglePalm(BaseLLM, BaseModel):
|
||||
"""Google PaLM models."""
|
||||
"""
|
||||
DEPRECATED: Use `langchain_google_genai.GoogleGenerativeAI` instead.
|
||||
|
||||
Google PaLM models.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: Optional[str]
|
||||
google_api_key: Optional[SecretStr]
|
||||
model_name: str = "models/text-bison-001"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
@@ -67,6 +89,11 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
max_retries: int = 6
|
||||
"""The maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def is_gemini(self) -> bool:
|
||||
"""Returns whether a model is belongs to a Gemini family or not."""
|
||||
return _is_gemini_model(self.model_name)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
@@ -86,18 +113,25 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
model_name = values["model_name"]
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
|
||||
if _is_gemini_model(model_name):
|
||||
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||
else:
|
||||
values["client"] = genai
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`."
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
@@ -119,30 +153,76 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
generations: List[List[Generation]] = []
|
||||
generation_config = {
|
||||
"stop_sequences": stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"candidate_count": self.n,
|
||||
}
|
||||
for prompt in prompts:
|
||||
completion = completion_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stop_sequences=stop,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
max_output_tokens=self.max_output_tokens,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt_generations = []
|
||||
for candidate in completion.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
if self.is_gemini:
|
||||
res = completion_with_retry(
|
||||
self,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
candidates = [
|
||||
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
||||
]
|
||||
generations.append([Generation(text=c) for c in candidates])
|
||||
else:
|
||||
res = completion_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=False,
|
||||
run_manager=run_manager,
|
||||
**generation_config,
|
||||
)
|
||||
prompt_generations = []
|
||||
for candidate in res.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if stop:
|
||||
generation_config["stop_sequences"] = stop
|
||||
for stream_resp in completion_with_retry(
|
||||
self,
|
||||
prompt,
|
||||
stream=True,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
**kwargs,
|
||||
):
|
||||
chunk = GenerationChunk(text=stream_resp.text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
@@ -159,5 +239,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
if self.is_gemini:
|
||||
raise ValueError("Counting tokens is not yet supported!")
|
||||
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
|
||||
return result["token_count"]
|
||||
|
@@ -1,4 +1,4 @@
|
||||
"""Test Google PaLM Text API wrapper.
|
||||
"""Test Google GenerativeAI API wrapper.
|
||||
|
||||
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||
valid API key.
|
||||
@@ -6,35 +6,68 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms.google_palm import GooglePalm
|
||||
from langchain_community.llms.loading import load_llm
|
||||
|
||||
model_names = [None, "models/text-bison-001", "gemini-pro"]
|
||||
|
||||
def test_google_palm_call() -> None:
|
||||
"""Test valid call to Google PaLM text API."""
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_call(model_name: str) -> None:
|
||||
"""Test valid call to Google GenerativeAI text API."""
|
||||
if model_name:
|
||||
llm = GooglePalm(max_output_tokens=10, model_name=model_name)
|
||||
else:
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "google_palm"
|
||||
assert llm.model_name == "models/text-bison-001"
|
||||
if model_name and "gemini" in model_name:
|
||||
assert llm.client.model_name == "models/gemini-pro"
|
||||
else:
|
||||
assert llm.model_name == "models/text-bison-001"
|
||||
|
||||
|
||||
def test_google_palm_generate() -> None:
|
||||
llm = GooglePalm(temperature=0.3, n=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_generate(model_name: str) -> None:
|
||||
n = 1 if model_name == "gemini-pro" else 2
|
||||
if model_name:
|
||||
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name)
|
||||
else:
|
||||
llm = GooglePalm(temperature=0.3, n=n)
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
assert len(output.generations[0]) == n
|
||||
|
||||
|
||||
def test_google_palm_get_num_tokens() -> None:
|
||||
def test_google_generativeai_get_num_tokens() -> None:
|
||||
llm = GooglePalm()
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
async def test_google_generativeai_agenerate() -> None:
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_generativeai_stream() -> None:
|
||||
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading a Google PaLM LLM."""
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
|
@@ -6,11 +6,14 @@ This module integrates Google's Generative AI models, specifically the Gemini se
|
||||
|
||||
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
|
||||
|
||||
**LLMs**
|
||||
|
||||
The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
|
||||
|
||||
**Embeddings**
|
||||
|
||||
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
|
||||
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
|
||||
|
||||
**Installation**
|
||||
|
||||
To install the package, use pip:
|
||||
@@ -29,6 +32,17 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
||||
|
||||
## Using LLMs
|
||||
|
||||
The package also supports generating text with Google's models.
|
||||
|
||||
```python
|
||||
from langchain_google_genai import GoogleGenerativeAI
|
||||
|
||||
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||
llm.invoke("Once upon a time, a library called LangChain")
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
|
||||
@@ -42,5 +56,10 @@ embeddings.embed_query("hello, world!")
|
||||
""" # noqa: E501
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||
|
||||
__all__ = ["ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings"]
|
||||
__all__ = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
"GoogleGenerativeAIEmbeddings",
|
||||
"GoogleGenerativeAI",
|
||||
]
|
||||
|
262
libs/partners/google-genai/langchain_google_genai/llms.py
Normal file
262
libs/partners/google-genai/langchain_google_genai/llms.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import google.api_core
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: BaseLLM,
|
||||
*,
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Creates a retry decorator for Vertex / Palm LLMs."""
|
||||
|
||||
errors = [
|
||||
google.api_core.exceptions.ResourceExhausted,
|
||||
google.api_core.exceptions.ServiceUnavailable,
|
||||
google.api_core.exceptions.Aborted,
|
||||
google.api_core.exceptions.DeadlineExceeded,
|
||||
google.api_core.exceptions.GoogleAPIError,
|
||||
]
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def _completion_with_retry(
|
||||
llm: GoogleGenerativeAI,
|
||||
prompt: LanguageModelInput,
|
||||
is_gemini: bool = False,
|
||||
stream: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(
|
||||
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(
|
||||
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
||||
) -> Any:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
contents=prompt, stream=stream, generation_config=generation_config
|
||||
)
|
||||
return llm.client.generate_text(prompt=prompt, **kwargs)
|
||||
|
||||
return _completion_with_retry(
|
||||
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _is_gemini_model(model_name: str) -> bool:
|
||||
return "gemini" in model_name
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
"""Strip erroneous leading spaces from text.
|
||||
|
||||
The PaLM API will sometimes erroneously return a single leading space in all
|
||||
lines > 1. This function strips that space.
|
||||
"""
|
||||
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
|
||||
if has_leading_space:
|
||||
return text.replace("\n ", "\n")
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
class GoogleGenerativeAI(BaseLLM, BaseModel):
|
||||
"""Google GenerativeAI models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_google_genai import GoogleGenerativeAI
|
||||
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str = Field(
|
||||
...,
|
||||
description="""The name of the model to use.
|
||||
Supported examples:
|
||||
- gemini-pro
|
||||
- models/text-bison-001""",
|
||||
)
|
||||
"""Model name to use."""
|
||||
google_api_key: Optional[SecretStr] = None
|
||||
temperature: float = 0.7
|
||||
"""Run inference with this temperature. Must by in the closed interval
|
||||
[0.0, 1.0]."""
|
||||
top_p: Optional[float] = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
max_output_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
||||
If unset, will default to 64."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
max_retries: int = 6
|
||||
"""The maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def is_gemini(self) -> bool:
|
||||
"""Returns whether a model is belongs to a Gemini family or not."""
|
||||
return _is_gemini_model(self.model)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists."""
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
model_name = values["model"]
|
||||
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
|
||||
if _is_gemini_model(model_name):
|
||||
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||
else:
|
||||
values["client"] = genai
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
|
||||
raise ValueError("max_output_tokens must be greater than zero")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations: List[List[Generation]] = []
|
||||
generation_config = {
|
||||
"stop_sequences": stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"candidate_count": self.n,
|
||||
}
|
||||
for prompt in prompts:
|
||||
if self.is_gemini:
|
||||
res = _completion_with_retry(
|
||||
self,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
candidates = [
|
||||
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
||||
]
|
||||
generations.append([Generation(text=c) for c in candidates])
|
||||
else:
|
||||
res = _completion_with_retry(
|
||||
self,
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
is_gemini=False,
|
||||
run_manager=run_manager,
|
||||
**generation_config,
|
||||
)
|
||||
prompt_generations = []
|
||||
for candidate in res.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
generation_config = kwargs.get("generation_config", {})
|
||||
if stop:
|
||||
generation_config["stop_sequences"] = stop
|
||||
for stream_resp in _completion_with_retry(
|
||||
self,
|
||||
prompt,
|
||||
stream=True,
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
**kwargs,
|
||||
):
|
||||
chunk = GenerationChunk(text=stream_resp.text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "google_palm"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
if self.is_gemini:
|
||||
raise ValueError("Counting tokens is not yet supported!")
|
||||
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||
return result["token_count"]
|
@@ -0,0 +1,65 @@
|
||||
"""Test Google GenerativeAI API wrapper.
|
||||
|
||||
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||
valid API key.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||
|
||||
model_names = [None, "models/text-bison-001", "gemini-pro"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_call(model_name: str) -> None:
|
||||
"""Test valid call to Google GenerativeAI text API."""
|
||||
if model_name:
|
||||
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
|
||||
else:
|
||||
llm = GoogleGenerativeAI(max_output_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "google_palm"
|
||||
if model_name and "gemini" in model_name:
|
||||
assert llm.client.model_name == "models/gemini-pro"
|
||||
else:
|
||||
assert llm.model == "models/text-bison-001"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names,
|
||||
)
|
||||
def test_google_generativeai_generate(model_name: str) -> None:
|
||||
n = 1 if model_name == "gemini-pro" else 2
|
||||
if model_name:
|
||||
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
|
||||
else:
|
||||
llm = GoogleGenerativeAI(temperature=0.3, n=n)
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == n
|
||||
|
||||
|
||||
def test_google_generativeai_get_num_tokens() -> None:
|
||||
llm = GoogleGenerativeAI()
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
async def test_google_generativeai_agenerate() -> None:
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_generativeai_stream() -> None:
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
@@ -3,6 +3,7 @@ from langchain_google_genai import __all__
|
||||
EXPECTED_ALL = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
"GoogleGenerativeAIEmbeddings",
|
||||
"GoogleGenerativeAI",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user