community: fix some features on Naver ChatModel & embedding model (#28228)

# Description

- adding stopReason to response_metadata to call stream and astream
- excluding NCP_APIGW_API_KEY input required validation
- to remove warning Field "model_name" has conflict with protected
namespace "model_".

cc. @vbarda
This commit is contained in:
CLOVA Studio 개발 2024-11-21 03:35:41 +09:00 committed by GitHub
parent 4da35623af
commit 218b4e073e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 25 deletions

View File

@ -35,7 +35,7 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_env
from pydantic import AliasChoices, Field, SecretStr, model_validator
from pydantic import AliasChoices, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
_DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com"
@ -51,6 +51,12 @@ def _convert_chunk_to_message_chunk(
role = message.get("role")
content = message.get("content") or ""
if sse.event == "result":
response_metadata = {}
if "stopReason" in sse_data:
response_metadata["stopReason"] = sse_data["stopReason"]
return AIMessageChunk(content="", response_metadata=response_metadata)
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
@ -124,8 +130,6 @@ async def _aiter_sse(
event_data = sse.json()
if sse.event == "signal" and event_data.get("data", {}) == "[DONE]":
return
if sse.event == "result":
return
yield sse
@ -210,10 +214,7 @@ class ChatClovaX(BaseChatModel):
timeout: int = Field(gt=0, default=90)
max_retries: int = Field(ge=1, default=2)
class Config:
"""Configuration for this pydantic object."""
populate_by_name = True
model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
@property
def _default_params(self) -> Dict[str, Any]:
@ -286,7 +287,7 @@ class ChatClovaX(BaseChatModel):
if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
)
if not self.base_url:
@ -311,22 +312,28 @@ class ChatClovaX(BaseChatModel):
return self
def default_headers(self) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value()
if self.ncp_clovastudio_api_key
else None
)
if clovastudio_api_key:
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key
apigw_api_key = (
self.ncp_apigw_api_key.get_secret_value()
if self.ncp_apigw_api_key
else None
)
return {
"Content-Type": "application/json",
"Accept": "application/json",
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
"X-NCP-APIGW-API-KEY": apigw_api_key,
}
if apigw_api_key:
headers["X-NCP-APIGW-API-KEY"] = apigw_api_key
return headers
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
@ -363,8 +370,6 @@ class ChatClovaX(BaseChatModel):
and event_data.get("data", {}) == "[DONE]"
):
return
if sse.event == "result":
return
if sse.event == "error":
raise SSEError(message=sse.data)
yield sse

View File

@ -7,6 +7,7 @@ from langchain_core.utils import convert_to_secret_str, get_from_env
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
@ -86,8 +87,7 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
timeout: int = Field(gt=0, default=60)
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
@property
def lc_secrets(self) -> Dict[str, str]:
@ -115,7 +115,7 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
if not self.ncp_apigw_api_key:
self.ncp_apigw_api_key = convert_to_secret_str(
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
)
if not self.base_url:
@ -143,22 +143,28 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
return self
def default_headers(self) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
clovastudio_api_key = (
self.ncp_clovastudio_api_key.get_secret_value()
if self.ncp_clovastudio_api_key
else None
)
if clovastudio_api_key:
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key
apigw_api_key = (
self.ncp_apigw_api_key.get_secret_value()
if self.ncp_apigw_api_key
else None
)
return {
"Content-Type": "application/json",
"Accept": "application/json",
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
"X-NCP-APIGW-API-KEY": apigw_api_key,
}
if apigw_api_key:
headers["X-NCP-APIGW-API-KEY"] = apigw_api_key
return headers
def _embed_text(self, text: str) -> List[float]:
payload = {"text": text}