diff --git a/libs/community/langchain_community/chat_models/yuan2.py b/libs/community/langchain_community/chat_models/yuan2.py index 9e7ad33229d..379dc630928 100644 --- a/libs/community/langchain_community/chat_models/yuan2.py +++ b/libs/community/langchain_community/chat_models/yuan2.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging from typing import ( - TYPE_CHECKING, Any, AsyncIterator, Callable, @@ -40,7 +39,7 @@ from langchain_core.messages import ( SystemMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.utils import ( get_from_dict_or_env, get_pydantic_field_names, @@ -53,9 +52,6 @@ from tenacity import ( wait_exponential, ) -if TYPE_CHECKING: - from openai.types.chat import ChatCompletion, ChatCompletionMessage - logger = logging.getLogger(__name__) @@ -91,7 +87,7 @@ class ChatYuan2(BaseChatModel): """Automatically inferred from env var `YUAN2_API_KEY` if not provided.""" yuan2_api_base: Optional[str] = Field( - default="http://127.0.0.1:8000", alias="base_url" + default="http://127.0.0.1:8000/v1", alias="base_url" ) """Base URL path for API requests, an OpenAI compatible API server.""" @@ -237,7 +233,7 @@ class ChatYuan2(BaseChatModel): # Happens in streaming continue token_usage = output["token_usage"] - for k, v in token_usage.__dict__.items(): + for k, v in token_usage.items(): if k in overall_token_usage: overall_token_usage[k] += v else: @@ -306,21 +302,23 @@ class ChatYuan2(BaseChatModel): message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params - def _create_chat_result(self, response: ChatCompletion) -> ChatResult: + def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: generations = [] logger.debug(f"type(response): {type(response)}; response: {response}") - for res in response.choices: - message = _convert_dict_to_message(res.message) - generation_info = dict(finish_reason=res.finish_reason) + if not isinstance(response, dict): + response = response.dict() + for res in response["choices"]: + message = _convert_dict_to_message(res["message"]) + generation_info = dict(finish_reason=res["finish_reason"]) if "logprobs" in res: - generation_info["logprobs"] = res.logprobs + generation_info["logprobs"] = res["logprobs"] gen = ChatGeneration( message=message, generation_info=generation_info, ) generations.append(gen) llm_output = { - "token_usage": response.usage, + "token_usage": response.get("usage", {}), "model_name": self.model_name, } return ChatResult(generations=generations, llm_output=llm_output) @@ -427,7 +425,7 @@ async def acompletion_with_retry(llm: ChatYuan2, **kwargs: Any) -> Any: def _convert_delta_to_message_chunk( - _dict: ChatCompletionMessage, default_class: Type[BaseMessageChunk] + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: role = _dict.get("role") content = _dict.get("content") or "" @@ -444,17 +442,16 @@ def _convert_delta_to_message_chunk( return default_class(content=content) -def _convert_dict_to_message(_dict: ChatCompletionMessage) -> BaseMessage: +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict.get("role") if role == "user": - return HumanMessage(content=_dict.get("content")) + return HumanMessage(content=_dict.get("content", "")) elif role == "assistant": - content = _dict.get("content") or "" - return AIMessage(content=content) + return AIMessage(content=_dict.get("content", "")) elif role == "system": - return SystemMessage(content=_dict.get("content")) + return SystemMessage(content=_dict.get("content", "")) else: - return ChatMessage(content=_dict.get("content"), role=role) + return ChatMessage(content=_dict.get("content", ""), role=role) def _convert_message_to_dict(message: BaseMessage) -> dict: diff --git a/libs/community/tests/integration_tests/chat_models/test_yuan2.py b/libs/community/tests/integration_tests/chat_models/test_yuan2.py index 17a1c40a079..53678016e15 100644 --- a/libs/community/tests/integration_tests/chat_models/test_yuan2.py +++ b/libs/community/tests/integration_tests/chat_models/test_yuan2.py @@ -27,7 +27,7 @@ def test_chat_yuan2() -> None: messages = [ HumanMessage(content="Hello"), ] - response = chat(messages) + response = chat.invoke(messages) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) @@ -46,7 +46,7 @@ def test_chat_yuan2_system_message() -> None: SystemMessage(content="You are an AI assistant."), HumanMessage(content="Hello"), ] - response = chat(messages) + response = chat.invoke(messages) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) @@ -89,12 +89,12 @@ def test_chat_yuan2_streaming() -> None: model_name="yuan2", max_retries=3, streaming=True, - callback_manager=callback_manager, + callbacks=callback_manager, ) messages = [ HumanMessage(content="Hello"), ] - response = chat(messages) + response = chat.invoke(messages) assert callback_handler.llm_streams > 0 assert isinstance(response, BaseMessage) @@ -136,7 +136,7 @@ async def test_async_chat_yuan2_streaming() -> None: model_name="yuan2", max_retries=3, streaming=True, - callback_manager=callback_manager, + callbacks=callback_manager, ) messages: List = [ HumanMessage(content="Hello"),