mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
Update google palm model signatures (#3920)
Signatures out of date after callback refactors
This commit is contained in:
parent
145ff23fb1
commit
900ad106d3
@ -5,6 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -216,7 +220,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = _messages_to_prompt_dict(messages)
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
@ -232,7 +239,10 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
return _response_to_result(response, stop)
|
return _response_to_result(response, stop)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = _messages_to_prompt_dict(messages)
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
|
@ -5,6 +5,10 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain.llms import BaseLLM
|
from langchain.llms import BaseLLM
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
@ -74,7 +78,10 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
generations = []
|
generations = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
@ -99,7 +106,10 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user