mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			349 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			349 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Wrapper around Google's PaLM Chat API."""
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import logging
 | 
						|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
 | 
						|
 | 
						|
from langchain_core.callbacks import (
 | 
						|
    AsyncCallbackManagerForLLMRun,
 | 
						|
    CallbackManagerForLLMRun,
 | 
						|
)
 | 
						|
from langchain_core.language_models.chat_models import BaseChatModel
 | 
						|
from langchain_core.messages import (
 | 
						|
    AIMessage,
 | 
						|
    BaseMessage,
 | 
						|
    ChatMessage,
 | 
						|
    HumanMessage,
 | 
						|
    SystemMessage,
 | 
						|
)
 | 
						|
from langchain_core.outputs import (
 | 
						|
    ChatGeneration,
 | 
						|
    ChatResult,
 | 
						|
)
 | 
						|
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
 | 
						|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
 | 
						|
from tenacity import (
 | 
						|
    before_sleep_log,
 | 
						|
    retry,
 | 
						|
    retry_if_exception_type,
 | 
						|
    stop_after_attempt,
 | 
						|
    wait_exponential,
 | 
						|
)
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    import google.generativeai as genai
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
class ChatGooglePalmError(Exception):
 | 
						|
    """Error with the `Google PaLM` API."""
 | 
						|
 | 
						|
 | 
						|
def _truncate_at_stop_tokens(
 | 
						|
    text: str,
 | 
						|
    stop: Optional[List[str]],
 | 
						|
) -> str:
 | 
						|
    """Truncates text at the earliest stop token found."""
 | 
						|
    if stop is None:
 | 
						|
        return text
 | 
						|
 | 
						|
    for stop_token in stop:
 | 
						|
        stop_token_idx = text.find(stop_token)
 | 
						|
        if stop_token_idx != -1:
 | 
						|
            text = text[:stop_token_idx]
 | 
						|
    return text
 | 
						|
 | 
						|
 | 
						|
def _response_to_result(
 | 
						|
    response: genai.types.ChatResponse,
 | 
						|
    stop: Optional[List[str]],
 | 
						|
) -> ChatResult:
 | 
						|
    """Converts a PaLM API response into a LangChain ChatResult."""
 | 
						|
    if not response.candidates:
 | 
						|
        raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
 | 
						|
 | 
						|
    generations: List[ChatGeneration] = []
 | 
						|
    for candidate in response.candidates:
 | 
						|
        author = candidate.get("author")
 | 
						|
        if author is None:
 | 
						|
            raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
 | 
						|
 | 
						|
        content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
 | 
						|
        if content is None:
 | 
						|
            raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
 | 
						|
 | 
						|
        if author == "ai":
 | 
						|
            generations.append(
 | 
						|
                ChatGeneration(text=content, message=AIMessage(content=content))
 | 
						|
            )
 | 
						|
        elif author == "human":
 | 
						|
            generations.append(
 | 
						|
                ChatGeneration(
 | 
						|
                    text=content,
 | 
						|
                    message=HumanMessage(content=content),
 | 
						|
                )
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            generations.append(
 | 
						|
                ChatGeneration(
 | 
						|
                    text=content,
 | 
						|
                    message=ChatMessage(role=author, content=content),
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
    return ChatResult(generations=generations)
 | 
						|
 | 
						|
 | 
						|
def _messages_to_prompt_dict(
 | 
						|
    input_messages: List[BaseMessage],
 | 
						|
) -> genai.types.MessagePromptDict:
 | 
						|
    """Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""
 | 
						|
    import google.generativeai as genai
 | 
						|
 | 
						|
    context: str = ""
 | 
						|
    examples: List[genai.types.MessageDict] = []
 | 
						|
    messages: List[genai.types.MessageDict] = []
 | 
						|
 | 
						|
    remaining = list(enumerate(input_messages))
 | 
						|
 | 
						|
    while remaining:
 | 
						|
        index, input_message = remaining.pop(0)
 | 
						|
 | 
						|
        if isinstance(input_message, SystemMessage):
 | 
						|
            if index != 0:
 | 
						|
                raise ChatGooglePalmError("System message must be first input message.")
 | 
						|
            context = cast(str, input_message.content)
 | 
						|
        elif isinstance(input_message, HumanMessage) and input_message.example:
 | 
						|
            if messages:
 | 
						|
                raise ChatGooglePalmError(
 | 
						|
                    "Message examples must come before other messages."
 | 
						|
                )
 | 
						|
            _, next_input_message = remaining.pop(0)
 | 
						|
            if isinstance(next_input_message, AIMessage) and next_input_message.example:
 | 
						|
                examples.extend(
 | 
						|
                    [
 | 
						|
                        genai.types.MessageDict(
 | 
						|
                            author="human", content=input_message.content
 | 
						|
                        ),
 | 
						|
                        genai.types.MessageDict(
 | 
						|
                            author="ai", content=next_input_message.content
 | 
						|
                        ),
 | 
						|
                    ]
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                raise ChatGooglePalmError(
 | 
						|
                    "Human example message must be immediately followed by an "
 | 
						|
                    " AI example response."
 | 
						|
                )
 | 
						|
        elif isinstance(input_message, AIMessage) and input_message.example:
 | 
						|
            raise ChatGooglePalmError(
 | 
						|
                "AI example message must be immediately preceded by a Human "
 | 
						|
                "example message."
 | 
						|
            )
 | 
						|
        elif isinstance(input_message, AIMessage):
 | 
						|
            messages.append(
 | 
						|
                genai.types.MessageDict(author="ai", content=input_message.content)
 | 
						|
            )
 | 
						|
        elif isinstance(input_message, HumanMessage):
 | 
						|
            messages.append(
 | 
						|
                genai.types.MessageDict(author="human", content=input_message.content)
 | 
						|
            )
 | 
						|
        elif isinstance(input_message, ChatMessage):
 | 
						|
            messages.append(
 | 
						|
                genai.types.MessageDict(
 | 
						|
                    author=input_message.role, content=input_message.content
 | 
						|
                )
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise ChatGooglePalmError(
 | 
						|
                "Messages without an explicit role not supported by PaLM API."
 | 
						|
            )
 | 
						|
 | 
						|
    return genai.types.MessagePromptDict(
 | 
						|
        context=context,
 | 
						|
        examples=examples,
 | 
						|
        messages=messages,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def _create_retry_decorator() -> Callable[[Any], Any]:
 | 
						|
    """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
 | 
						|
    import google.api_core.exceptions
 | 
						|
 | 
						|
    multiplier = 2
 | 
						|
    min_seconds = 1
 | 
						|
    max_seconds = 60
 | 
						|
    max_retries = 10
 | 
						|
 | 
						|
    return retry(
 | 
						|
        reraise=True,
 | 
						|
        stop=stop_after_attempt(max_retries),
 | 
						|
        wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
 | 
						|
        retry=(
 | 
						|
            retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
 | 
						|
            | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
 | 
						|
            | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
 | 
						|
        ),
 | 
						|
        before_sleep=before_sleep_log(logger, logging.WARNING),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
 | 
						|
    """Use tenacity to retry the completion call."""
 | 
						|
    retry_decorator = _create_retry_decorator()
 | 
						|
 | 
						|
    @retry_decorator
 | 
						|
    def _chat_with_retry(**kwargs: Any) -> Any:
 | 
						|
        return llm.client.chat(**kwargs)
 | 
						|
 | 
						|
    return _chat_with_retry(**kwargs)
 | 
						|
 | 
						|
 | 
						|
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
 | 
						|
    """Use tenacity to retry the async completion call."""
 | 
						|
    retry_decorator = _create_retry_decorator()
 | 
						|
 | 
						|
    @retry_decorator
 | 
						|
    async def _achat_with_retry(**kwargs: Any) -> Any:
 | 
						|
        # Use OpenAI's async api https://github.com/openai/openai-python#async-api
 | 
						|
        return await llm.client.chat_async(**kwargs)
 | 
						|
 | 
						|
    return await _achat_with_retry(**kwargs)
 | 
						|
 | 
						|
 | 
						|
class ChatGooglePalm(BaseChatModel, BaseModel):
 | 
						|
    """`Google PaLM` Chat models API.
 | 
						|
 | 
						|
    To use you must have the google.generativeai Python package installed and
 | 
						|
    either:
 | 
						|
 | 
						|
        1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
 | 
						|
        2. Pass your API key using the google_api_key kwarg to the ChatGoogle
 | 
						|
           constructor.
 | 
						|
 | 
						|
    Example:
 | 
						|
        .. code-block:: python
 | 
						|
 | 
						|
            from langchain_community.chat_models import ChatGooglePalm
 | 
						|
            chat = ChatGooglePalm()
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    client: Any  #: :meta private:
 | 
						|
    model_name: str = "models/chat-bison-001"
 | 
						|
    """Model name to use."""
 | 
						|
    google_api_key: Optional[SecretStr] = None
 | 
						|
    temperature: Optional[float] = None
 | 
						|
    """Run inference with this temperature. Must by in the closed
 | 
						|
       interval [0.0, 1.0]."""
 | 
						|
    top_p: Optional[float] = None
 | 
						|
    """Decode using nucleus sampling: consider the smallest set of tokens whose
 | 
						|
       probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
 | 
						|
    top_k: Optional[int] = None
 | 
						|
    """Decode using top-k sampling: consider the set of top_k most probable tokens.
 | 
						|
       Must be positive."""
 | 
						|
    n: int = 1
 | 
						|
    """Number of chat completions to generate for each prompt. Note that the API may
 | 
						|
       not return the full n completions if duplicates are generated."""
 | 
						|
 | 
						|
    @property
 | 
						|
    def lc_secrets(self) -> Dict[str, str]:
 | 
						|
        return {"google_api_key": "GOOGLE_API_KEY"}
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def is_lc_serializable(self) -> bool:
 | 
						|
        return True
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def get_lc_namespace(cls) -> List[str]:
 | 
						|
        """Get the namespace of the langchain object."""
 | 
						|
        return ["langchain", "chat_models", "google_palm"]
 | 
						|
 | 
						|
    @root_validator()
 | 
						|
    def validate_environment(cls, values: Dict) -> Dict:
 | 
						|
        """Validate api key, python package exists, temperature, top_p, and top_k."""
 | 
						|
        google_api_key = convert_to_secret_str(
 | 
						|
            get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY")
 | 
						|
        )
 | 
						|
        try:
 | 
						|
            import google.generativeai as genai
 | 
						|
 | 
						|
            genai.configure(api_key=google_api_key.get_secret_value())
 | 
						|
        except ImportError:
 | 
						|
            raise ChatGooglePalmError(
 | 
						|
                "Could not import google.generativeai python package. "
 | 
						|
                "Please install it with `pip install google-generativeai`"
 | 
						|
            )
 | 
						|
 | 
						|
        values["client"] = genai
 | 
						|
 | 
						|
        if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
 | 
						|
            raise ValueError("temperature must be in the range [0.0, 1.0]")
 | 
						|
 | 
						|
        if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
 | 
						|
            raise ValueError("top_p must be in the range [0.0, 1.0]")
 | 
						|
 | 
						|
        if values["top_k"] is not None and values["top_k"] <= 0:
 | 
						|
            raise ValueError("top_k must be positive")
 | 
						|
 | 
						|
        return values
 | 
						|
 | 
						|
    def _generate(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> ChatResult:
 | 
						|
        prompt = _messages_to_prompt_dict(messages)
 | 
						|
 | 
						|
        response: genai.types.ChatResponse = chat_with_retry(
 | 
						|
            self,
 | 
						|
            model=self.model_name,
 | 
						|
            prompt=prompt,
 | 
						|
            temperature=self.temperature,
 | 
						|
            top_p=self.top_p,
 | 
						|
            top_k=self.top_k,
 | 
						|
            candidate_count=self.n,
 | 
						|
            **kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
        return _response_to_result(response, stop)
 | 
						|
 | 
						|
    async def _agenerate(
 | 
						|
        self,
 | 
						|
        messages: List[BaseMessage],
 | 
						|
        stop: Optional[List[str]] = None,
 | 
						|
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ) -> ChatResult:
 | 
						|
        prompt = _messages_to_prompt_dict(messages)
 | 
						|
 | 
						|
        response: genai.types.ChatResponse = await achat_with_retry(
 | 
						|
            self,
 | 
						|
            model=self.model_name,
 | 
						|
            prompt=prompt,
 | 
						|
            temperature=self.temperature,
 | 
						|
            top_p=self.top_p,
 | 
						|
            top_k=self.top_k,
 | 
						|
            candidate_count=self.n,
 | 
						|
        )
 | 
						|
 | 
						|
        return _response_to_result(response, stop)
 | 
						|
 | 
						|
    @property
 | 
						|
    def _identifying_params(self) -> Dict[str, Any]:
 | 
						|
        """Get the identifying parameters."""
 | 
						|
        return {
 | 
						|
            "model_name": self.model_name,
 | 
						|
            "temperature": self.temperature,
 | 
						|
            "top_p": self.top_p,
 | 
						|
            "top_k": self.top_k,
 | 
						|
            "n": self.n,
 | 
						|
        }
 | 
						|
 | 
						|
    @property
 | 
						|
    def _llm_type(self) -> str:
 | 
						|
        return "google-palm-chat"
 |