mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
together: add chat models, use openai base (#21337)
**Description:** Adding chat completions to the Together AI package, which is our most popular API. Also staying backwards compatible with the old API so folks can continue to use the completions API as well. Also moved the embedding API to use the OpenAI library to standardize it further. **Twitter handle:** @nutlope - [x] **Add tests and docs**: If you're adding a new integration, please include - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a2d31307bb
commit
d6ef5fe86a
@ -7,15 +7,21 @@
|
||||
"source": [
|
||||
"# Together AI\n",
|
||||
"\n",
|
||||
"> The Together API makes it easy to fine-tune or run leading open-source models with a couple lines of code. We have integrated the world’s leading open-source models, including Llama-2, RedPajama, Falcon, Alpaca, Stable Diffusion XL, and more. Read more: https://together.ai\n",
|
||||
"> The Together API makes it easy to query and fine-tune leading open-source models with a couple lines of code. We have integrated the world’s leading open-source models, including Llama-3, Mixtral, DBRX, Stable Diffusion XL, and more. Read more: https://together.ai\n",
|
||||
"\n",
|
||||
"To use, you'll need an API key which you can find here:\n",
|
||||
"https://api.together.xyz/settings/api-keys. This can be passed in as init param\n",
|
||||
"https://api.together.ai/settings/api-keys. This can be passed in as init param\n",
|
||||
"``together_api_key`` or set as environment variable ``TOGETHER_API_KEY``.\n",
|
||||
"\n",
|
||||
"Together API reference: https://docs.together.ai/reference"
|
||||
"Together API reference: https://docs.together.ai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c47fc36",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -28,40 +34,43 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "637bb53f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Running chat completions with Together AI\n",
|
||||
"\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"from langchain_together import ChatTogether\n",
|
||||
"\n",
|
||||
"chat = ChatTogether()\n",
|
||||
"\n",
|
||||
"# using chat invoke\n",
|
||||
"chat.invoke(\"Tell me fun things to do in NYC\")\n",
|
||||
"\n",
|
||||
"# using chat stream\n",
|
||||
"for m in chat.stream(\"Tell me fun things to do in NYC\"):\n",
|
||||
" print(m)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7b7170d-d7c5-4890-9714-a37238343805",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"A: A large language model is a neural network that is trained on a large amount of text data. It is able to generate text that is similar to the training data, and can be used for tasks such as language translation, question answering, and text summarization.\n",
|
||||
"\n",
|
||||
"A: A large language model is a neural network that is trained on a large amount of text data. It is able to generate text that is similar to the training data, and can be used for tasks such as language translation, question answering, and text summarization.\n",
|
||||
"\n",
|
||||
"A: A large language model is a neural network that is trained on\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Running completions with Together AI\n",
|
||||
"\n",
|
||||
"from langchain_together import Together\n",
|
||||
"\n",
|
||||
"llm = Together(\n",
|
||||
" model=\"togethercomputer/RedPajama-INCITE-7B-Base\",\n",
|
||||
" temperature=0.7,\n",
|
||||
" max_tokens=128,\n",
|
||||
" top_k=1,\n",
|
||||
" model=\"codellama/CodeLlama-70b-Python-hf\",\n",
|
||||
" # together_api_key=\"...\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"input_ = \"\"\"You are a teacher with a deep knowledge of machine learning and AI. \\\n",
|
||||
"You provide succinct and accurate answers. Answer the following question: \n",
|
||||
"\n",
|
||||
"What is a large language model?\"\"\"\n",
|
||||
"print(llm.invoke(input_))"
|
||||
"print(llm.invoke(\"def bubble_sort(): \"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
@ -11,7 +11,6 @@ integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
@ -1,48 +1,28 @@
|
||||
# langchain-together
|
||||
|
||||
This package contains the LangChain integration for Together's generative models.
|
||||
This package contains the LangChain integrations for [Together AI](https://www.together.ai/) through their [APIs](https://docs.together.ai/).
|
||||
|
||||
## Installation
|
||||
## Installation and Setup
|
||||
|
||||
```sh
|
||||
- Install the LangChain partner package
|
||||
|
||||
```bash
|
||||
pip install -U langchain-together
|
||||
```
|
||||
|
||||
- Get your Together AI api key from the [Together Dashboard](https://api.together.ai/settings/api-keys) and set it as an environment variable (`TOGETHER_API_KEY`)
|
||||
|
||||
## Chat Completions
|
||||
|
||||
This package contains the `ChatTogether` class, which is the recommended way to interface with Together AI chat models.
|
||||
|
||||
ADD USAGE EXAMPLE HERE.
|
||||
Can we add this in the langchain docs?
|
||||
|
||||
NEED to add image endpoint + completions endpoint as well
|
||||
|
||||
## Embeddings
|
||||
|
||||
You can use Together's embedding models through `TogetherEmbeddings` class.
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/text_embedding/together/)
|
||||
|
||||
```py
|
||||
from langchain_together import TogetherEmbeddings
|
||||
|
||||
embeddings = TogetherEmbeddings(
|
||||
model='togethercomputer/m2-bert-80M-8k-retrieval'
|
||||
)
|
||||
embeddings.embed_query("What is a large language model?")
|
||||
```
|
||||
|
||||
## LLMs
|
||||
|
||||
You can use Together's generative AI models as Langchain LLMs:
|
||||
|
||||
```py
|
||||
from langchain_together import Together
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
llm = Together(
|
||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||
temperature=0.7,
|
||||
max_tokens=64,
|
||||
top_k=1,
|
||||
# together_api_key="..."
|
||||
)
|
||||
|
||||
template = """Question: {question}
|
||||
Answer: """
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
chain = prompt | llm
|
||||
|
||||
question = "Who was the president in the year Justin Beiber was born?"
|
||||
print(chain.invoke({"question": question}))
|
||||
```
|
||||
Use `togethercomputer/m2-bert-80M-8k-retrieval` as the default model for embeddings.
|
||||
|
@ -1,9 +1,5 @@
|
||||
from langchain_together.chat_models import ChatTogether
|
||||
from langchain_together.embeddings import TogetherEmbeddings
|
||||
from langchain_together.llms import Together
|
||||
from langchain_together.version import __version__
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"Together",
|
||||
"TogetherEmbeddings",
|
||||
]
|
||||
__all__ = ["ChatTogether", "Together", "TogetherEmbeddings"]
|
||||
|
103
libs/partners/together/langchain_together/chat_models.py
Normal file
103
libs/partners/together/langchain_together/chat_models.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""Wrapper around Together AI's Chat Completions API."""
|
||||
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
)
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
|
||||
|
||||
class ChatTogether(BaseChatOpenAI):
|
||||
"""ChatTogether chat model.
|
||||
|
||||
To use, you should have the environment variable `TOGETHER_API_KEY`
|
||||
set with your API key or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
|
||||
model = ChatTogether()
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"together_api_key": "TOGETHER_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return ["langchain", "chat_models", "together"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.together_api_base:
|
||||
attributes["together_api_base"] = self.together_api_base
|
||||
|
||||
return attributes
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "together-chat"
|
||||
|
||||
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
|
||||
"""Model name to use."""
|
||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
|
||||
together_api_base: Optional[str] = Field(
|
||||
default="https://api.together.ai/v1/chat/completions", alias="base_url"
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
values["together_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
|
||||
)
|
||||
values["together_api_base"] = values["together_api_base"] or os.getenv(
|
||||
"TOGETHER_API_BASE"
|
||||
)
|
||||
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["together_api_key"].get_secret_value()
|
||||
if values["together_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["together_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
values["client"] = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
return values
|
@ -1,50 +1,265 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
"""Wrapper around Together AI's Embeddings API."""
|
||||
|
||||
import together # type: ignore
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
"""TogetherEmbeddings embedding model.
|
||||
|
||||
To use, set the environment variable `TOGETHER_API_KEY` with your API key or
|
||||
pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_together import TogetherEmbeddings
|
||||
|
||||
model = TogetherEmbeddings(
|
||||
model='togethercomputer/m2-bert-80M-8k-retrieval'
|
||||
)
|
||||
model = TogetherEmbeddings()
|
||||
"""
|
||||
|
||||
_client: together.Together
|
||||
together_api_key: SecretStr = convert_to_secret_str("")
|
||||
model: str
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
|
||||
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
||||
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""The number of dimensions the resulting output embeddings should have.
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate environment variables."""
|
||||
together_api_key = convert_to_secret_str(
|
||||
values.get("together_api_key") or os.getenv("TOGETHER_API_KEY") or ""
|
||||
)
|
||||
values["together_api_key"] = together_api_key
|
||||
Not yet supported.
|
||||
"""
|
||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""API Key for Solar API."""
|
||||
together_api_base: str = Field(
|
||||
default="https://api.together.ai/v1/embeddings", alias="base_url"
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
embedding_ctx_length: int = 4096
|
||||
"""The maximum number of tokens to embed at once.
|
||||
|
||||
# note this sets it globally for module
|
||||
# there isn't currently a way to pass it into client
|
||||
together.api_key = together_api_key.get_secret_value()
|
||||
values["_client"] = together.Together()
|
||||
Not yet supported.
|
||||
"""
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||
"""Not yet supported."""
|
||||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||
"""Not yet supported."""
|
||||
chunk_size: int = 1000
|
||||
"""Maximum number of texts to embed in each batch.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Together embedding API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
skip_empty: bool = False
|
||||
"""Whether to skip empty strings when embedding or raise an error.
|
||||
Defaults to not skipping.
|
||||
|
||||
Not yet supported."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
together_api_key = get_from_dict_or_env(
|
||||
values, "together_api_key", "TOGETHER_API_KEY"
|
||||
)
|
||||
values["together_api_key"] = (
|
||||
convert_to_secret_str(together_api_key) if together_api_key else None
|
||||
)
|
||||
values["together_api_base"] = values["together_api_base"] or os.getenv(
|
||||
"TOGETHER_API_BASE"
|
||||
)
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["together_api_key"].get_secret_value()
|
||||
if values["together_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["together_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
}
|
||||
if not values.get("client"):
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
values["client"] = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).embeddings
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).embeddings
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
self.model = self.model.replace("-query", "").replace("-passage", "")
|
||||
|
||||
params: Dict = {"model": self.model, **self.model_kwargs}
|
||||
if self.dimensions is not None:
|
||||
params["dimensions"] = self.dimensions
|
||||
return params
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return [
|
||||
i.embedding
|
||||
for i in self._client.embeddings.create(input=texts, model=self.model).data
|
||||
]
|
||||
"""Embed a list of document texts using passage model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
|
||||
for text in texts:
|
||||
response = self.client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
||||
"""Embed query text using query model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
|
||||
response = self.client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
return response["data"][0]["embedding"]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of document texts using passage model asynchronously.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
|
||||
for text in texts:
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
||||
return embeddings
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text using query model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
return response["data"][0]["embedding"]
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Wrapper around Together AI's Completion API."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
@ -13,8 +14,6 @@ from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
from langchain_together.version import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -22,7 +21,7 @@ class Together(LLM):
|
||||
"""LLM models from `Together`.
|
||||
|
||||
To use, you'll need an API key which you can find here:
|
||||
https://api.together.xyz/settings/api-keys. This can be passed in as init param
|
||||
https://api.together.ai/settings/api-keys. This can be passed in as init param
|
||||
``together_api_key`` or set as environment variable ``TOGETHER_API_KEY``.
|
||||
|
||||
Together AI API reference: https://docs.together.ai/reference/completions
|
||||
@ -35,39 +34,39 @@ class Together(LLM):
|
||||
model = Together(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
"""
|
||||
|
||||
base_url: str = "https://api.together.xyz/v1/completions"
|
||||
base_url: str = "https://api.together.ai/v1/completions"
|
||||
"""Base completions API URL."""
|
||||
together_api_key: SecretStr
|
||||
"""Together AI API key. Get it here: https://api.together.xyz/settings/api-keys"""
|
||||
"""Together AI API key. Get it here: https://api.together.ai/settings/api-keys"""
|
||||
model: str
|
||||
"""Model name. Available models listed here:
|
||||
"""Model name. Available models listed here:
|
||||
Base Models: https://docs.together.ai/docs/inference-models#language-models
|
||||
Chat Models: https://docs.together.ai/docs/inference-models#chat-models
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""Model temperature."""
|
||||
top_p: Optional[float] = None
|
||||
"""Used to dynamically adjust the number of choices for each predicted token based
|
||||
on the cumulative probabilities. A value of 1 will always yield the same
|
||||
output. A temperature less than 1 favors more correctness and is appropriate
|
||||
for question answering or summarization. A value greater than 1 introduces more
|
||||
"""Used to dynamically adjust the number of choices for each predicted token based
|
||||
on the cumulative probabilities. A value of 1 will always yield the same
|
||||
output. A temperature less than 1 favors more correctness and is appropriate
|
||||
for question answering or summarization. A value greater than 1 introduces more
|
||||
randomness in the output.
|
||||
"""
|
||||
top_k: Optional[int] = None
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
probability of occurrence. This technique helps to speed up the generation
|
||||
process and can improve the quality of the generated text by focusing on the
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
probability of occurrence. This technique helps to speed up the generation
|
||||
process and can improve the quality of the generated text by focusing on the
|
||||
most likely options.
|
||||
"""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
"""A number that controls the diversity of generated text by reducing the
|
||||
"""A number that controls the diversity of generated text by reducing the
|
||||
likelihood of repeated sequences. Higher values decrease repetition.
|
||||
"""
|
||||
logprobs: Optional[int] = None
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
the response for each token generation step.
|
||||
"""
|
||||
|
||||
@ -107,10 +106,6 @@ class Together(LLM):
|
||||
def _format_output(self, output: dict) -> str:
|
||||
return output["choices"][0]["text"]
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
return f"langchain-together/{__version__}"
|
||||
|
||||
@property
|
||||
def default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
@ -1,8 +0,0 @@
|
||||
"""Main entrypoint into package."""
|
||||
from importlib import metadata
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
937
libs/partners/together/poetry.lock
generated
937
libs/partners/together/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,7 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-together"
|
||||
version = "0.1.0"
|
||||
description = "An integration package connecting Together and LangChain"
|
||||
version = "0.1.1"
|
||||
description = "An integration package connecting Together AI and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
@ -12,8 +12,8 @@ license = "MIT"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1"
|
||||
together = "^0.2.10"
|
||||
langchain-core = "^0.1.44"
|
||||
langchain-openai = "^0.1.3"
|
||||
requests = "^2"
|
||||
aiohttp = "^3.9.1"
|
||||
|
||||
@ -27,7 +27,11 @@ pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-openai = { path = "../openai", develop = true }
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
docarray = "^0.32.1"
|
||||
pydantic = "^1.10.9"
|
||||
langchain-standard-tests = { path = "../../standard-tests", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
@ -46,6 +50,9 @@ optional = true
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
@ -57,12 +64,11 @@ optional = true
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff.lint]
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
@ -10,8 +10,8 @@ if __name__ == "__main__":
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file) # noqa: T201
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
||||
|
@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
|
||||
def test_chat_together_model() -> None:
|
||||
"""Test ChatTogether wrapper handles model_name."""
|
||||
chat = ChatTogether(model="foo")
|
||||
assert chat.model_name == "foo"
|
||||
chat = ChatTogether(model_name="bar")
|
||||
assert chat.model_name == "bar"
|
||||
|
||||
|
||||
def test_chat_together_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatTogether(max_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_together_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatTogether(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_together_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatTogether(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_together_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatTogether(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_together_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat together."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatTogether(foo=3, max_tokens=10)
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatTogether(foo=3, model_kwargs={"bar": 2})
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatTogether(foo=3, model_kwargs={"foo": 2})
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatTogether(model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatTogether(model_kwargs={"model": "meta-llama/Llama-3-8b-chat-hf"})
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from Together AI."""
|
||||
llm = ChatTogether()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from Together AI."""
|
||||
llm = ChatTogether()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from ChatTogether."""
|
||||
llm = ChatTogether()
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatTogether."""
|
||||
llm = ChatTogether()
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatTogether."""
|
||||
llm = ChatTogether()
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatTogether."""
|
||||
llm = ChatTogether()
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatTogether."""
|
||||
llm = ChatTogether()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
@ -0,0 +1,21 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
|
||||
class TestTogethertandard(ChatModelIntegrationTests):
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatTogether
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "meta-llama/Llama-3-8b-chat-hf",
|
||||
}
|
@ -1,19 +1,37 @@
|
||||
"""Test Together embeddings."""
|
||||
from langchain_together.embeddings import TogetherEmbeddings
|
||||
"""Test Together AI embeddings."""
|
||||
|
||||
from langchain_together import TogetherEmbeddings
|
||||
|
||||
|
||||
def test_langchain_together_embedding_documents() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
|
||||
def test_langchain_together_embed_documents() -> None:
|
||||
"""Test Together AI embeddings."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = TogetherEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_together_embedding_query() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
|
||||
output = embedding.embed_query(document)
|
||||
def test_langchain_together_embed_query() -> None:
|
||||
"""Test Together AI embeddings."""
|
||||
query = "foo bar"
|
||||
embedding = TogetherEmbeddings()
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
async def test_langchain_together_aembed_documents() -> None:
|
||||
"""Test Together AI embeddings asynchronous."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = TogetherEmbeddings()
|
||||
output = await embedding.aembed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
async def test_langchain_together_aembed_query() -> None:
|
||||
"""Test Together AI embeddings asynchronous."""
|
||||
query = "foo bar"
|
||||
embedding = TogetherEmbeddings()
|
||||
output = await embedding.aembed_query(query)
|
||||
assert len(output) > 0
|
||||
|
@ -1,11 +1,10 @@
|
||||
"""Test Together API wrapper.
|
||||
|
||||
In order to run this test, you need to have an Together api key.
|
||||
You can get it by registering for free at https://api.together.xyz/.
|
||||
You can get it by registering for free at https://api.together.ai/.
|
||||
A test key can be found at https://api.together.xyz/settings/api-keys
|
||||
|
||||
You'll then need to set TOGETHER_API_KEY environment variable to your api key.
|
||||
"""
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from langchain_together import Together
|
||||
|
192
libs/partners/together/tests/unit_tests/test_chat_models.py
Normal file
192
libs/partners/together/tests/unit_tests/test_chat_models.py
Normal file
@ -0,0 +1,192 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_openai.chat_models.base import (
|
||||
_convert_dict_to_message,
|
||||
_convert_message_to_dict,
|
||||
)
|
||||
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatTogether()
|
||||
|
||||
|
||||
def test_together_model_param() -> None:
|
||||
llm = ChatTogether(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatTogether(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
def test_function_dict_to_message_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = _convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
assert isinstance(result, FunctionMessage)
|
||||
assert result.name == name
|
||||
assert result.content == content
|
||||
|
||||
|
||||
def test_convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test__convert_dict_to_message_human_with_name() -> None:
|
||||
message = {"role": "user", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai_with_name() -> None:
|
||||
message = {"role": "assistant", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system_with_name() -> None:
|
||||
message = {"role": "system", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_tool() -> None:
|
||||
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = ToolMessage(content="foo", tool_call_id="bar")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "chatcmpl-7fcZavknQda3SQ",
|
||||
"object": "chat.completion",
|
||||
"created": 1689989000,
|
||||
"model": "meta-llama/Llama-3-8b-chat-hf",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Bab",
|
||||
"name": "KimSolar",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_together_invoke(mock_completion: dict) -> None:
|
||||
llm = ChatTogether()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.invoke("bab")
|
||||
assert res.content == "Bab"
|
||||
assert completed
|
||||
|
||||
|
||||
async def test_together_ainvoke(mock_completion: dict) -> None:
|
||||
llm = ChatTogether()
|
||||
mock_client = AsyncMock()
|
||||
completed = False
|
||||
|
||||
async def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"async_client",
|
||||
mock_client,
|
||||
):
|
||||
res = await llm.ainvoke("bab")
|
||||
assert res.content == "Bab"
|
||||
assert completed
|
||||
|
||||
|
||||
def test_together_invoke_name(mock_completion: dict) -> None:
|
||||
llm = ChatTogether()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.return_value = mock_completion
|
||||
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
messages = [
|
||||
HumanMessage(content="Foo", name="Zorba"),
|
||||
]
|
||||
res = llm.invoke(messages)
|
||||
call_args, call_kwargs = mock_client.create.call_args
|
||||
assert len(call_args) == 0 # no positional args
|
||||
call_messages = call_kwargs["messages"]
|
||||
assert len(call_messages) == 1
|
||||
assert call_messages[0]["role"] == "user"
|
||||
assert call_messages[0]["content"] == "Foo"
|
||||
assert call_messages[0]["name"] == "Zorba"
|
||||
|
||||
# check return type has name
|
||||
assert res.content == "Bab"
|
||||
assert res.name == "KimSolar"
|
@ -0,0 +1,21 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
|
||||
class TestTogetherStandard(ChatModelUnitTests):
|
||||
@pytest.fixture
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatTogether
|
||||
|
||||
@pytest.fixture
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "meta-llama/Llama-3-8b-chat-hf",
|
||||
}
|
@ -1,9 +1,25 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
import os
|
||||
|
||||
from langchain_together.embeddings import TogetherEmbeddings
|
||||
import pytest
|
||||
|
||||
from langchain_together import TogetherEmbeddings
|
||||
|
||||
os.environ["TOGETHER_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
|
||||
TogetherEmbeddings()
|
||||
|
||||
|
||||
def test_together_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
TogetherEmbeddings(model_kwargs={"model": "foo"})
|
||||
|
||||
|
||||
def test_together_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = TogetherEmbeddings(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
@ -1,10 +1,6 @@
|
||||
from langchain_together import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"__version__",
|
||||
"Together",
|
||||
"TogetherEmbeddings",
|
||||
]
|
||||
EXPECTED_ALL = ["ChatTogether", "TogetherEmbeddings", "Together"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""Test Together LLM"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
|
13
libs/partners/together/tests/unit_tests/test_secrets.py
Normal file
13
libs/partners/together/tests/unit_tests/test_secrets.py
Normal file
@ -0,0 +1,13 @@
|
||||
from langchain_together import ChatTogether, TogetherEmbeddings
|
||||
|
||||
|
||||
def test_chat_together_secrets() -> None:
|
||||
o = ChatTogether(together_api_key="foo")
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
||||
|
||||
|
||||
def test_together_embeddings_secrets() -> None:
|
||||
o = TogetherEmbeddings(together_api_key="foo")
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
Loading…
Reference in New Issue
Block a user