mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 14:36:54 +00:00
community: Add Naver chat model & embeddings (#25162)
Reopened as a personal repo outside the organization. ## Description - Naver HyperCLOVA X community package - Add chat model & embeddings - Add unit test & integration test - Add chat model & embeddings docs - I changed partner package(https://github.com/langchain-ai/langchain/pull/24252) to community package on this PR - Could this embeddings(https://github.com/langchain-ai/langchain/pull/21890) be deprecated? We are trying to replace it with embedding model(**ClovaXEmbeddings**) in this PR. Twitter handle: None. (if needed, contact with joonha.jeon@navercorp.com) --- you can check our previous discussion below: > one question on namespaces - would it make sense to have these in .clova namespaces instead of .naver? I would like to keep it as is, unless it is essential to unify the package name. (ClovaX is a branding for the model, and I plan to add other models and components. They need to be managed as separate classes.) > also, could you clarify the difference between ClovaEmbeddings and ClovaXEmbeddings? There are 3 models that are being serviced by embedding, and all are supported in the current PR. In addition, all the functionality of CLOVA Studio that serves actual models, such as distinguishing between test apps and service apps, is supported. The existing PR does not support this content because it is hard-coded. --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
@@ -125,6 +125,9 @@ if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.moonshot import (
|
||||
MoonshotChat,
|
||||
)
|
||||
from langchain_community.chat_models.naver import (
|
||||
ChatClovaX,
|
||||
)
|
||||
from langchain_community.chat_models.oci_data_science import (
|
||||
ChatOCIModelDeployment,
|
||||
ChatOCIModelDeploymentTGI,
|
||||
@@ -193,6 +196,7 @@ __all__ = [
|
||||
"ChatAnthropic",
|
||||
"ChatAnyscale",
|
||||
"ChatBaichuan",
|
||||
"ChatClovaX",
|
||||
"ChatCohere",
|
||||
"ChatCoze",
|
||||
"ChatOctoAI",
|
||||
@@ -257,6 +261,7 @@ _module_lookup = {
|
||||
"ChatAnthropic": "langchain_community.chat_models.anthropic",
|
||||
"ChatAnyscale": "langchain_community.chat_models.anyscale",
|
||||
"ChatBaichuan": "langchain_community.chat_models.baichuan",
|
||||
"ChatClovaX": "langchain_community.chat_models.naver",
|
||||
"ChatCohere": "langchain_community.chat_models.cohere",
|
||||
"ChatCoze": "langchain_community.chat_models.coze",
|
||||
"ChatDatabricks": "langchain_community.chat_models.databricks",
|
||||
|
524
libs/community/langchain_community/chat_models/naver.py
Normal file
524
libs/community/langchain_community/chat_models/naver.py
Normal file
@@ -0,0 +1,524 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_env
|
||||
from pydantic import AliasChoices, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
_DEFAULT_BASE_URL = "https://clovastudio.stream.ntruss.com"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
sse: Any, default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
sse_data = sse.json()
|
||||
message = sse_data.get("message")
|
||||
role = message.get("role")
|
||||
content = message.get("content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _convert_message_to_naver_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, SystemMessage):
|
||||
return dict(role="system", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
return dict(role="assistant", content=message.content)
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionMessage, ToolMessage not yet supported "
|
||||
"(https://api.ncloud-docs.com/docs/clovastudio-chatcompletions)"
|
||||
)
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def _convert_naver_chat_message_to_message(
|
||||
_message: Dict,
|
||||
) -> BaseMessage:
|
||||
role = _message["role"]
|
||||
assert role in (
|
||||
"assistant",
|
||||
"system",
|
||||
"user",
|
||||
), f"Expected role to be 'assistant', 'system', 'user', got {role}"
|
||||
content = cast(str, _message["content"])
|
||||
additional_kwargs: Dict = {}
|
||||
|
||||
if role == "user":
|
||||
return HumanMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
elif role == "assistant":
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
logger.warning("Got unknown role %s", role)
|
||||
raise ValueError(f"Got unknown role {role}")
|
||||
|
||||
|
||||
async def _aiter_sse(
|
||||
event_source_mgr: AsyncContextManager[Any],
|
||||
) -> AsyncIterator[Dict]:
|
||||
"""Iterate over the server-sent events."""
|
||||
async with event_source_mgr as event_source:
|
||||
await _araise_on_error(event_source.response)
|
||||
async for sse in event_source.aiter_sse():
|
||||
event_data = sse.json()
|
||||
if sse.event == "signal" and event_data.get("data", {}) == "[DONE]":
|
||||
return
|
||||
if sse.event == "result":
|
||||
return
|
||||
yield sse
|
||||
|
||||
|
||||
def _raise_on_error(response: httpx.Response) -> None:
|
||||
"""Raise an error if the response is an error."""
|
||||
if httpx.codes.is_error(response.status_code):
|
||||
error_message = response.read().decode("utf-8")
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Error response {response.status_code} "
|
||||
f"while fetching {response.url}: {error_message}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
async def _araise_on_error(response: httpx.Response) -> None:
|
||||
"""Raise an error if the response is an error."""
|
||||
if httpx.codes.is_error(response.status_code):
|
||||
error_message = (await response.aread()).decode("utf-8")
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Error response {response.status_code} "
|
||||
f"while fetching {response.url}: {error_message}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
class ChatClovaX(BaseChatModel):
|
||||
"""`NCP ClovaStudio` Chat Completion API.
|
||||
|
||||
following environment variables set or passed in constructor in lower case:
|
||||
- ``NCP_CLOVASTUDIO_API_KEY``
|
||||
- ``NCP_APIGW_API_KEY``
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from langchain_community import ChatClovaX
|
||||
|
||||
model = ChatClovaX()
|
||||
model.invoke([HumanMessage(content="Come up with 10 names for a song about parrots.")])
|
||||
""" # noqa: E501
|
||||
|
||||
client: httpx.Client = Field(default=None) #: :meta private:
|
||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||
|
||||
model_name: str = Field(
|
||||
default="HCX-003",
|
||||
validation_alias=AliasChoices("model_name", "model"),
|
||||
description="NCP ClovaStudio chat model name",
|
||||
)
|
||||
task_id: Optional[str] = Field(
|
||||
default=None, description="NCP Clova Studio chat model tuning task ID"
|
||||
)
|
||||
service_app: bool = Field(
|
||||
default=False,
|
||||
description="false: use testapp, true: use service app on NCP Clova Studio",
|
||||
)
|
||||
|
||||
ncp_clovastudio_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env are `NCP_CLOVASTUDIO_API_KEY` if not provided."""
|
||||
|
||||
ncp_apigw_api_key: Optional[SecretStr] = Field(default=None, alias="apigw_api_key")
|
||||
"""Automatically inferred from env are `NCP_APIGW_API_KEY` if not provided."""
|
||||
|
||||
base_url: str = Field(default=None, alias="base_url")
|
||||
"""
|
||||
Automatically inferred from env are `NCP_CLOVASTUDIO_API_BASE_URL` if not provided.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = Field(gt=0.0, le=1.0, default=0.5)
|
||||
top_k: Optional[int] = Field(ge=0, le=128, default=0)
|
||||
top_p: Optional[float] = Field(ge=0, le=1.0, default=0.8)
|
||||
repeat_penalty: Optional[float] = Field(gt=0.0, le=10, default=5.0)
|
||||
max_tokens: Optional[int] = Field(ge=0, le=4096, default=100)
|
||||
stop_before: Optional[list[str]] = Field(default=None, alias="stop")
|
||||
include_ai_filters: Optional[bool] = Field(default=False)
|
||||
seed: Optional[int] = Field(ge=0, le=4294967295, default=0)
|
||||
|
||||
timeout: int = Field(gt=0, default=90)
|
||||
max_retries: int = Field(ge=1, default=2)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
populate_by_name = True
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling the API."""
|
||||
defaults = {
|
||||
"temperature": self.temperature,
|
||||
"topK": self.top_k,
|
||||
"topP": self.top_p,
|
||||
"repeatPenalty": self.repeat_penalty,
|
||||
"maxTokens": self.max_tokens,
|
||||
"stopBefore": self.stop_before,
|
||||
"includeAiFilters": self.include_ai_filters,
|
||||
"seed": self.seed,
|
||||
}
|
||||
filtered = {k: v for k, v in defaults.items() if v is not None}
|
||||
return filtered
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
self._default_params["model_name"] = self.model_name
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"ncp_clovastudio_api_key": "NCP_CLOVASTUDIO_API_KEY",
|
||||
"ncp_apigw_api_key": "NCP_APIGW_API_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "chat-naver"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "naver"
|
||||
return params
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the client."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _api_url(self) -> str:
|
||||
"""GET chat completion api url"""
|
||||
app_type = "serviceapp" if self.service_app else "testapp"
|
||||
|
||||
if self.task_id:
|
||||
return (
|
||||
f"{self.base_url}/{app_type}/v1/tasks/{self.task_id}/chat-completions"
|
||||
)
|
||||
else:
|
||||
return f"{self.base_url}/{app_type}/v1/chat-completions/{self.model_name}"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model_after(self) -> Self:
|
||||
if not (self.model_name or self.task_id):
|
||||
raise ValueError("either model_name or task_id must be assigned a value.")
|
||||
|
||||
if not self.ncp_clovastudio_api_key:
|
||||
self.ncp_clovastudio_api_key = convert_to_secret_str(
|
||||
get_from_env("ncp_clovastudio_api_key", "NCP_CLOVASTUDIO_API_KEY")
|
||||
)
|
||||
|
||||
if not self.ncp_apigw_api_key:
|
||||
self.ncp_apigw_api_key = convert_to_secret_str(
|
||||
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
self.base_url = get_from_env(
|
||||
"base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL
|
||||
)
|
||||
|
||||
if not self.client:
|
||||
self.client = httpx.Client(
|
||||
base_url=self.base_url,
|
||||
headers=self.default_headers(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers=self.default_headers(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def default_headers(self) -> Dict[str, Any]:
|
||||
clovastudio_api_key = (
|
||||
self.ncp_clovastudio_api_key.get_secret_value()
|
||||
if self.ncp_clovastudio_api_key
|
||||
else None
|
||||
)
|
||||
apigw_api_key = (
|
||||
self.ncp_apigw_api_key.get_secret_value()
|
||||
if self.ncp_apigw_api_key
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
|
||||
"X-NCP-APIGW-API-KEY": apigw_api_key,
|
||||
}
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None and "stopBefore" in params:
|
||||
params["stopBefore"] = stop
|
||||
|
||||
message_dicts = [_convert_message_to_naver_chat_message(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _completion_with_retry(self, **kwargs: Any) -> Any:
|
||||
from httpx_sse import (
|
||||
ServerSentEvent,
|
||||
SSEError,
|
||||
connect_sse,
|
||||
)
|
||||
|
||||
if "stream" not in kwargs:
|
||||
kwargs["stream"] = False
|
||||
|
||||
stream = kwargs["stream"]
|
||||
if stream:
|
||||
|
||||
def iter_sse() -> Iterator[ServerSentEvent]:
|
||||
with connect_sse(
|
||||
self.client, "POST", self._api_url, json=kwargs
|
||||
) as event_source:
|
||||
_raise_on_error(event_source.response)
|
||||
for sse in event_source.iter_sse():
|
||||
event_data = sse.json()
|
||||
if (
|
||||
sse.event == "signal"
|
||||
and event_data.get("data", {}) == "[DONE]"
|
||||
):
|
||||
return
|
||||
if sse.event == "result":
|
||||
return
|
||||
if sse.event == "error":
|
||||
raise SSEError(message=sse.data)
|
||||
yield sse
|
||||
|
||||
return iter_sse()
|
||||
else:
|
||||
response = self.client.post(url=self._api_url, json=kwargs)
|
||||
_raise_on_error(response)
|
||||
return response.json()
|
||||
|
||||
async def _acompletion_with_retry(
|
||||
self,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
if "stream" not in kwargs:
|
||||
kwargs["stream"] = False
|
||||
stream = kwargs["stream"]
|
||||
if stream:
|
||||
event_source = aconnect_sse(
|
||||
self.async_client, "POST", self._api_url, json=kwargs
|
||||
)
|
||||
return _aiter_sse(event_source)
|
||||
else:
|
||||
response = await self.async_client.post(url=self._api_url, json=kwargs)
|
||||
await _araise_on_error(response)
|
||||
return response.json()
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
def _create_chat_result(self, response: Dict) -> ChatResult:
|
||||
generations = []
|
||||
result = response.get("result", {})
|
||||
msg = result.get("message", {})
|
||||
message = _convert_naver_chat_message_to_message(msg)
|
||||
|
||||
if isinstance(message, AIMessage):
|
||||
message.usage_metadata = {
|
||||
"input_tokens": result.get("inputLength"),
|
||||
"output_tokens": result.get("outputLength"),
|
||||
"total_tokens": result.get("inputLength") + result.get("outputLength"),
|
||||
}
|
||||
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
)
|
||||
generations.append(gen)
|
||||
|
||||
llm_output = {
|
||||
"stop_reason": result.get("stopReason"),
|
||||
"input_length": result.get("inputLength"),
|
||||
"output_length": result.get("outputLength"),
|
||||
"seed": result.get("seed"),
|
||||
"ai_filter": result.get("aiFilter"),
|
||||
}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
|
||||
response = self._completion_with_retry(messages=message_dicts, **params)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for sse in self._completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
new_chunk = _convert_chunk_to_message_chunk(sse, default_chunk_class)
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=cast(str, new_chunk.content), chunk=gen_chunk
|
||||
)
|
||||
|
||||
yield gen_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
|
||||
response = await self._acompletion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await self._acompletion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=cast(str, new_chunk.content), chunk=gen_chunk
|
||||
)
|
||||
|
||||
yield gen_chunk
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatClovaX,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
||||
|
||||
errors = [httpx.RequestError, httpx.StreamError]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
@@ -151,6 +151,9 @@ if TYPE_CHECKING:
|
||||
from langchain_community.embeddings.mosaicml import (
|
||||
MosaicMLInstructorEmbeddings,
|
||||
)
|
||||
from langchain_community.embeddings.naver import (
|
||||
ClovaXEmbeddings,
|
||||
)
|
||||
from langchain_community.embeddings.nemo import (
|
||||
NeMoEmbeddings,
|
||||
)
|
||||
@@ -250,6 +253,7 @@ __all__ = [
|
||||
"BookendEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"ClovaEmbeddings",
|
||||
"ClovaXEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DashScopeEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
@@ -332,6 +336,7 @@ _module_lookup = {
|
||||
"BookendEmbeddings": "langchain_community.embeddings.bookend",
|
||||
"ClarifaiEmbeddings": "langchain_community.embeddings.clarifai",
|
||||
"ClovaEmbeddings": "langchain_community.embeddings.clova",
|
||||
"ClovaXEmbeddings": "langchain_community.embeddings.naver",
|
||||
"CohereEmbeddings": "langchain_community.embeddings.cohere",
|
||||
"DashScopeEmbeddings": "langchain_community.embeddings.dashscope",
|
||||
"DatabricksEmbeddings": "langchain_community.embeddings.databricks",
|
||||
|
@@ -3,11 +3,17 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import requests
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, SecretStr, model_validator
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.4",
|
||||
removal="1.0.0",
|
||||
alternative_import="langchain_community.ClovaXEmbeddings",
|
||||
)
|
||||
class ClovaEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
Clova's embedding service.
|
||||
|
192
libs/community/langchain_community/embeddings/naver.py
Normal file
192
libs/community/langchain_community/embeddings/naver.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_env
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
_DEFAULT_BASE_URL = "https://clovastudio.apigw.ntruss.com"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _raise_on_error(response: httpx.Response) -> None:
|
||||
"""Raise an error if the response is an error."""
|
||||
if httpx.codes.is_error(response.status_code):
|
||||
error_message = response.read().decode("utf-8")
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Error response {response.status_code} "
|
||||
f"while fetching {response.url}: {error_message}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
async def _araise_on_error(response: httpx.Response) -> None:
|
||||
"""Raise an error if the response is an error."""
|
||||
if httpx.codes.is_error(response.status_code):
|
||||
error_message = (await response.aread()).decode("utf-8")
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Error response {response.status_code} "
|
||||
f"while fetching {response.url}: {error_message}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
class ClovaXEmbeddings(BaseModel, Embeddings):
|
||||
"""`NCP ClovaStudio` Embedding API.
|
||||
|
||||
following environment variables set or passed in constructor in lower case:
|
||||
- ``NCP_CLOVASTUDIO_API_KEY``
|
||||
- ``NCP_APIGW_API_KEY``
|
||||
- ``NCP_CLOVASTUDIO_APP_ID``
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community import ClovaXEmbeddings
|
||||
|
||||
model = ClovaXEmbeddings(model="clir-emb-dolphin")
|
||||
output = embedding.embed_documents(documents)
|
||||
""" # noqa: E501
|
||||
|
||||
client: httpx.Client = Field(default=None) #: :meta private:
|
||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||
|
||||
ncp_clovastudio_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env are `NCP_CLOVASTUDIO_API_KEY` if not provided."""
|
||||
|
||||
ncp_apigw_api_key: Optional[SecretStr] = Field(default=None, alias="apigw_api_key")
|
||||
"""Automatically inferred from env are `NCP_APIGW_API_KEY` if not provided."""
|
||||
|
||||
base_url: str = Field(default=None, alias="base_url")
|
||||
"""
|
||||
Automatically inferred from env are `NCP_CLOVASTUDIO_API_BASE_URL` if not provided.
|
||||
"""
|
||||
|
||||
app_id: Optional[str] = Field(default=None)
|
||||
service_app: bool = Field(
|
||||
default=False,
|
||||
description="false: use testapp, true: use service app on NCP Clova Studio",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="clir-emb-dolphin",
|
||||
validation_alias=AliasChoices("model_name", "model"),
|
||||
description="NCP ClovaStudio embedding model name",
|
||||
)
|
||||
|
||||
timeout: int = Field(gt=0, default=60)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"ncp_clovastudio_api_key": "NCP_CLOVASTUDIO_API_KEY",
|
||||
"ncp_apigw_api_key": "NCP_APIGW_API_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def _api_url(self) -> str:
|
||||
"""GET embedding api url"""
|
||||
app_type = "serviceapp" if self.service_app else "testapp"
|
||||
model_name = self.model_name if self.model_name != "bge-m3" else "v2"
|
||||
return (
|
||||
f"{self.base_url}/{app_type}"
|
||||
f"/v1/api-tools/embedding/{model_name}/{self.app_id}"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model_after(self) -> Self:
|
||||
if not self.ncp_clovastudio_api_key:
|
||||
self.ncp_clovastudio_api_key = convert_to_secret_str(
|
||||
get_from_env("ncp_clovastudio_api_key", "NCP_CLOVASTUDIO_API_KEY")
|
||||
)
|
||||
|
||||
if not self.ncp_apigw_api_key:
|
||||
self.ncp_apigw_api_key = convert_to_secret_str(
|
||||
get_from_env("ncp_apigw_api_key", "NCP_APIGW_API_KEY")
|
||||
)
|
||||
|
||||
if not self.base_url:
|
||||
self.base_url = get_from_env(
|
||||
"base_url", "NCP_CLOVASTUDIO_API_BASE_URL", _DEFAULT_BASE_URL
|
||||
)
|
||||
|
||||
if not self.app_id:
|
||||
self.app_id = get_from_env("app_id", "NCP_CLOVASTUDIO_APP_ID")
|
||||
|
||||
if not self.client:
|
||||
self.client = httpx.Client(
|
||||
base_url=self.base_url,
|
||||
headers=self.default_headers(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if not self.async_client:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers=self.default_headers(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def default_headers(self) -> Dict[str, Any]:
|
||||
clovastudio_api_key = (
|
||||
self.ncp_clovastudio_api_key.get_secret_value()
|
||||
if self.ncp_clovastudio_api_key
|
||||
else None
|
||||
)
|
||||
apigw_api_key = (
|
||||
self.ncp_apigw_api_key.get_secret_value()
|
||||
if self.ncp_apigw_api_key
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"X-NCP-CLOVASTUDIO-API-KEY": clovastudio_api_key,
|
||||
"X-NCP-APIGW-API-KEY": apigw_api_key,
|
||||
}
|
||||
|
||||
def _embed_text(self, text: str) -> List[float]:
|
||||
payload = {"text": text}
|
||||
response = self.client.post(url=self._api_url, json=payload)
|
||||
_raise_on_error(response)
|
||||
return response.json()["result"]["embedding"]
|
||||
|
||||
async def _aembed_text(self, text: str) -> List[float]:
|
||||
payload = {"text": text}
|
||||
response = await self.async_client.post(url=self._api_url, json=payload)
|
||||
await _araise_on_error(response)
|
||||
return response.json()["result"]["embedding"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embeddings.append(self._embed_text(text))
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embed_text(text)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self._aembed_text(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return await self._aembed_text(text)
|
13
libs/community/poetry.lock
generated
13
libs/community/poetry.lock
generated
@@ -1267,6 +1267,17 @@ http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.0"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
|
||||
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
@@ -4557,4 +4568,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "5c436a9ba9a1695c5c456c1ad8a81c9772a2ba0248624278c6cb606dd019b338"
|
||||
content-hash = "5c36c0453948190412f8ea627e6a10e02c56a1ea732ca215834dad3f1d785f4c"
|
||||
|
@@ -43,6 +43,7 @@ tenacity = ">=8.1.0,!=8.4.0,<10"
|
||||
dataclasses-json = ">= 0.5.7, < 0.7"
|
||||
pydantic-settings = "^2.4.0"
|
||||
langsmith = "^0.1.125"
|
||||
httpx-sse = "^0.4.0"
|
||||
|
||||
[[tool.poetry.dependencies.numpy]]
|
||||
version = "^1"
|
||||
|
@@ -0,0 +1,71 @@
|
||||
"""Test ChatNaver chat model."""
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
from langchain_community.chat_models import ChatClovaX
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from ChatClovaX."""
|
||||
llm = ChatClovaX()
|
||||
|
||||
for token in llm.stream("I'm Clova"):
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from ChatClovaX."""
|
||||
llm = ChatClovaX()
|
||||
|
||||
async for token in llm.astream("I'm Clova"):
|
||||
assert isinstance(token, AIMessageChunk)
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from ChatClovaX."""
|
||||
llm = ChatClovaX()
|
||||
|
||||
result = await llm.abatch(["I'm Clova", "I'm not Clova"])
|
||||
for token in result:
|
||||
assert isinstance(token, AIMessage)
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatClovaX."""
|
||||
llm = ChatClovaX()
|
||||
|
||||
result = await llm.abatch(["I'm Clova", "I'm not Clova"], config={"tags": ["foo"]})
|
||||
for token in result:
|
||||
assert isinstance(token, AIMessage)
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatClovaX."""
|
||||
llm = ChatClovaX()
|
||||
|
||||
result = llm.batch(["I'm Clova", "I'm not Clova"])
|
||||
for token in result:
|
||||
assert isinstance(token, AIMessage)
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatClovaX."""
|
||||
llm = ChatClovaX()
|
||||
|
||||
result = await llm.ainvoke("I'm Clova", config={"tags": ["foo"]})
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatClovaX."""
|
||||
llm = ChatClovaX()
|
||||
|
||||
result = llm.invoke("I'm Clova", config=dict(tags=["foo"]))
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result.content, str)
|
@@ -0,0 +1,37 @@
|
||||
"""Test Naver embeddings."""
|
||||
|
||||
from langchain_community.embeddings import ClovaXEmbeddings
|
||||
|
||||
|
||||
def test_embedding_documents() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = ClovaXEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
async def test_aembedding_documents() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = ClovaXEmbeddings()
|
||||
output = await embedding.aembed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_embedding_query() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = ClovaXEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
async def test_aembedding_query() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = ClovaXEmbeddings()
|
||||
output = await embedding.aembed_query(document)
|
||||
assert len(output) > 0
|
@@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
||||
"ChatAnthropic",
|
||||
"ChatAnyscale",
|
||||
"ChatBaichuan",
|
||||
"ChatClovaX",
|
||||
"ChatCohere",
|
||||
"ChatCoze",
|
||||
"ChatDatabricks",
|
||||
|
197
libs/community/tests/unit_tests/chat_models/test_naver.py
Normal file
197
libs/community/tests/unit_tests/chat_models/test_naver.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Generator, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_community.chat_models import ChatClovaX
|
||||
from langchain_community.chat_models.naver import (
|
||||
_convert_message_to_naver_chat_message,
|
||||
_convert_naver_chat_message_to_message,
|
||||
)
|
||||
|
||||
os.environ["NCP_CLOVASTUDIO_API_KEY"] = "test_api_key"
|
||||
os.environ["NCP_APIGW_API_KEY"] = "test_gw_key"
|
||||
|
||||
|
||||
def test_initialization_api_key() -> None:
|
||||
"""Test chat model initialization."""
|
||||
chat_model = ChatClovaX(api_key="foo", apigw_api_key="bar") # type: ignore[arg-type]
|
||||
assert (
|
||||
cast(SecretStr, chat_model.ncp_clovastudio_api_key).get_secret_value() == "foo"
|
||||
)
|
||||
assert cast(SecretStr, chat_model.ncp_apigw_api_key).get_secret_value() == "bar"
|
||||
|
||||
|
||||
def test_initialization_model_name() -> None:
|
||||
llm = ChatClovaX(model="HCX-DASH-001") # type: ignore[call-arg]
|
||||
assert llm.model_name == "HCX-DASH-001"
|
||||
llm = ChatClovaX(model_name="HCX-DASH-001")
|
||||
assert llm.model_name == "HCX-DASH-001"
|
||||
|
||||
|
||||
def test_convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = _convert_naver_chat_message_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_naver_chat_message(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_naver_chat_message_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_naver_chat_message(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = _convert_naver_chat_message_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_naver_chat_message(expected_output) == message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_completion_response() -> dict:
|
||||
return {
|
||||
"status": {"code": "20000", "message": "OK"},
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Phrases: Record what happened today and prepare "
|
||||
"for tomorrow. "
|
||||
"The diary will make your life richer.",
|
||||
},
|
||||
"stopReason": "LENGTH",
|
||||
"inputLength": 100,
|
||||
"outputLength": 10,
|
||||
"aiFilter": [
|
||||
{"groupName": "curse", "name": "insult", "score": "1"},
|
||||
{"groupName": "curse", "name": "discrimination", "score": "0"},
|
||||
{
|
||||
"groupName": "unsafeContents",
|
||||
"name": "sexualHarassment",
|
||||
"score": "2",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_naver_invoke(mock_chat_completion_response: dict) -> None:
|
||||
llm = ChatClovaX()
|
||||
completed = False
|
||||
|
||||
def mock_completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_chat_completion_response
|
||||
|
||||
with patch.object(ChatClovaX, "_completion_with_retry", mock_completion_with_retry):
|
||||
res = llm.invoke("Let's test it.")
|
||||
assert (
|
||||
res.content
|
||||
== "Phrases: Record what happened today and prepare for tomorrow. "
|
||||
"The diary will make your life richer."
|
||||
)
|
||||
assert completed
|
||||
|
||||
|
||||
async def test_naver_ainvoke(mock_chat_completion_response: dict) -> None:
|
||||
llm = ChatClovaX()
|
||||
completed = False
|
||||
|
||||
async def mock_acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_chat_completion_response
|
||||
|
||||
with patch.object(
|
||||
ChatClovaX, "_acompletion_with_retry", mock_acompletion_with_retry
|
||||
):
|
||||
res = await llm.ainvoke("Let's test it.")
|
||||
assert (
|
||||
res.content
|
||||
== "Phrases: Record what happened today and prepare for tomorrow. "
|
||||
"The diary will make your life richer."
|
||||
)
|
||||
assert completed
|
||||
|
||||
|
||||
def _make_completion_response_from_token(token: str): # type: ignore[no-untyped-def]
|
||||
from httpx_sse import ServerSentEvent
|
||||
|
||||
return ServerSentEvent(
|
||||
event="token",
|
||||
data=json.dumps(
|
||||
dict(
|
||||
index=0,
|
||||
inputLength=89,
|
||||
outputLength=1,
|
||||
message=dict(
|
||||
content=token,
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
|
||||
def it() -> Generator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
|
||||
return it()
|
||||
|
||||
|
||||
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
async def it() -> AsyncGenerator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
|
||||
return it()
|
||||
|
||||
|
||||
class MyCustomHandler(BaseCallbackHandler):
|
||||
last_token: str = ""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self.last_token = token
|
||||
|
||||
|
||||
@patch(
|
||||
"langchain_community.chat_models.ChatClovaX._completion_with_retry",
|
||||
new=mock_chat_stream,
|
||||
)
|
||||
@pytest.mark.requires("httpx_sse")
|
||||
def test_stream_with_callback() -> None:
|
||||
callback = MyCustomHandler()
|
||||
chat = ChatClovaX(callbacks=[callback])
|
||||
for token in chat.stream("Hello"):
|
||||
assert callback.last_token == token.content
|
||||
|
||||
|
||||
@patch(
|
||||
"langchain_community.chat_models.ChatClovaX._acompletion_with_retry",
|
||||
new=mock_chat_astream,
|
||||
)
|
||||
@pytest.mark.requires("httpx_sse")
|
||||
async def test_astream_with_callback() -> None:
|
||||
callback = MyCustomHandler()
|
||||
chat = ChatClovaX(callbacks=[callback])
|
||||
async for token in chat.astream("Hello"):
|
||||
assert callback.last_token == token.content
|
@@ -7,6 +7,7 @@ EXPECTED_ALL = [
|
||||
"AzureOpenAIEmbeddings",
|
||||
"BaichuanTextEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"ClovaXEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
"ElasticsearchEmbeddings",
|
||||
|
18
libs/community/tests/unit_tests/embeddings/test_naver.py
Normal file
18
libs/community/tests/unit_tests/embeddings/test_naver.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_community.embeddings import ClovaXEmbeddings
|
||||
|
||||
os.environ["NCP_CLOVASTUDIO_API_KEY"] = "test_api_key"
|
||||
os.environ["NCP_APIGW_API_KEY"] = "test_gw_key"
|
||||
os.environ["NCP_CLOVASTUDIO_APP_ID"] = "test_app_id"
|
||||
|
||||
|
||||
def test_initialization_api_key() -> None:
|
||||
llm = ClovaXEmbeddings(api_key="foo", apigw_api_key="bar") # type: ignore[arg-type]
|
||||
assert cast(SecretStr, llm.ncp_clovastudio_api_key).get_secret_value() == "foo"
|
||||
assert cast(SecretStr, llm.ncp_apigw_api_key).get_secret_value() == "bar"
|
@@ -43,6 +43,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"SQLAlchemy",
|
||||
"aiohttp",
|
||||
"dataclasses-json",
|
||||
"httpx-sse",
|
||||
"langchain-core",
|
||||
"langsmith",
|
||||
"numpy",
|
||||
|
Reference in New Issue
Block a user