community[patch]: chat model mypy fixes (#17061)

Related to #17048
This commit is contained in:
Bagatur
2024-02-05 13:42:59 -08:00
committed by GitHub
parent d93de71d08
commit 66e45e8ab7
17 changed files with 101 additions and 89 deletions

View File

@@ -14,6 +14,7 @@ from typing import (
Mapping,
Optional,
Tuple,
Type,
Union,
)
@@ -27,7 +28,7 @@ from langchain_core.language_models.chat_models import (
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
@@ -56,9 +57,9 @@ class GPTRouterModel(BaseModel):
provider_name: str
def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def]
models_priority_list: List[GPTRouterModel], **kwargs
):
def get_ordered_generation_requests(
models_priority_list: List[GPTRouterModel], **kwargs: Any
) -> List:
"""
Return the body for the model router input.
"""
@@ -100,7 +101,7 @@ def completion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@@ -122,7 +123,7 @@ async def acompletion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@@ -282,9 +283,9 @@ class GPTRouter(BaseChatModel):
)
return self._create_chat_result(response)
def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def]
self, data: Mapping[str, Any], default_chunk_class
):
def _create_chat_generation_chunk(
self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk]
) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]:
chunk = _convert_delta_to_message_chunk(
{"content": data.get("text", "")}, default_chunk_class
)
@@ -293,8 +294,8 @@ class GPTRouter(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment]
return chunk, default_chunk_class
gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
return gen_chunk, default_chunk_class
def _stream(
self,
@@ -306,7 +307,7 @@ class GPTRouter(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = completion_with_retry(
self,
messages=message_dicts,
@@ -339,7 +340,7 @@ class GPTRouter(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = acompletion_with_retry(
self,
messages=message_dicts,