mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
community: fix some features on Naver ChatModel & embedding model (#28228)
# Description - adding stopReason to response_metadata to call stream and astream - excluding NCP_APIGW_API_KEY input required validation - to remove warning Field "model_name" has conflict with protected namespace "model_". cc. @vbarda
This commit is contained in:
parent
4da35623af
commit
218b4e073e
@ -35,7 +35,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_env
|
||||
from pydantic import AliasChoices, Field, SecretStr, model_validator
|
||||
from pydantic import AliasChoices, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
_DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com"
|
||||
@ -51,6 +51,12 @@ def _convert_chunk_to_message_chunk(
|
||||
role = message.get("role")
|
||||
content = message.get("content") or ""
|
||||
|
||||
if sse.event == "result":
|
||||
response_metadata = {}
|
||||
if "stopReason" in sse_data:
|
||||
response_metadata["stopReason"] = sse_data["stopReason"]
|
||||
return AIMessageChunk(content="", response_metadata=response_metadata)
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
@ -124,8 +130,6 @@ async def _aiter_sse(
|
||||
event_data = sse.json()
|
||||
if sse.event == "signal" and event_data.get("data", {}) == "[DONE]":
|
||||
return
|
||||
if sse.event == "result":
|
||||
return
|
||||
yield sse
|
||||
|
||||
|
||||
@ -210,10 +214,7 @@ class ChatClovaX(BaseChatModel):
|
||||
timeout: int = Field(gt=0, default=90)
|
||||
max_retries: int = Field(ge=1, default=2)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
populate_by_name = True
|
||||
model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
@ -286,7 +287,7 @@ class ChatClovaX(BaseChatModel):
|
||||
|
||||
if not self.ncp_apigw_api_key:
|
||||
self.ncp_apigw_api_key = convert_to_secret_str(
|
||||
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
|
||||
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
@ -311,22 +312,28 @@ class ChatClovaX(BaseChatModel):
|
||||
return self
|
||||
|
||||
def default_headers(self) -> Dict[str, Any]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
clovastudio_api_key = (
|
||||
self.ncp_clovastudio_api_key.get_secret_value()
|
||||
if self.ncp_clovastudio_api_key
|
||||
else None
|
||||
)
|
||||
if clovastudio_api_key:
|
||||
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key
|
||||
|
||||
apigw_api_key = (
|
||||
self.ncp_apigw_api_key.get_secret_value()
|
||||
if self.ncp_apigw_api_key
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
|
||||
"X-NCP-APIGW-API-KEY": apigw_api_key,
|
||||
}
|
||||
if apigw_api_key:
|
||||
headers["X-NCP-APIGW-API-KEY"] = apigw_api_key
|
||||
|
||||
return headers
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
@ -363,8 +370,6 @@ class ChatClovaX(BaseChatModel):
|
||||
and event_data.get("data", {}) == "[DONE]"
|
||||
):
|
||||
return
|
||||
if sse.event == "result":
|
||||
return
|
||||
if sse.event == "error":
|
||||
raise SSEError(message=sse.data)
|
||||
yield sse
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.utils import convert_to_secret_str, get_from_env
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
@ -86,8 +87,7 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
timeout: int = Field(gt=0, default=60)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
@ -115,7 +115,7 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
if not self.ncp_apigw_api_key:
|
||||
self.ncp_apigw_api_key = convert_to_secret_str(
|
||||
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
|
||||
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY", "")
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
@ -143,22 +143,28 @@ class ClovaXEmbeddings(BaseModel, Embeddings):
|
||||
return self
|
||||
|
||||
def default_headers(self) -> Dict[str, Any]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
clovastudio_api_key = (
|
||||
self.ncp_clovastudio_api_key.get_secret_value()
|
||||
if self.ncp_clovastudio_api_key
|
||||
else None
|
||||
)
|
||||
if clovastudio_api_key:
|
||||
headers["X-NCP-CLOVASTUDIO-API-KEY"] = clovastudio_api_key
|
||||
|
||||
apigw_api_key = (
|
||||
self.ncp_apigw_api_key.get_secret_value()
|
||||
if self.ncp_apigw_api_key
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
|
||||
"X-NCP-APIGW-API-KEY": apigw_api_key,
|
||||
}
|
||||
if apigw_api_key:
|
||||
headers["X-NCP-APIGW-API-KEY"] = apigw_api_key
|
||||
|
||||
return headers
|
||||
|
||||
def _embed_text(self, text: str) -> List[float]:
|
||||
payload = {"text": text}
|
||||
|
Loading…
Reference in New Issue
Block a user