together: add chat models, use openai base (#21337)

**Description:** Adding chat completions to the Together AI package,
which is our most popular API. Also staying backwards compatible with
the old API so folks can continue to use the completions API as well.
Also moved the embedding API to use the OpenAI library to standardize it
further.

**Twitter handle:** @nutlope

- [x] **Add tests and docs**: If you're adding a new integration, please
include
- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Hassan El Mghari 2024-05-06 20:47:06 -04:00 committed by GitHub
parent a2d31307bb
commit d6ef5fe86a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1501 additions and 473 deletions

View File

@ -7,15 +7,21 @@
"source": [
"# Together AI\n",
"\n",
"> The Together API makes it easy to fine-tune or run leading open-source models with a couple lines of code. We have integrated the worlds leading open-source models, including Llama-2, RedPajama, Falcon, Alpaca, Stable Diffusion XL, and more. Read more: https://together.ai\n",
"> The Together API makes it easy to query and fine-tune leading open-source models with a couple lines of code. We have integrated the worlds leading open-source models, including Llama-3, Mixtral, DBRX, Stable Diffusion XL, and more. Read more: https://together.ai\n",
"\n",
"To use, you'll need an API key which you can find here:\n",
"https://api.together.xyz/settings/api-keys. This can be passed in as init param\n",
"https://api.together.ai/settings/api-keys. This can be passed in as init param\n",
"``together_api_key`` or set as environment variable ``TOGETHER_API_KEY``.\n",
"\n",
"Together API reference: https://docs.together.ai/reference"
"Together API reference: https://docs.together.ai"
]
},
{
"cell_type": "markdown",
"id": "1c47fc36",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
@ -28,40 +34,43 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "637bb53f",
"metadata": {},
"outputs": [],
"source": [
"# Running chat completions with Together AI\n",
"\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_together import ChatTogether\n",
"\n",
"chat = ChatTogether()\n",
"\n",
"# using chat invoke\n",
"chat.invoke(\"Tell me fun things to do in NYC\")\n",
"\n",
"# using chat stream\n",
"for m in chat.stream(\"Tell me fun things to do in NYC\"):\n",
" print(m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7b7170d-d7c5-4890-9714-a37238343805",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"A: A large language model is a neural network that is trained on a large amount of text data. It is able to generate text that is similar to the training data, and can be used for tasks such as language translation, question answering, and text summarization.\n",
"\n",
"A: A large language model is a neural network that is trained on a large amount of text data. It is able to generate text that is similar to the training data, and can be used for tasks such as language translation, question answering, and text summarization.\n",
"\n",
"A: A large language model is a neural network that is trained on\n"
]
}
],
"outputs": [],
"source": [
"# Running completions with Together AI\n",
"\n",
"from langchain_together import Together\n",
"\n",
"llm = Together(\n",
" model=\"togethercomputer/RedPajama-INCITE-7B-Base\",\n",
" temperature=0.7,\n",
" max_tokens=128,\n",
" top_k=1,\n",
" model=\"codellama/CodeLlama-70b-Python-hf\",\n",
" # together_api_key=\"...\"\n",
")\n",
"\n",
"input_ = \"\"\"You are a teacher with a deep knowledge of machine learning and AI. \\\n",
"You provide succinct and accurate answers. Answer the following question: \n",
"\n",
"What is a large language model?\"\"\"\n",
"print(llm.invoke(input_))"
"print(llm.invoke(\"def bubble_sort(): \"))"
]
}
],

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Copyright (c) 2024 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -11,7 +11,6 @@ integration_test integration_tests: TEST_FILE=tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################

View File

@ -1,48 +1,28 @@
# langchain-together
This package contains the LangChain integration for Together's generative models.
This package contains the LangChain integrations for [Together AI](https://www.together.ai/) through their [APIs](https://docs.together.ai/).
## Installation
## Installation and Setup
```sh
- Install the LangChain partner package
```bash
pip install -U langchain-together
```
- Get your Together AI api key from the [Together Dashboard](https://api.together.ai/settings/api-keys) and set it as an environment variable (`TOGETHER_API_KEY`)
## Chat Completions
This package contains the `ChatTogether` class, which is the recommended way to interface with Together AI chat models.
ADD USAGE EXAMPLE HERE.
Can we add this in the langchain docs?
NEED to add image endpoint + completions endpoint as well
## Embeddings
You can use Together's embedding models through `TogetherEmbeddings` class.
See a [usage example](https://python.langchain.com/docs/integrations/text_embedding/together/)
```py
from langchain_together import TogetherEmbeddings
embeddings = TogetherEmbeddings(
model='togethercomputer/m2-bert-80M-8k-retrieval'
)
embeddings.embed_query("What is a large language model?")
```
## LLMs
You can use Together's generative AI models as Langchain LLMs:
```py
from langchain_together import Together
from langchain_core.prompts import PromptTemplate
llm = Together(
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.7,
max_tokens=64,
top_k=1,
# together_api_key="..."
)
template = """Question: {question}
Answer: """
prompt = PromptTemplate.from_template(template)
chain = prompt | llm
question = "Who was the president in the year Justin Beiber was born?"
print(chain.invoke({"question": question}))
```
Use `togethercomputer/m2-bert-80M-8k-retrieval` as the default model for embeddings.

View File

@ -1,9 +1,5 @@
from langchain_together.chat_models import ChatTogether
from langchain_together.embeddings import TogetherEmbeddings
from langchain_together.llms import Together
from langchain_together.version import __version__
__all__ = [
"__version__",
"Together",
"TogetherEmbeddings",
]
__all__ = ["ChatTogether", "Together", "TogetherEmbeddings"]

View File

@ -0,0 +1,103 @@
"""Wrapper around Together AI's Chat Completions API."""
import os
from typing import (
Any,
Dict,
List,
Optional,
)
import openai
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
)
from langchain_openai.chat_models.base import BaseChatOpenAI
class ChatTogether(BaseChatOpenAI):
"""ChatTogether chat model.
To use, you should have the environment variable `TOGETHER_API_KEY`
set with your API key or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_together import ChatTogether
model = ChatTogether()
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"together_api_key": "TOGETHER_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
return ["langchain", "chat_models", "together"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.together_api_base:
attributes["together_api_base"] = self.together_api_base
return attributes
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "together-chat"
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
"""Model name to use."""
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
together_api_base: Optional[str] = Field(
default="https://api.together.ai/v1/chat/completions", alias="base_url"
)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
client_params = {
"api_key": (
values["together_api_key"].get_secret_value()
if values["together_api_key"]
else None
),
"base_url": values["together_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
}
if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
return values

View File

@ -1,50 +1,265 @@
import os
from typing import Any, Dict, List
"""Wrapper around Together AI's Embeddings API."""
import together # type: ignore
import logging
import os
import warnings
from typing import (
Any,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import openai
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
logger = logging.getLogger(__name__)
class TogetherEmbeddings(BaseModel, Embeddings):
"""TogetherEmbeddings embedding model.
To use, set the environment variable `TOGETHER_API_KEY` with your API key or
pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_together import TogetherEmbeddings
model = TogetherEmbeddings(
model='togethercomputer/m2-bert-80M-8k-retrieval'
)
model = TogetherEmbeddings()
"""
_client: together.Together
together_api_key: SecretStr = convert_to_secret_str("")
model: str
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
"""
dimensions: Optional[int] = None
"""The number of dimensions the resulting output embeddings should have.
@root_validator()
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate environment variables."""
together_api_key = convert_to_secret_str(
values.get("together_api_key") or os.getenv("TOGETHER_API_KEY") or ""
)
values["together_api_key"] = together_api_key
Not yet supported.
"""
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""API Key for Solar API."""
together_api_base: str = Field(
default="https://api.together.ai/v1/embeddings", alias="base_url"
)
"""Endpoint URL to use."""
embedding_ctx_length: int = 4096
"""The maximum number of tokens to embed at once.
# note this sets it globally for module
# there isn't currently a way to pass it into client
together.api_key = together_api_key.get_secret_value()
values["_client"] = together.Together()
Not yet supported.
"""
allowed_special: Union[Literal["all"], Set[str]] = set()
"""Not yet supported."""
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
"""Not yet supported."""
chunk_size: int = 1000
"""Maximum number of texts to embed in each batch.
Not yet supported.
"""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to Together embedding API. Can be float, httpx.Timeout or
None."""
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding.
Not yet supported.
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
skip_empty: bool = False
"""Whether to skip empty strings when embedding or raise an error.
Defaults to not skipping.
Not yet supported."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = None
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
class Config:
extra = Extra.forbid
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
together_api_key = get_from_dict_or_env(
values, "together_api_key", "TOGETHER_API_KEY"
)
values["together_api_key"] = (
convert_to_secret_str(together_api_key) if together_api_key else None
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
client_params = {
"api_key": (
values["together_api_key"].get_secret_value()
if values["together_api_key"]
else None
),
"base_url": values["together_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
}
if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).embeddings
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).embeddings
return values
@property
def _invocation_params(self) -> Dict[str, Any]:
self.model = self.model.replace("-query", "").replace("-passage", "")
params: Dict = {"model": self.model, **self.model_kwargs}
if self.dimensions is not None:
params["dimensions"] = self.dimensions
return params
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return [
i.embedding
for i in self._client.embeddings.create(input=texts, model=self.model).data
]
"""Embed a list of document texts using passage model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
for text in texts:
response = self.client.create(input=text, **params)
if not isinstance(response, dict):
response = response.model_dump()
embeddings.extend([i["embedding"] for i in response["data"]])
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
"""Embed query text using query model.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
params = self._invocation_params
params["model"] = params["model"] + "-query"
response = self.client.create(input=text, **params)
if not isinstance(response, dict):
response = response.model_dump()
return response["data"][0]["embedding"]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts using passage model asynchronously.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
for text in texts:
response = await self.async_client.create(input=text, **params)
if not isinstance(response, dict):
response = response.model_dump()
embeddings.extend([i["embedding"] for i in response["data"]])
return embeddings
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text using query model.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
params = self._invocation_params
params["model"] = params["model"] + "-query"
response = await self.async_client.create(input=text, **params)
if not isinstance(response, dict):
response = response.model_dump()
return response["data"][0]["embedding"]

View File

@ -1,4 +1,5 @@
"""Wrapper around Together AI's Completion API."""
import logging
import warnings
from typing import Any, Dict, List, Optional
@ -13,8 +14,6 @@ from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_together.version import __version__
logger = logging.getLogger(__name__)
@ -22,7 +21,7 @@ class Together(LLM):
"""LLM models from `Together`.
To use, you'll need an API key which you can find here:
https://api.together.xyz/settings/api-keys. This can be passed in as init param
https://api.together.ai/settings/api-keys. This can be passed in as init param
``together_api_key`` or set as environment variable ``TOGETHER_API_KEY``.
Together AI API reference: https://docs.together.ai/reference/completions
@ -35,39 +34,39 @@ class Together(LLM):
model = Together(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
"""
base_url: str = "https://api.together.xyz/v1/completions"
base_url: str = "https://api.together.ai/v1/completions"
"""Base completions API URL."""
together_api_key: SecretStr
"""Together AI API key. Get it here: https://api.together.xyz/settings/api-keys"""
"""Together AI API key. Get it here: https://api.together.ai/settings/api-keys"""
model: str
"""Model name. Available models listed here:
"""Model name. Available models listed here:
Base Models: https://docs.together.ai/docs/inference-models#language-models
Chat Models: https://docs.together.ai/docs/inference-models#chat-models
"""
temperature: Optional[float] = None
"""Model temperature."""
top_p: Optional[float] = None
"""Used to dynamically adjust the number of choices for each predicted token based
on the cumulative probabilities. A value of 1 will always yield the same
output. A temperature less than 1 favors more correctness and is appropriate
for question answering or summarization. A value greater than 1 introduces more
"""Used to dynamically adjust the number of choices for each predicted token based
on the cumulative probabilities. A value of 1 will always yield the same
output. A temperature less than 1 favors more correctness and is appropriate
for question answering or summarization. A value greater than 1 introduces more
randomness in the output.
"""
top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their
probability of occurrence. This technique helps to speed up the generation
process and can improve the quality of the generated text by focusing on the
"""Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their
probability of occurrence. This technique helps to speed up the generation
process and can improve the quality of the generated text by focusing on the
most likely options.
"""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
repetition_penalty: Optional[float] = None
"""A number that controls the diversity of generated text by reducing the
"""A number that controls the diversity of generated text by reducing the
likelihood of repeated sequences. Higher values decrease repetition.
"""
logprobs: Optional[int] = None
"""An integer that specifies how many top token log probabilities are included in
"""An integer that specifies how many top token log probabilities are included in
the response for each token generation step.
"""
@ -107,10 +106,6 @@ class Together(LLM):
def _format_output(self, output: dict) -> str:
return output["choices"][0]["text"]
@staticmethod
def get_user_agent() -> str:
return f"langchain-together/{__version__}"
@property
def default_params(self) -> Dict[str, Any]:
return {

View File

@ -1,8 +0,0 @@
"""Main entrypoint into package."""
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
[tool.poetry]
name = "langchain-together"
version = "0.1.0"
description = "An integration package connecting Together and LangChain"
version = "0.1.1"
description = "An integration package connecting Together AI and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
@ -12,8 +12,8 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1"
together = "^0.2.10"
langchain-core = "^0.1.44"
langchain-openai = "^0.1.3"
requests = "^2"
aiohttp = "^3.9.1"
@ -27,7 +27,11 @@ pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-openai = { path = "../openai", develop = true }
langchain-core = { path = "../../core", develop = true }
docarray = "^0.32.1"
pydantic = "^1.10.9"
langchain-standard-tests = { path = "../../standard-tests", develop = true }
[tool.poetry.group.codespell]
optional = true
@ -46,6 +50,9 @@ optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing]
optional = true
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = { path = "../../core", develop = true }
@ -57,12 +64,11 @@ optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
[tool.ruff.lint]
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"T201", # print
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]

View File

@ -10,8 +10,8 @@ if __name__ == "__main__":
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file) # noqa: T201
print(file)
traceback.print_exc()
print() # noqa: T201
print()
sys.exit(1 if has_failure else 0)

View File

@ -0,0 +1,136 @@
import pytest
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_together import ChatTogether
def test_chat_together_model() -> None:
"""Test ChatTogether wrapper handles model_name."""
chat = ChatTogether(model="foo")
assert chat.model_name == "foo"
chat = ChatTogether(model_name="bar")
assert chat.model_name == "bar"
def test_chat_together_system_message() -> None:
"""Test ChatOpenAI wrapper with system message."""
chat = ChatTogether(max_tokens=10)
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat([system_message, human_message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_together_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatTogether(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model_name
def test_chat_together_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatTogether(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model_name
def test_chat_together_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError):
ChatTogether(
max_tokens=10,
streaming=True,
temperature=0,
n=5,
)
def test_chat_together_extra_kwargs() -> None:
"""Test extra kwargs to chat together."""
# Check that foo is saved in extra_kwargs.
llm = ChatTogether(foo=3, max_tokens=10)
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = ChatTogether(foo=3, model_kwargs={"bar": 2})
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatTogether(foo=3, model_kwargs={"foo": 2})
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
ChatTogether(model_kwargs={"temperature": 0.2})
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatTogether(model_kwargs={"model": "meta-llama/Llama-3-8b-chat-hf"})
def test_stream() -> None:
"""Test streaming tokens from Together AI."""
llm = ChatTogether()
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_astream() -> None:
"""Test streaming tokens from Together AI."""
llm = ChatTogether()
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_abatch() -> None:
"""Test streaming tokens from ChatTogether."""
llm = ChatTogether()
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatTogether."""
llm = ChatTogether()
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_batch() -> None:
"""Test batch tokens from ChatTogether."""
llm = ChatTogether()
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from ChatTogether."""
llm = ChatTogether()
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_invoke() -> None:
"""Test invoke tokens from ChatTogether."""
llm = ChatTogether()
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)

View File

@ -0,0 +1,21 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_together import ChatTogether
class TestTogethertandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatTogether
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "meta-llama/Llama-3-8b-chat-hf",
}

View File

@ -1,19 +1,37 @@
"""Test Together embeddings."""
from langchain_together.embeddings import TogetherEmbeddings
"""Test Together AI embeddings."""
from langchain_together import TogetherEmbeddings
def test_langchain_together_embedding_documents() -> None:
"""Test cohere embeddings."""
documents = ["foo bar"]
embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
def test_langchain_together_embed_documents() -> None:
"""Test Together AI embeddings."""
documents = ["foo bar", "bar foo"]
embedding = TogetherEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output) == 2
assert len(output[0]) > 0
def test_langchain_together_embedding_query() -> None:
"""Test cohere embeddings."""
document = "foo bar"
embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
output = embedding.embed_query(document)
def test_langchain_together_embed_query() -> None:
"""Test Together AI embeddings."""
query = "foo bar"
embedding = TogetherEmbeddings()
output = embedding.embed_query(query)
assert len(output) > 0
async def test_langchain_together_aembed_documents() -> None:
"""Test Together AI embeddings asynchronous."""
documents = ["foo bar", "bar foo"]
embedding = TogetherEmbeddings()
output = await embedding.aembed_documents(documents)
assert len(output) == 2
assert len(output[0]) > 0
async def test_langchain_together_aembed_query() -> None:
"""Test Together AI embeddings asynchronous."""
query = "foo bar"
embedding = TogetherEmbeddings()
output = await embedding.aembed_query(query)
assert len(output) > 0

View File

@ -1,11 +1,10 @@
"""Test Together API wrapper.
In order to run this test, you need to have an Together api key.
You can get it by registering for free at https://api.together.xyz/.
You can get it by registering for free at https://api.together.ai/.
A test key can be found at https://api.together.xyz/settings/api-keys
You'll then need to set TOGETHER_API_KEY environment variable to your api key.
"""
import pytest as pytest
from langchain_together import Together

View File

@ -0,0 +1,192 @@
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_openai.chat_models.base import (
_convert_dict_to_message,
_convert_message_to_dict,
)
from langchain_together import ChatTogether
def test_initialization() -> None:
"""Test chat model initialization."""
ChatTogether()
def test_together_model_param() -> None:
llm = ChatTogether(model="foo")
assert llm.model_name == "foo"
llm = ChatTogether(model_name="foo")
assert llm.model_name == "foo"
def test_function_dict_to_message_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
"content": content,
}
)
assert isinstance(result, FunctionMessage)
assert result.name == name
assert result.content == content
def test_convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test__convert_dict_to_message_human_with_name() -> None:
message = {"role": "user", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_ai_with_name() -> None:
message = {"role": "assistant", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_system_with_name() -> None:
message = {"role": "system", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_tool() -> None:
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
result = _convert_dict_to_message(message)
expected_output = ToolMessage(content="foo", tool_call_id="bar")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
@pytest.fixture
def mock_completion() -> dict:
return {
"id": "chatcmpl-7fcZavknQda3SQ",
"object": "chat.completion",
"created": 1689989000,
"model": "meta-llama/Llama-3-8b-chat-hf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bab",
"name": "KimSolar",
},
"finish_reason": "stop",
}
],
}
def test_together_invoke(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.invoke("bab")
assert res.content == "Bab"
assert completed
async def test_together_ainvoke(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = AsyncMock()
completed = False
async def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"async_client",
mock_client,
):
res = await llm.ainvoke("bab")
assert res.content == "Bab"
assert completed
def test_together_invoke_name(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = MagicMock()
mock_client.create.return_value = mock_completion
with patch.object(
llm,
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Zorba"),
]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"]
assert len(call_messages) == 1
assert call_messages[0]["role"] == "user"
assert call_messages[0]["content"] == "Foo"
assert call_messages[0]["name"] == "Zorba"
# check return type has name
assert res.content == "Bab"
assert res.name == "KimSolar"

View File

@ -0,0 +1,21 @@
"""Standard LangChain interface tests"""
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_together import ChatTogether
class TestTogetherStandard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatTogether
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "meta-llama/Llama-3-8b-chat-hf",
}

View File

@ -1,9 +1,25 @@
"""Test embedding model integration."""
import os
from langchain_together.embeddings import TogetherEmbeddings
import pytest
from langchain_together import TogetherEmbeddings
os.environ["TOGETHER_API_KEY"] = "foo"
def test_initialization() -> None:
"""Test embedding model initialization."""
TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
TogetherEmbeddings()
def test_together_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
TogetherEmbeddings(model_kwargs={"model": "foo"})
def test_together_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = TogetherEmbeddings(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}

View File

@ -1,10 +1,6 @@
from langchain_together import __all__
EXPECTED_ALL = [
"__version__",
"Together",
"TogetherEmbeddings",
]
EXPECTED_ALL = ["ChatTogether", "TogetherEmbeddings", "Together"]
def test_all_imports() -> None:

View File

@ -1,5 +1,3 @@
"""Test Together LLM"""
from typing import cast
from langchain_core.pydantic_v1 import SecretStr

View File

@ -0,0 +1,13 @@
from langchain_together import ChatTogether, TogetherEmbeddings
def test_chat_together_secrets() -> None:
o = ChatTogether(together_api_key="foo")
s = str(o)
assert "foo" not in s
def test_together_embeddings_secrets() -> None:
o = TogetherEmbeddings(together_api_key="foo")
s = str(o)
assert "foo" not in s