mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
parent
d93de71d08
commit
66e45e8ab7
@ -20,7 +20,7 @@ from langchain_community.llms.azureml_endpoint import (
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
def __init__(self): # type: ignore[no-untyped-def]
|
||||
def __init__(self) -> None:
|
||||
raise TypeError(
|
||||
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||
"`LlamaChatContentFormatter` instead."
|
||||
@ -72,12 +72,12 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_request_payload( # type: ignore[override]
|
||||
def format_messages_request_payload(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType,
|
||||
) -> str:
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
chat_messages = [
|
||||
LlamaChatContentFormatter._convert_message_to_dict(message)
|
||||
@ -98,17 +98,19 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
)
|
||||
return str.encode(request_payload) # type: ignore[return-value]
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload( # type: ignore[override]
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
def format_response_payload(
|
||||
self,
|
||||
output: bytes,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
) -> ChatGeneration:
|
||||
"""Formats response"""
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
try:
|
||||
choice = json.loads(output)["output"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
content=choice.strip(),
|
||||
@ -125,7 +127,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
"model. Expected `dict` but `{type(choice)}` was received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
content=choice["message"]["content"].strip(),
|
||||
@ -187,7 +189,7 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||
if stop:
|
||||
_model_kwargs["stop"] = stop
|
||||
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
request_payload = self.content_formatter.format_messages_request_payload(
|
||||
messages, _model_kwargs, self.endpoint_api_type
|
||||
)
|
||||
response_payload = self.http_client.call(
|
||||
|
@ -327,7 +327,7 @@ class ChatDeepInfra(BaseChatModel):
|
||||
if chunk:
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content) # type: ignore[arg-type]
|
||||
run_manager.on_llm_new_token(str(chunk.content))
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -349,7 +349,7 @@ class ChatDeepInfra(BaseChatModel):
|
||||
if chunk:
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content) # type: ignore[arg-type]
|
||||
await run_manager.on_llm_new_token(str(chunk.content))
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -165,6 +165,12 @@ class ChatEdenAI(BaseChatModel):
|
||||
"""Return type of chat model."""
|
||||
return "edenai-chat"
|
||||
|
||||
@property
|
||||
def _api_key(self) -> str:
|
||||
if self.edenai_api_key:
|
||||
return self.edenai_api_key.get_secret_value()
|
||||
return ""
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -175,7 +181,7 @@ class ChatEdenAI(BaseChatModel):
|
||||
"""Call out to EdenAI's chat endpoint."""
|
||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
@ -216,7 +222,7 @@ class ChatEdenAI(BaseChatModel):
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
@ -265,7 +271,7 @@ class ChatEdenAI(BaseChatModel):
|
||||
|
||||
url = f"{self.edenai_api_url}/text/chat"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
@ -323,7 +329,7 @@ class ChatEdenAI(BaseChatModel):
|
||||
|
||||
url = f"{self.edenai_api_url}/text/chat"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
|
@ -214,7 +214,7 @@ class ErnieBotChat(BaseChatModel):
|
||||
generations = [
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
content=response.get("result"), # type: ignore[arg-type]
|
||||
content=response.get("result", ""),
|
||||
additional_kwargs={**additional_kwargs},
|
||||
)
|
||||
)
|
||||
|
@ -14,6 +14,7 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -27,7 +28,7 @@ from langchain_core.language_models.chat_models import (
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
@ -56,9 +57,9 @@ class GPTRouterModel(BaseModel):
|
||||
provider_name: str
|
||||
|
||||
|
||||
def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def]
|
||||
models_priority_list: List[GPTRouterModel], **kwargs
|
||||
):
|
||||
def get_ordered_generation_requests(
|
||||
models_priority_list: List[GPTRouterModel], **kwargs: Any
|
||||
) -> List:
|
||||
"""
|
||||
Return the body for the model router input.
|
||||
"""
|
||||
@ -100,7 +101,7 @@ def completion_with_retry(
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
|
||||
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@ -122,7 +123,7 @@ async def acompletion_with_retry(
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
|
||||
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
@ -282,9 +283,9 @@ class GPTRouter(BaseChatModel):
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def]
|
||||
self, data: Mapping[str, Any], default_chunk_class
|
||||
):
|
||||
def _create_chat_generation_chunk(
|
||||
self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk]
|
||||
) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]:
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
{"content": data.get("text", "")}, default_chunk_class
|
||||
)
|
||||
@ -293,8 +294,8 @@ class GPTRouter(BaseChatModel):
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment]
|
||||
return chunk, default_chunk_class
|
||||
gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
return gen_chunk, default_chunk_class
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -306,7 +307,7 @@ class GPTRouter(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
generator_response = completion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
@ -339,7 +340,7 @@ class GPTRouter(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
generator_response = acompletion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
|
@ -44,7 +44,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
|
||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||
tokenizer: Any = None
|
||||
model_id: str = None # type: ignore
|
||||
model_id: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
@ -144,7 +144,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
elif isinstance(self.llm, HuggingFaceHub):
|
||||
# no need to look up model_id for HuggingFaceHub LLM
|
||||
self.model_id = self.llm.repo_id # type: ignore[assignment]
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
|
||||
else:
|
||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Ignoring type because below is valid pydantic code
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
|
||||
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||
class ChatParams(BaseModel, extra=Extra.allow):
|
||||
"""Parameters for the `Javelin AI Gateway` LLM."""
|
||||
|
||||
temperature: float = 0.0
|
||||
|
@ -13,6 +13,7 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import requests
|
||||
@ -169,7 +170,9 @@ class ChatKonko(ChatOpenAI):
|
||||
}
|
||||
|
||||
if openai_api_key:
|
||||
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() # type: ignore[union-attr]
|
||||
headers["X-OpenAI-Api-Key"] = cast(
|
||||
SecretStr, openai_api_key
|
||||
).get_secret_value()
|
||||
|
||||
models_response = requests.get(models_url, headers=headers)
|
||||
|
||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Ignoring type because below is valid pydantic code
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
|
||||
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||
class ChatParams(BaseModel, extra=Extra.allow):
|
||||
"""Parameters for the `MLflow AI Gateway` LLM."""
|
||||
|
||||
temperature: float = 0.0
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -74,10 +74,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
if message.content[0].get("type") == "text": # type: ignore[union-attr]
|
||||
message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index]
|
||||
elif message.content[0].get("type") == "image_url": # type: ignore[union-attr]
|
||||
message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index]
|
||||
if isinstance(message.content, List):
|
||||
first_content = cast(List[Dict], message.content)[0]
|
||||
content_type = first_content.get("type")
|
||||
if content_type == "text":
|
||||
message_text = f"[INST] {first_content['text']} [/INST]"
|
||||
elif content_type == "image_url":
|
||||
message_text = first_content["image_url"]["url"]
|
||||
else:
|
||||
message_text = f"[INST] {message.content} [/INST]"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
@ -94,7 +99,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
def _convert_messages_to_ollama_messages(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
ollama_messages = []
|
||||
ollama_messages: List = []
|
||||
for message in messages:
|
||||
role = ""
|
||||
if isinstance(message, HumanMessage):
|
||||
@ -111,12 +116,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
if isinstance(message.content, str):
|
||||
content = message.content
|
||||
else:
|
||||
for content_part in message.content:
|
||||
if content_part.get("type") == "text": # type: ignore[union-attr]
|
||||
content += f"\n{content_part['text']}" # type: ignore[index]
|
||||
elif content_part.get("type") == "image_url": # type: ignore[union-attr]
|
||||
if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr]
|
||||
image_url_components = content_part["image_url"].split(",") # type: ignore[index]
|
||||
for content_part in cast(List[Dict], message.content):
|
||||
if content_part.get("type") == "text":
|
||||
content += f"\n{content_part['text']}"
|
||||
elif content_part.get("type") == "image_url":
|
||||
if isinstance(content_part.get("image_url"), str):
|
||||
image_url_components = content_part["image_url"].split(",")
|
||||
# Support data:image/jpeg;base64,<image> format
|
||||
# and base64 strings
|
||||
if len(image_url_components) > 1:
|
||||
@ -142,7 +147,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
}
|
||||
)
|
||||
|
||||
return ollama_messages # type: ignore[return-value]
|
||||
return ollama_messages
|
||||
|
||||
def _create_chat_stream(
|
||||
self,
|
||||
@ -324,10 +329,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
try:
|
||||
async for stream_resp in self._acreate_chat_stream(
|
||||
messages, stop, **kwargs
|
||||
):
|
||||
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
|
||||
if stream_resp:
|
||||
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
@ -336,9 +338,6 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
except OllamaEndpointNotFoundError:
|
||||
async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
|
||||
yield chunk
|
||||
|
||||
@deprecated("0.0.3", alternative="_stream")
|
||||
def _legacy_stream(
|
||||
|
@ -554,7 +554,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
if self.openai_proxy:
|
||||
import openai
|
||||
|
||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
|
||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}
|
||||
return {**self._default_params, **openai_creds}
|
||||
|
||||
def _get_invocation_params(
|
||||
|
@ -13,6 +13,7 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
@ -197,7 +198,7 @@ class ChatTongyi(BaseChatModel):
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"top_p": self.top_p,
|
||||
"api_key": self.dashscope_api_key.get_secret_value(), # type: ignore[union-attr]
|
||||
"api_key": cast(SecretStr, self.dashscope_api_key).get_secret_value(),
|
||||
"result_format": "message",
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
@ -120,11 +120,10 @@ def _parse_chat_history_gemini(
|
||||
image = load_image_from_gcs(path=path, project=project)
|
||||
elif path.startswith("data:image/"):
|
||||
# extract base64 component from image uri
|
||||
try:
|
||||
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( # type: ignore[union-attr]
|
||||
1
|
||||
)
|
||||
except AttributeError:
|
||||
encoded: Any = re.search(r"data:image/\w{2,4};base64,(.*)", path)
|
||||
if encoded:
|
||||
encoded = encoded.group(1)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid image uri. It should be in the format "
|
||||
"data:image/<image_type>;base64,<base64_encoded_image>."
|
||||
|
@ -52,7 +52,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]:
|
||||
return chat_history
|
||||
|
||||
|
||||
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): # type: ignore[misc]
|
||||
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
|
||||
"""Wrapper around YandexGPT large language models.
|
||||
|
||||
There are two authentication options for the service account
|
||||
@ -156,7 +156,7 @@ def _make_request(
|
||||
messages=[Message(**message) for message in message_history],
|
||||
)
|
||||
stub = TextGenerationServiceStub(channel)
|
||||
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
||||
res = stub.Completion(request, metadata=self._grpc_metadata)
|
||||
return list(res)[0].alternatives[0].message.text
|
||||
|
||||
|
||||
@ -201,7 +201,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
|
||||
messages=[Message(**message) for message in message_history],
|
||||
)
|
||||
stub = TextGenerationAsyncServiceStub(channel)
|
||||
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
||||
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
||||
async with grpc.aio.secure_channel(
|
||||
operation_api_url, channel_credentials
|
||||
) as operation_channel:
|
||||
@ -211,7 +211,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
|
||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||
operation = await operation_stub.Get(
|
||||
operation_request,
|
||||
metadata=self._grpc_metadata, # type: ignore[attr-defined]
|
||||
metadata=self._grpc_metadata,
|
||||
)
|
||||
|
||||
completion_response = CompletionResponse()
|
||||
|
@ -5,7 +5,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
from typing import Any, Dict, Iterator, List, Optional, cast
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import (
|
||||
@ -161,7 +161,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
|
||||
return attributes
|
||||
|
||||
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
try:
|
||||
import zhipuai
|
||||
@ -174,7 +174,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
"Please install it via 'pip install zhipuai'"
|
||||
)
|
||||
|
||||
def invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||
def invoke(self, prompt: Any) -> Any: # type: ignore[override]
|
||||
if self.model == "chatglm_turbo":
|
||||
return self.zhipuai.model_api.invoke(
|
||||
model=self.model,
|
||||
@ -185,17 +185,17 @@ class ChatZhipuAI(BaseChatModel):
|
||||
return_type=self.return_type,
|
||||
)
|
||||
elif self.model == "characterglm":
|
||||
meta = self.meta.dict()
|
||||
_meta = cast(meta, self.meta).dict()
|
||||
return self.zhipuai.model_api.invoke(
|
||||
model=self.model,
|
||||
meta=meta,
|
||||
meta=_meta,
|
||||
prompt=prompt,
|
||||
request_id=self.request_id,
|
||||
return_type=self.return_type,
|
||||
)
|
||||
return None
|
||||
|
||||
def sse_invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||
def sse_invoke(self, prompt: Any) -> Any:
|
||||
if self.model == "chatglm_turbo":
|
||||
return self.zhipuai.model_api.sse_invoke(
|
||||
model=self.model,
|
||||
@ -207,18 +207,18 @@ class ChatZhipuAI(BaseChatModel):
|
||||
incremental=self.incremental,
|
||||
)
|
||||
elif self.model == "characterglm":
|
||||
meta = self.meta.dict()
|
||||
_meta = cast(meta, self.meta).dict()
|
||||
return self.zhipuai.model_api.sse_invoke(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
meta=meta,
|
||||
meta=_meta,
|
||||
request_id=self.request_id,
|
||||
return_type=self.return_type,
|
||||
incremental=self.incremental,
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||
async def async_invoke(self, prompt: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
partial_func = partial(
|
||||
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
|
||||
@ -229,7 +229,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_invoke_result(self, task_id): # type: ignore[no-untyped-def]
|
||||
async def async_invoke_result(self, task_id: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
@ -247,7 +247,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a chat response."""
|
||||
prompt = []
|
||||
prompt: List = []
|
||||
for message in messages:
|
||||
if isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
@ -270,7 +270,7 @@ class ChatZhipuAI(BaseChatModel):
|
||||
|
||||
else:
|
||||
stream_iter = self._stream(
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
|
@ -101,7 +101,7 @@ class ContentFormatterBase:
|
||||
accepts: Optional[str] = "application/json"
|
||||
"""The MIME type of the response data returned from the endpoint"""
|
||||
|
||||
format_error_msg: Optional[str] = (
|
||||
format_error_msg: str = (
|
||||
"Error while formatting response payload for chat model of type "
|
||||
" `{api_type}`. Are you using the right formatter for the deployed "
|
||||
" model and endpoint type?"
|
||||
@ -134,17 +134,17 @@ class ContentFormatterBase:
|
||||
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
@abstractmethod
|
||||
def format_request_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
) -> bytes:
|
||||
) -> Any:
|
||||
"""Formats the request body according to the input schema of
|
||||
the model. Returns bytes or seekable file like object in the
|
||||
format specified in the content_type request header.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def format_response_payload(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -54,13 +54,14 @@ class _BaseYandexGPT(Serializable):
|
||||
"""Maximum number of retries to make when generating."""
|
||||
sleep_interval: float = 1.0
|
||||
"""Delay between API requests"""
|
||||
_grpc_metadata: Sequence
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "yandex_gpt"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_uri": self.model_uri,
|
||||
|
Loading…
Reference in New Issue
Block a user