mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
upstage: init package (#20574)
Co-authored-by: Sean Cho <sean@upstage.ai> Co-authored-by: JuHyung-Son <sonju0427@gmail.com>
This commit is contained in:
@@ -2,16 +2,19 @@
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.llms.solar import SOLAR_SERVICE_URL_BASE, SolarCommon
|
||||
|
||||
|
||||
class SolarChat(SolarCommon, ChatOpenAI): # type: ignore[misc]
|
||||
"""Solar large language models.
|
||||
|
||||
@deprecated(
|
||||
since="0.0.34", removal="0.2.0", alternative_import="langchain_upstage.ChatUpstage"
|
||||
)
|
||||
class SolarChat(SolarCommon, ChatOpenAI):
|
||||
"""Wrapper around Solar large language models.
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``SOLAR_API_KEY`` set with your API key.
|
||||
(Solar's chat API is compatible with OpenAI's SDK.)
|
||||
@@ -24,6 +27,16 @@ class SolarChat(SolarCommon, ChatOpenAI): # type: ignore[misc]
|
||||
solar = SolarChat(model="solar-1-mini-chat")
|
||||
"""
|
||||
|
||||
max_tokens: int = Field(default=1024)
|
||||
|
||||
# this is needed to match ChatOpenAI superclass
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = "ignore"
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the environment is set up correctly."""
|
||||
@@ -42,9 +55,9 @@ class SolarChat(SolarCommon, ChatOpenAI): # type: ignore[misc]
|
||||
|
||||
client_params = {
|
||||
"api_key": values["solar_api_key"],
|
||||
"base_url": values["base_url"]
|
||||
if "base_url" in values
|
||||
else SOLAR_SERVICE_URL_BASE,
|
||||
"base_url": (
|
||||
values["base_url"] if "base_url" in values else SOLAR_SERVICE_URL_BASE
|
||||
),
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
|
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
@@ -44,6 +45,9 @@ def embed_with_retry(embeddings: SolarEmbeddings, *args: Any, **kwargs: Any) ->
|
||||
return _embed_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.34", removal="0.2.0", alternative_import="langchain_upstage.ChatUpstage"
|
||||
)
|
||||
class SolarEmbeddings(BaseModel, Embeddings):
|
||||
"""Solar's embedding service.
|
||||
|
||||
|
@@ -40,7 +40,7 @@ class SolarCommon(BaseModel):
|
||||
"""Solar API key. Get it here: https://console.upstage.ai/services/solar"""
|
||||
model_name: str = Field(default="solar-1-mini-chat", alias="model")
|
||||
"""Model name. Available models listed here: https://console.upstage.ai/services/solar"""
|
||||
max_tokens: int = Field(default=1024, alias="max context")
|
||||
max_tokens: int = Field(default=1024)
|
||||
temperature = 0.3
|
||||
|
||||
class Config:
|
||||
|
1
libs/partners/upstage/.gitignore
vendored
Normal file
1
libs/partners/upstage/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/upstage/LICENSE
Normal file
21
libs/partners/upstage/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
57
libs/partners/upstage/Makefile
Normal file
57
libs/partners/upstage/Makefile
Normal file
@@ -0,0 +1,57 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/upstage --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_upstage
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_upstage -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
25
libs/partners/upstage/README.md
Normal file
25
libs/partners/upstage/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# langchain-upstage
|
||||
|
||||
This package contains the LangChain integrations for [Upstage](https://upstage.ai) through their [APIs](https://developers.upstage.ai/docs/getting-started/models).
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the LangChain partner package
|
||||
```bash
|
||||
pip install -U langchain-upstage
|
||||
```
|
||||
|
||||
- Get an Upstage api key from [Upstage Console](https://console.upstage.ai/home) and set it as an environment variable (`UPSTAGE_API_KEY`)
|
||||
|
||||
## Chat Models
|
||||
|
||||
This package contains the `ChatUpstage` class, which is the recommended way to interface with Upstage models.
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/chat/upstage)
|
||||
|
||||
## Embeddings
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/text_embedding/upstage)
|
||||
|
||||
Use `solar-1-mini-embedding` as the default model for embeddings. Do not add suffixes such as `-query` or `-passage` to the model name.
|
||||
`UpstageEmbeddings` will automatically add the suffixes based on the method called.
|
4
libs/partners/upstage/langchain_upstage/__init__.py
Normal file
4
libs/partners/upstage/langchain_upstage/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from langchain_upstage.chat_models import ChatUpstage
|
||||
from langchain_upstage.embeddings import UpstageEmbeddings
|
||||
|
||||
__all__ = ["ChatUpstage", "UpstageEmbeddings"]
|
101
libs/partners/upstage/langchain_upstage/chat_models.py
Normal file
101
libs/partners/upstage/langchain_upstage/chat_models.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class ChatUpstage(ChatOpenAI):
|
||||
"""ChatUpstage chat model.
|
||||
|
||||
To use, you should have the environment variable `UPSTAGE_API_KEY`
|
||||
set with your API key or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
model = ChatUpstage()
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"upstage_api_key": "UPSTAGE_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return ["langchain", "chat_models", "upstage"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.upstage_api_base:
|
||||
attributes["upstage_api_base"] = self.upstage_api_base
|
||||
|
||||
return attributes
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "upstage-chat"
|
||||
|
||||
model_name: str = Field(default="solar-1-mini-chat", alias="model")
|
||||
"""Model name to use."""
|
||||
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env are `UPSTAGE_API_KEY` if not provided."""
|
||||
upstage_api_base: Optional[str] = Field(
|
||||
default="https://api.upstage.ai/v1/solar", alias="base_url"
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
values["upstage_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "upstage_api_key", "UPSTAGE_API_KEY")
|
||||
)
|
||||
values["upstage_api_base"] = values["upstage_api_base"] or os.getenv(
|
||||
"UPSTAGE_API_BASE"
|
||||
)
|
||||
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["upstage_api_key"].get_secret_value()
|
||||
if values["upstage_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["upstage_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
values["client"] = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
return values
|
263
libs/partners/upstage/langchain_upstage/embeddings.py
Normal file
263
libs/partners/upstage/langchain_upstage/embeddings.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpstageEmbeddings(BaseModel, Embeddings):
|
||||
"""UpstageEmbeddings embedding model.
|
||||
|
||||
To use, set the environment variable `UPSTAGE_API_KEY` with your API key or
|
||||
pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_upstage import UpstageEmbeddings
|
||||
|
||||
model = UpstageEmbeddings()
|
||||
"""
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model: str = "solar-1-mini-embedding"
|
||||
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
||||
Instead, use 'solar-1-mini-embedding' for example.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""The number of dimensions the resulting output embeddings should have.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""API Key for Solar API."""
|
||||
upstage_api_base: str = Field(
|
||||
default="https://api.upstage.ai/v1/solar", alias="base_url"
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
embedding_ctx_length: int = 4096
|
||||
"""The maximum number of tokens to embed at once.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||
"""Not yet supported."""
|
||||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||
"""Not yet supported."""
|
||||
chunk_size: int = 1000
|
||||
"""Maximum number of texts to embed in each batch.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Upstage embedding API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
skip_empty: bool = False
|
||||
"""Whether to skip empty strings when embedding or raise an error.
|
||||
Defaults to not skipping.
|
||||
|
||||
Not yet supported."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
upstage_api_key = get_from_dict_or_env(
|
||||
values, "upstage_api_key", "UPSTAGE_API_KEY"
|
||||
)
|
||||
values["upstage_api_key"] = (
|
||||
convert_to_secret_str(upstage_api_key) if upstage_api_key else None
|
||||
)
|
||||
values["upstage_api_base"] = values["upstage_api_base"] or os.getenv(
|
||||
"UPSTAGE_API_BASE"
|
||||
)
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["upstage_api_key"].get_secret_value()
|
||||
if values["upstage_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["upstage_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
}
|
||||
if not values.get("client"):
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
values["client"] = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).embeddings
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).embeddings
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
self.model = self.model.replace("-query", "").replace("-passage", "")
|
||||
|
||||
params: Dict = {"model": self.model, **self.model_kwargs}
|
||||
if self.dimensions is not None:
|
||||
params["dimensions"] = self.dimensions
|
||||
return params
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of document texts using passage model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
|
||||
for text in texts:
|
||||
response = self.client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text using query model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
|
||||
response = self.client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
return response["data"][0]["embedding"]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of document texts using passage model asynchronously.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
|
||||
for text in texts:
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
||||
return embeddings
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text using query model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
return response["data"][0]["embedding"]
|
0
libs/partners/upstage/langchain_upstage/py.typed
Normal file
0
libs/partners/upstage/langchain_upstage/py.typed
Normal file
1273
libs/partners/upstage/poetry.lock
generated
Normal file
1273
libs/partners/upstage/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
96
libs/partners/upstage/pyproject.toml
Normal file
96
libs/partners/upstage/pyproject.toml
Normal file
@@ -0,0 +1,96 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-upstage"
|
||||
version = "0.1.0rc0"
|
||||
description = "An integration package connecting Upstage and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/upstage"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.42"
|
||||
langchain-openai = "^0.1.3"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-openai = { path = "../openai", develop = true }
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
docarray = "^0.32.1"
|
||||
pydantic = "^1.10.9"
|
||||
langchain-standard-tests = { path = "../../standard-tests", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/upstage/scripts/check_imports.py
Normal file
17
libs/partners/upstage/scripts/check_imports.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/upstage/scripts/check_pydantic.sh
Executable file
27
libs/partners/upstage/scripts/check_pydantic.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
17
libs/partners/upstage/scripts/lint_imports.sh
Executable file
17
libs/partners/upstage/scripts/lint_imports.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/upstage/tests/__init__.py
Normal file
0
libs/partners/upstage/tests/__init__.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
def test_chat_upstage_model() -> None:
|
||||
"""Test ChatUpstage wrapper handles model_name."""
|
||||
chat = ChatUpstage(model="foo")
|
||||
assert chat.model_name == "foo"
|
||||
chat = ChatUpstage(model_name="bar")
|
||||
assert chat.model_name == "bar"
|
||||
|
||||
|
||||
def test_chat_upstage_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatUpstage(max_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_upstage_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatUpstage(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_upstage_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatUpstage(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_upstage_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_upstage_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat upstage."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatUpstage(foo=3, max_tokens=10)
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatUpstage(foo=3, model_kwargs={"bar": 2})
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(foo=3, model_kwargs={"foo": 2})
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(model_kwargs={"model": "solar-1-mini-chat"})
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
@@ -0,0 +1,32 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
class TestUpstageStandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatUpstage
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "solar-1-mini-chat",
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="400s with tool calling currently")
|
||||
def test_tool_message_histories(
|
||||
self,
|
||||
chat_model_class: Type[BaseChatModel],
|
||||
chat_model_params: dict,
|
||||
chat_model_has_tool_calling: bool,
|
||||
) -> None:
|
||||
super().test_tool_message_histories(
|
||||
chat_model_class, chat_model_params, chat_model_has_tool_calling
|
||||
)
|
@@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@@ -0,0 +1,36 @@
|
||||
"""Test Upstage embeddings."""
|
||||
from langchain_upstage import UpstageEmbeddings
|
||||
|
||||
|
||||
def test_langchain_upstage_embed_documents() -> None:
|
||||
"""Test Upstage embeddings."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = UpstageEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_upstage_embed_query() -> None:
|
||||
"""Test Upstage embeddings."""
|
||||
query = "foo bar"
|
||||
embedding = UpstageEmbeddings()
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
async def test_langchain_upstage_aembed_documents() -> None:
|
||||
"""Test Upstage embeddings asynchronous."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = UpstageEmbeddings()
|
||||
output = await embedding.aembed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
async def test_langchain_upstage_aembed_query() -> None:
|
||||
"""Test Upstage embeddings asynchronous."""
|
||||
query = "foo bar"
|
||||
embedding = UpstageEmbeddings()
|
||||
output = await embedding.aembed_query(query)
|
||||
assert len(output) > 0
|
0
libs/partners/upstage/tests/unit_tests/__init__.py
Normal file
0
libs/partners/upstage/tests/unit_tests/__init__.py
Normal file
192
libs/partners/upstage/tests/unit_tests/test_chat_models.py
Normal file
192
libs/partners/upstage/tests/unit_tests/test_chat_models.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_openai.chat_models.base import (
|
||||
_convert_dict_to_message,
|
||||
_convert_message_to_dict,
|
||||
)
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatUpstage()
|
||||
|
||||
|
||||
def test_upstage_model_param() -> None:
|
||||
llm = ChatUpstage(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatUpstage(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
def test_function_dict_to_message_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = _convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
assert isinstance(result, FunctionMessage)
|
||||
assert result.name == name
|
||||
assert result.content == content
|
||||
|
||||
|
||||
def test_convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test__convert_dict_to_message_human_with_name() -> None:
|
||||
message = {"role": "user", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai_with_name() -> None:
|
||||
message = {"role": "assistant", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system_with_name() -> None:
|
||||
message = {"role": "system", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_tool() -> None:
|
||||
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = ToolMessage(content="foo", tool_call_id="bar")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "chatcmpl-7fcZavknQda3SQ",
|
||||
"object": "chat.completion",
|
||||
"created": 1689989000,
|
||||
"model": "solar-1-mini-chat",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Bab",
|
||||
"name": "KimSolar",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_upstage_invoke(mock_completion: dict) -> None:
|
||||
llm = ChatUpstage()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.invoke("bab")
|
||||
assert res.content == "Bab"
|
||||
assert completed
|
||||
|
||||
|
||||
async def test_upstage_ainvoke(mock_completion: dict) -> None:
|
||||
llm = ChatUpstage()
|
||||
mock_client = AsyncMock()
|
||||
completed = False
|
||||
|
||||
async def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"async_client",
|
||||
mock_client,
|
||||
):
|
||||
res = await llm.ainvoke("bab")
|
||||
assert res.content == "Bab"
|
||||
assert completed
|
||||
|
||||
|
||||
def test_upstage_invoke_name(mock_completion: dict) -> None:
|
||||
llm = ChatUpstage()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.return_value = mock_completion
|
||||
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
messages = [
|
||||
HumanMessage(content="Foo", name="Zorba"),
|
||||
]
|
||||
res = llm.invoke(messages)
|
||||
call_args, call_kwargs = mock_client.create.call_args
|
||||
assert len(call_args) == 0 # no positional args
|
||||
call_messages = call_kwargs["messages"]
|
||||
assert len(call_messages) == 1
|
||||
assert call_messages[0]["role"] == "user"
|
||||
assert call_messages[0]["content"] == "Foo"
|
||||
assert call_messages[0]["name"] == "Zorba"
|
||||
|
||||
# check return type has name
|
||||
assert res.content == "Bab"
|
||||
assert res.name == "KimSolar"
|
@@ -0,0 +1,20 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
class TestUpstageStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatUpstage
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "solar-1-mini-chat",
|
||||
}
|
24
libs/partners/upstage/tests/unit_tests/test_embeddings.py
Normal file
24
libs/partners/upstage/tests/unit_tests/test_embeddings.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Test embedding model integration."""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_upstage import UpstageEmbeddings
|
||||
|
||||
os.environ["UPSTAGE_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
UpstageEmbeddings()
|
||||
|
||||
|
||||
def test_upstage_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
UpstageEmbeddings(model_kwargs={"model": "foo"})
|
||||
|
||||
|
||||
def test_upstage_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = UpstageEmbeddings(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
10
libs/partners/upstage/tests/unit_tests/test_imports.py
Normal file
10
libs/partners/upstage/tests/unit_tests/test_imports.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from langchain_upstage import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatUpstage",
|
||||
"UpstageEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
13
libs/partners/upstage/tests/unit_tests/test_secrets.py
Normal file
13
libs/partners/upstage/tests/unit_tests/test_secrets.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from langchain_upstage import ChatUpstage, UpstageEmbeddings
|
||||
|
||||
|
||||
def test_chat_upstage_secrets() -> None:
|
||||
o = ChatUpstage(upstage_api_key="foo")
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
||||
|
||||
|
||||
def test_upstage_embeddings_secrets() -> None:
|
||||
o = UpstageEmbeddings(upstage_api_key="foo")
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
Reference in New Issue
Block a user