community: fix some features on Naver ChatModel & embedding model (#28228)

# Description

- adding stopReason to response_metadata to call stream and astream
- excluding NCP_APIGW_API_KEY input required validation
- to remove warning Field "model_name" has conflict with protected
namespace "model_".

cc. @vbarda
This commit is contained in:
CLOVA Studio 개발 2024-11-21 03:35:41 +09:00 committed by GitHub
parent 4da35623af
commit 218b4e073e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 25 deletions

View File

@ -35,7 +35,7 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_env from langchain_core.utils import convert_to_secret_str, get_from_env
from pydantic import AliasChoices, Field, SecretStr, model_validator from pydantic import AliasChoices, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
_DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com" _DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com"
@ -51,6 +51,12 @@ def _convert_chunk_to_message_chunk(
role = message.get("role") role = message.get("role")
content = message.get("content") or "" content = message.get("content") or ""
if sse.event == "result":
response_metadata = {}
if "stopReason" in sse_data:
response_metadata["stopReason"] = sse_data["stopReason"]
return AIMessageChunk(content="", response_metadata=response_metadata)
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: elif role == "assistant" or default_class == AIMessageChunk:
@ -124,8 +130,6 @@ async def _aiter_sse(
event_data = sse.json() event_data = sse.json()
if sse.event == "signal" and event_data.get("data", {}) == "[DONE]": if sse.event == "signal" and event_data.get("data", {}) == "[DONE]":
return return
if sse.event == "result":
return
yield sse yield sse
@ -210,10 +214,7 @@ class ChatClovaX(BaseChatModel):
timeout: int = Field(gt=0, default=90) timeout: int = Field(gt=0, default=90)
max_retries: int = Field(ge=1, default=2) max_retries: int = Field(ge=1, default=2)
class Config: model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
"""Configuration for this pydantic object."""
populate_by_name = True
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
@ -286,7 +287,7 @@ class ChatClovaX(BaseChatModel):
if not self.ncp_apigw_api_key: if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str( self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY") get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
) )
if not self.base_url: if not self.base_url:
@ -311,22 +312,28 @@ class ChatClovaX(BaseChatModel):
return self return self
def default_headers(self) -> Dict[str, Any]: def default_headers(self) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
clovastudio_api_key = ( clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value() self.ncp_clovastudio_api_key.get_secret_value()
if self.ncp_clovastudio_api_key if self.ncp_clovastudio_api_key
else None else None
) )
if clovastudio_api_key:
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key
apigw_api_key = ( apigw_api_key = (
self.ncp_apigw_api_key.get_secret_value() self.ncp_apigw_api_key.get_secret_value()
if self.ncp_apigw_api_key if self.ncp_apigw_api_key
else None else None
) )
return { if apigw_api_key:
"Content-Type": "application/json", headers["X-NCP-APIGW-API-KEY"] = apigw_api_key
"Accept": "application/json",
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key, return headers
"X-NCP-APIGW-API-KEY": apigw_api_key,
}
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: List[BaseMessage], stop: Optional[List[str]]
@ -363,8 +370,6 @@ class ChatClovaX(BaseChatModel):
and event_data.get("data", {}) == "[DONE]" and event_data.get("data", {}) == "[DONE]"
): ):
return return
if sse.event == "result":
return
if sse.event == "error": if sse.event == "error":
raise SSEError(message=sse.data) raise SSEError(message=sse.data)
yield sse yield sse

View File

@ -7,6 +7,7 @@ from langchain_core.utils import convert_to_secret_str, get_from_env
from pydantic import ( from pydantic import (
AliasChoices, AliasChoices,
BaseModel, BaseModel,
ConfigDict,
Field, Field,
SecretStr, SecretStr,
model_validator, model_validator,
@ -86,8 +87,7 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
timeout: int = Field(gt=0, default=60) timeout: int = Field(gt=0, default=60)
class Config: model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
arbitrary_types_allowed = True
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
@ -115,7 +115,7 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
if not self.ncp_apigw_api_key: if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str( self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY") get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
) )
if not self.base_url: if not self.base_url:
@ -143,22 +143,28 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
return self return self
def default_headers(self) -> Dict[str, Any]: def default_headers(self) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
clovastudio_api_key = ( clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value() self.ncp_clovastudio_api_key.get_secret_value()
if self.ncp_clovastudio_api_key if self.ncp_clovastudio_api_key
else None else None
) )
if clovastudio_api_key:
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key
apigw_api_key = ( apigw_api_key = (
self.ncp_apigw_api_key.get_secret_value() self.ncp_apigw_api_key.get_secret_value()
if self.ncp_apigw_api_key if self.ncp_apigw_api_key
else None else None
) )
return { if apigw_api_key:
"Content-Type": "application/json", headers["X-NCP-APIGW-API-KEY"] = apigw_api_key
"Accept": "application/json",
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key, return headers
"X-NCP-APIGW-API-KEY": apigw_api_key,
}
def _embed_text(self, text: str) -> List[float]: def _embed_text(self, text: str) -> List[float]:
payload = {"text": text} payload = {"text": text}