ai21: migrate to external repo (#25827)

This commit is contained in:
Erick Friis 2024-08-28 14:24:07 -07:00 committed by GitHub
parent 095b712a26
commit 8fb594fd2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 2 additions and 4900 deletions

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,55 +0,0 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE = tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/ai21 --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_ai21
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_ai21 -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -1,217 +1,3 @@
# langchain-ai21
This package has moved!
This package contains the LangChain integrations for [AI21](https://docs.ai21.com/) models and tools.
## Installation and Setup
- Install the AI21 partner package
```bash
pip install langchain-ai21
```
- Get an AI21 api key and set it as an environment variable (`AI21_API_KEY`)
## Chat Models
This package contains the `ChatAI21` class, which is the recommended way to interface with AI21 chat models, including Jamba-Instruct
and any Jurassic chat models.
To use, install the requirements and configure your environment.
```bash
export AI21_API_KEY=your-api-key
```
Then initialize
```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21
chat = ChatAI21(model="jamba-instruct")
messages = [HumanMessage(content="Hello from AI21")]
chat.invoke(messages)
```
For a list of the supported models, see [this page](https://docs.ai21.com/reference/python-sdk#chat)
### Streaming in Chat
Streaming is supported by the latest models. To use streaming, set the `streaming` parameter to `True` when initializing the model.
```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21
chat = ChatAI21(model="jamba-instruct", streaming=True)
messages = [HumanMessage(content="Hello from AI21")]
response = chat.invoke(messages)
```
or use the `stream` method directly
```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21
chat = ChatAI21(model="jamba-instruct")
messages = [HumanMessage(content="Hello from AI21")]
for chunk in chat.stream(messages):
print(chunk)
```
## LLMs
You can use AI21's Jurassic generative AI models as LangChain LLMs.
To use the newer Jamba model, use the [ChatAI21 chat model](#chat-models), which
supports single-turn instruction/question answering capabilities.
```python
from langchain_core.prompts import PromptTemplate
from langchain_ai21 import AI21LLM
llm = AI21LLM(model="j2-ultra")
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate.from_template(template)
chain = prompt | llm
question = "Which scientist discovered relativity?"
print(chain.invoke({"question": question}))
```
## Embeddings
You can use AI21's [embeddings model](https://docs.ai21.com/reference/embeddings-ref) as shown here:
### Query
```python
from langchain_ai21 import AI21Embeddings
embeddings = AI21Embeddings()
embeddings.embed_query("Hello! This is some query")
```
### Document
```python
from langchain_ai21 import AI21Embeddings
embeddings = AI21Embeddings()
embeddings.embed_documents(["Hello! This is document 1", "And this is document 2!"])
```
## Task-Specific Models
### Contextual Answers
You can use AI21's [contextual answers model](https://docs.ai21.com/reference/contextual-answers-ref) to parse
given text and answer a question based entirely on the provided information.
This means that if the answer to your question is not in the document,
the model will indicate it (instead of providing a false answer)
```python
from langchain_ai21 import AI21ContextualAnswers
tsm = AI21ContextualAnswers()
response = tsm.invoke(input={"context": "Lots of information here", "question": "Your question about the context"})
```
You can also use it with chains and output parsers and vector DBs:
```python
from langchain_ai21 import AI21ContextualAnswers
from langchain_core.output_parsers import StrOutputParser
tsm = AI21ContextualAnswers()
chain = tsm | StrOutputParser()
response = chain.invoke(
{"context": "Your context", "question": "Your question"},
)
```
## Text Splitters
### Semantic Text Splitter
You can use AI21's semantic [text segmentation model](https://docs.ai21.com/reference/text-segmentation-ref) to split a text into segments by topic.
Text is split at each point where the topic changes.
For a list for examples, see [this page](https://github.com/langchain-ai/langchain/blob/master/docs/docs/modules/data_connection/document_transformers/semantic_text_splitter.ipynb).
```python
from langchain_ai21 import AI21SemanticTextSplitter
splitter = AI21SemanticTextSplitter()
response = splitter.split_text("Your text")
```
## Tool calls
### Function calling
AI21 models incorporate the Function Calling feature to support custom user functions. The models generate structured
data that includes the function name and proposed arguments. This data empowers applications to call external APIs and
incorporate the resulting information into subsequent model prompts, enriching responses with real-time data and
context. Through function calling, users can access and utilize various services like transportation APIs and financial
data providers to obtain more accurate and relevant answers. Here is an example of how to use function calling
with AI21 models in LangChain:
```python
import os
from getpass import getpass
from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
from langchain_core.tools import tool
from langchain_ai21.chat_models import ChatAI21
from langchain_core.utils.function_calling import convert_to_openai_tool
os.environ["AI21_API_KEY"] = getpass()
@tool
def get_weather(location: str, date: str) -> str:
"""“Provide the weather for the specified location on the given date.”"""
if location == "New York" and date == "2024-12-05":
return "25 celsius"
elif location == "New York" and date == "2024-12-06":
return "27 celsius"
elif location == "London" and date == "2024-12-05":
return "22 celsius"
return "32 celsius"
llm = ChatAI21(model="jamba-1.5-mini")
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
chat_messages = [SystemMessage(content="You are a helpful assistant. You can use the provided tools "
"to assist with various tasks and provide accurate information")]
human_messages = [
HumanMessage(content="What is the forecast for the weather in New York on December 5, 2024?"),
HumanMessage(content="And what about the 2024-12-06?"),
HumanMessage(content="OK, thank you."),
HumanMessage(content="What is the expected weather in London on December 5, 2024?")]
for human_message in human_messages:
print(f"User: {human_message.content}")
chat_messages.append(human_message)
response = llm_with_tools.invoke(chat_messages)
chat_messages.append(response)
if response.tool_calls:
tool_call = response.tool_calls[0]
if tool_call["name"] == "get_weather":
weather = get_weather.invoke(
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]})
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
llm_answer = llm_with_tools.invoke(chat_messages)
print(f"Assistant: {llm_answer.content}")
else:
print(f"Assistant: {response.content}")
```
https://github.com/langchain-ai/langchain-ai21/tree/main/libs/ai21

View File

@ -1,13 +0,0 @@
from langchain_ai21.chat_models import ChatAI21
from langchain_ai21.contextual_answers import AI21ContextualAnswers
from langchain_ai21.embeddings import AI21Embeddings
from langchain_ai21.llms import AI21LLM
from langchain_ai21.semantic_text_splitter import AI21SemanticTextSplitter
__all__ = [
"AI21LLM",
"ChatAI21",
"AI21Embeddings",
"AI21ContextualAnswers",
"AI21SemanticTextSplitter",
]

View File

@ -1,64 +0,0 @@
import os
from typing import Any, Dict, Optional
from ai21 import AI21Client
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
_DEFAULT_TIMEOUT_SEC = 300
class AI21Base(BaseModel):
"""Base class for AI21 models."""
class Config:
arbitrary_types_allowed = True
client: Any = Field(default=None, exclude=True) #: :meta private:
api_key: Optional[SecretStr] = None
"""API key for AI21 API."""
api_host: Optional[str] = None
"""Host URL"""
timeout_sec: Optional[float] = None
"""Timeout in seconds.
If not set, it will default to the value of the environment
variable `AI21_TIMEOUT_SEC` or 300 seconds.
"""
num_retries: Optional[int] = None
"""Maximum number of retries for API requests before giving up."""
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
api_key = convert_to_secret_str(
values.get("api_key") or os.getenv("AI21_API_KEY") or ""
)
values["api_key"] = api_key
api_host = (
values.get("api_host")
or os.getenv("AI21_API_URL")
or "https://api.ai21.com"
)
values["api_host"] = api_host
timeout_sec = values.get("timeout_sec") or float(
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
)
values["timeout_sec"] = timeout_sec
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
api_key = values["api_key"]
api_host = values["api_host"]
timeout_sec = values["timeout_sec"]
if values.get("client") is None:
values["client"] = AI21Client(
api_key=api_key.get_secret_value(),
api_host=api_host,
timeout_sec=None if timeout_sec is None else float(timeout_sec),
via="langchain",
)
return values

View File

@ -1,324 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Literal, Optional, Union, cast, overload
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import (
AssistantMessage as AI21AssistantMessage,
)
from ai21.models.chat import ChatCompletionChunk, ChatMessageParam
from ai21.models.chat import ChatMessage as AI21ChatMessage
from ai21.models.chat import SystemMessage as AI21SystemMessage
from ai21.models.chat import ToolCall as AI21ToolCall
from ai21.models.chat import ToolFunction as AI21ToolFunction
from ai21.models.chat import ToolMessage as AI21ToolMessage
from ai21.models.chat import UserMessage as AI21UserMessage
from ai21.stream.stream import Stream as AI21Stream
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers.openai_tools import parse_tool_call
from langchain_core.outputs import ChatGenerationChunk
_ChatMessageTypes = Union[AI21ChatMessage, J2ChatMessage]
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
_ROLE_TYPE = Union[str, RoleType]
class ChatAdapter(ABC):
"""Common interface for the different Chat models available in AI21.
It converts LangChain messages to AI21 messages.
Calls the appropriate AI21 model API with the converted messages.
"""
@abstractmethod
def convert_messages(
self,
messages: List[BaseMessage],
) -> Dict[str, Any]:
pass
def _convert_message_to_ai21_message(
self,
message: BaseMessage,
) -> _ChatMessageTypes:
role = self._parse_role(message)
return self._chat_message(role=role, message=message)
def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE:
role = None
if isinstance(message, SystemMessage):
return RoleType.SYSTEM
if isinstance(message, HumanMessage):
return RoleType.USER
if isinstance(message, AIMessage):
return RoleType.ASSISTANT
if isinstance(message, ToolMessage):
return RoleType.TOOL
if isinstance(self, J2ChatAdapter):
if not role:
raise ValueError(
f"Could not resolve role type from message {message}. "
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
)
# if it gets here, we rely on the server to handle the role type
return message.type
@abstractmethod
def _chat_message(
self,
role: _ROLE_TYPE,
message: BaseMessage,
) -> _ChatMessageTypes:
pass
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]:
pass
@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]:
pass
@abstractmethod
def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
pass
def _get_system_message_from_message(self, message: BaseMessage) -> str:
if not isinstance(message.content, str):
raise ValueError(
f"System Message must be of type str. Got {type(message.content)}"
)
return message.content
class J2ChatAdapter(ChatAdapter):
"""Adapter for J2Chat models."""
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
system_message = ""
converted_messages = [] # type: ignore
for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError(_SYSTEM_ERR_MESSAGE)
else:
system_message = self._get_system_message_from_message(message)
else:
converted_message = self._convert_message_to_ai21_message(message)
converted_messages.append(converted_message)
return {"system": system_message, "messages": converted_messages}
def _chat_message(
self,
role: _ROLE_TYPE,
message: BaseMessage,
) -> J2ChatMessage:
return J2ChatMessage(role=RoleType(role), text=cast(str, message.content))
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]: ...
@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]: ...
def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
if stream:
raise NotImplementedError("Streaming is not supported for Jurassic models.")
response = client.chat.create(**params)
return [AIMessage(output.text) for output in response.outputs]
class JambaChatCompletionsAdapter(ChatAdapter):
"""Adapter for Jamba Chat Completions."""
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
return {
"messages": [
self._convert_message_to_ai21_message(message) for message in messages
],
}
def _convert_lc_tool_calls_to_ai21_tool_calls(
self, tool_calls: List[ToolCall]
) -> Optional[List[AI21ToolCall]]:
"""
Convert Langchain ToolCalls to AI21 ToolCalls.
"""
ai21_tool_calls: List[AI21ToolCall] = []
for lc_tool_call in tool_calls:
if "id" not in lc_tool_call or not lc_tool_call["id"]:
raise ValueError("Tool call ID is missing or empty.")
ai21_tool_call = AI21ToolCall(
id=lc_tool_call["id"],
type="function",
function=AI21ToolFunction(
name=lc_tool_call["name"],
arguments=str(lc_tool_call["args"]),
),
)
ai21_tool_calls.append(ai21_tool_call)
return ai21_tool_calls
def _get_content_as_string(self, base_message: BaseMessage) -> str:
if isinstance(base_message.content, str):
return base_message.content
elif isinstance(base_message.content, list):
return "\n".join(str(item) for item in base_message.content)
else:
raise ValueError("Unsupported content type")
def _chat_message(
self,
role: _ROLE_TYPE,
message: BaseMessage,
) -> ChatMessageParam:
content = self._get_content_as_string(message)
if isinstance(message, AIMessage):
return AI21AssistantMessage(
tool_calls=self._convert_lc_tool_calls_to_ai21_tool_calls(
message.tool_calls
),
content=content or None,
)
if isinstance(message, ToolMessage):
return AI21ToolMessage(
tool_call_id=message.tool_call_id,
content=content,
)
if isinstance(message, HumanMessage):
return AI21UserMessage(
content=content,
)
if isinstance(message, SystemMessage):
return AI21SystemMessage(
content=content,
)
return AI21ChatMessage(
role=role.value if isinstance(role, RoleType) else role,
content=content,
)
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]: ...
@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]: ...
def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
response = client.chat.completions.create(stream=stream, **params)
if stream:
return self._stream_response(response)
ai_messages: List[BaseMessage] = []
for message in response.choices:
if message.message.tool_calls:
tool_calls = [
parse_tool_call(tool_call.model_dump(), return_id=True)
for tool_call in message.message.tool_calls
]
ai_messages.append(AIMessage("", tool_calls=tool_calls))
else:
ai_messages.append(AIMessage(message.message.content))
return ai_messages
def _stream_response(
self,
response: AI21Stream[ChatCompletionChunk],
) -> Iterator[ChatGenerationChunk]:
for chunk in response:
converted_message = self._convert_ai21_chunk_to_chunk(chunk)
yield ChatGenerationChunk(message=converted_message)
def _convert_ai21_chunk_to_chunk(
self,
chunk: ChatCompletionChunk,
) -> BaseMessageChunk:
usage = chunk.usage
content = chunk.choices[0].delta.content or ""
if usage is None:
return AIMessageChunk(
content=content,
)
return AIMessageChunk(
content=content,
usage_metadata=UsageMetadata(
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
),
)

View File

@ -1,23 +0,0 @@
from langchain_ai21.chat.chat_adapter import (
ChatAdapter,
J2ChatAdapter,
JambaChatCompletionsAdapter,
)
def create_chat_adapter(model: str) -> ChatAdapter:
"""Create a chat adapter based on the model.
Args:
model: The model to create the chat adapter for.
Returns:
The chat adapter.
"""
if "j2" in model:
return J2ChatAdapter()
if "jamba" in model:
return JambaChatCompletionsAdapter()
raise ValueError(f"Model {model} not supported.")

View File

@ -1,271 +0,0 @@
import asyncio
from functools import partial
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
generate_from_stream,
)
from langchain_core.messages import (
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_ai21.ai21_base import AI21Base
from langchain_ai21.chat.chat_adapter import ChatAdapter
from langchain_ai21.chat.chat_factory import create_chat_adapter
class ChatAI21(BaseChatModel, AI21Base):
"""ChatAI21 chat model. Different model types support different parameters and
different parameter values. Please read the [AI21 reference documentation]
(https://docs.ai21.com/reference) for your model to understand which parameters
are available.
Example:
.. code-block:: python
from langchain_ai21 import ChatAI21
model = ChatAI21(
# defaults to os.environ.get("AI21_API_KEY")
api_key="my_api_key"
)
"""
model: str
"""Model type you wish to interact with.
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
num_results: int = 1
"""The number of responses to generate for a given prompt."""
stop: Optional[List[str]] = None
"""Default stop sequences."""
max_tokens: int = 512
"""The maximum number of tokens to generate for each response."""
min_tokens: int = 0
"""The minimum number of tokens to generate for each response.
_Not supported for all models._"""
temperature: float = 0.4
"""A value controlling the "creativity" of the model's responses."""
top_p: float = 1
"""A value controlling the diversity of the model's responses."""
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step.
_Not supported for all models._"""
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated.
_Not supported for all models._"""
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt.
_Not supported for all models._"""
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses. _Not supported for all models._"""
n: int = 1
"""Number of chat completions to generate for each prompt."""
streaming: bool = False
_chat_adapter: ChatAdapter
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
model = values["model"]
values["_chat_adapter"] = create_chat_adapter(model)
return values
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-ai21"
@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
"n": self.n,
}
if self.stop:
base_params["stop_sequences"] = self.stop
if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()
if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()
return base_params
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="ai21",
ls_model_name=self.model,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None) or self.stop:
ls_params["ls_stop"] = ls_stop
return ls_params
def _build_params_for_request(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Mapping[str, Any]:
params = {}
converted_messages = self._chat_adapter.convert_messages(messages)
if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop
return {
**converted_messages,
**self._default_params,
**params,
**kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream or self.streaming
if should_stream:
return self._handle_stream_from_generate(
messages=messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
params = self._build_params_for_request(
messages=messages,
stop=stop,
stream=should_stream,
**kwargs,
)
messages = self._chat_adapter.call(self.client, **params)
generations = [ChatGeneration(message=message) for message in messages]
return ChatResult(generations=generations)
def _handle_stream_from_generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
stream_iter = self._stream(
messages=messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
return generate_from_stream(stream_iter)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._build_params_for_request(
messages=messages,
stop=stop,
stream=True,
**kwargs,
)
for chunk in self._chat_adapter.call(self.client, **params):
if run_manager and isinstance(chunk.message.content, str):
run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk)
yield chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager
)
def _get_system_message_from_message(self, message: BaseMessage) -> str:
if not isinstance(message.content, str):
raise ValueError(
f"System Message must be of type str. Got {type(message.content)}"
)
return message.content
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

View File

@ -1,112 +0,0 @@
from typing import (
Any,
List,
Optional,
Tuple,
Type,
TypedDict,
Union,
)
from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_ai21.ai21_base import AI21Base
ANSWER_NOT_IN_CONTEXT_RESPONSE = "Answer not in context"
ContextType = Union[str, List[Union[Document, str]]]
class ContextualAnswerInput(TypedDict):
"""Input for the ContextualAnswers runnable."""
context: ContextType
question: str
class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base):
"""Runnable for the AI21 Contextual Answers API."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[ContextualAnswerInput]:
"""Get the input type for this runnable."""
return ContextualAnswerInput
@property
def OutputType(self) -> Type[str]:
"""Get the input type for this runnable."""
return str
def invoke(
self,
input: ContextualAnswerInput,
config: Optional[RunnableConfig] = None,
response_if_no_answer_found: str = ANSWER_NOT_IN_CONTEXT_RESPONSE,
**kwargs: Any,
) -> str:
config = ensure_config(config)
return self._call_with_config(
func=lambda inner_input: self._call_contextual_answers(
inner_input, response_if_no_answer_found
),
input=input,
config=config,
run_type="llm",
)
def _call_contextual_answers(
self,
input: ContextualAnswerInput,
response_if_no_answer_found: str,
) -> str:
context, question = self._convert_input(input)
response = self.client.answer.create(context=context, question=question)
if response.answer is None:
return response_if_no_answer_found
return response.answer
def _convert_input(self, input: ContextualAnswerInput) -> Tuple[str, str]:
context, question = self._extract_context_and_question(input)
context = self._parse_context(context)
return context, question
def _extract_context_and_question(
self,
input: ContextualAnswerInput,
) -> Tuple[ContextType, str]:
context = input.get("context")
question = input.get("question")
if not context or not question:
raise ValueError(
f"Input must contain a 'context' and 'question' fields. Got {input}"
)
if not isinstance(context, list) and not isinstance(context, str):
raise ValueError(
f"Expected input to be a list of strings or Documents."
f" Received {type(input)}"
)
return context, question
def _parse_context(self, context: ContextType) -> str:
if isinstance(context, str):
return context
docs = [
item.page_content if isinstance(item, Document) else item
for item in context
]
return "\n".join(docs)

View File

@ -1,126 +0,0 @@
from itertools import islice
from typing import Any, Iterator, List, Optional
from ai21.models import EmbedType
from langchain_core.embeddings import Embeddings
from langchain_ai21.ai21_base import AI21Base
_DEFAULT_BATCH_SIZE = 128
def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
texts_itr = iter(texts)
return iter(lambda: list(islice(texts_itr, batch_size)), [])
class AI21Embeddings(Embeddings, AI21Base):
"""AI21 embedding model integration.
Install ``langchain_ai21`` and set environment variable ``AI21_API_KEY``.
.. code-block:: bash
pip install -U langchain_ai21
export AI21_API_KEY="your-api-key"
Key init args client params:
api_key: Optional[SecretStr]
batch_size: int
The number of texts that will be sent to the API in each batch.
Use larger batch sizes if working with many short texts. This will reduce
the number of API calls made, and can improve the time it takes to embed
a large number of texts.
num_retries: Optional[int]
Maximum number of retries for API requests before giving up.
timeout_sec: Optional[float]
Timeout in seconds for API requests. If not set, it will default to the
value of the environment variable `AI21_TIMEOUT_SEC` or 300 seconds.
See full list of supported init args and their descriptions in the params section.
Instantiate:
.. code-block:: python
from langchain_ai21 import AI21Embeddings
embed = AI21Embeddings(
# api_key="...",
# batch_size=128,
)
Embed single text:
.. code-block:: python
input_text = "The meaning of life is 42"
vector = embed.embed_query(input_text)
print(vector[:3])
.. code-block:: python
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
Embed multiple texts:
.. code-block:: python
input_texts = ["Document 1...", "Document 2..."]
vectors = embed.embed_documents(input_texts)
print(len(vectors))
# The first 3 coordinates for the first vector
print(vectors[0][:3])
.. code-block:: python
2
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
"""
batch_size: int = _DEFAULT_BATCH_SIZE
"""Maximum number of texts to embed in each batch"""
def embed_documents(
self,
texts: List[str],
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[List[float]]:
"""Embed search docs."""
return self._send_embeddings(
texts=texts,
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.SEGMENT,
**kwargs,
)
def embed_query(
self,
text: str,
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[float]:
"""Embed query text."""
return self._send_embeddings(
texts=[text],
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.QUERY,
**kwargs,
)[0]
def _send_embeddings(
self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any
) -> List[List[float]]:
chunks = _split_texts_into_batches(texts, batch_size)
responses = [
self.client.embed.create(
texts=chunk,
type=embed_type,
**kwargs,
)
for chunk in chunks
]
return [
result.embedding for response in responses for result in response.results
]

View File

@ -1,188 +0,0 @@
import asyncio
from functools import partial
from typing import (
Any,
List,
Mapping,
Optional,
)
from ai21.models import CompletionsResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_ai21.ai21_base import AI21Base
class AI21LLM(BaseLLM, AI21Base):
"""AI21 large language models. Different model types support different parameters
and different parameter values. Please read the [AI21 reference documentation]
(https://docs.ai21.com/reference) for your model to understand which parameters
are available.
AI21LLM supports only the older Jurassic models.
We recommend using ChatAI21 with the newest models, for better results and more
features.
Example:
.. code-block:: python
from langchain_ai21 import AI21LLM
model = AI21LLM(
# defaults to os.environ.get("AI21_API_KEY")
api_key="my_api_key"
)
"""
model: str
"""Model type you wish to interact with.
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
num_results: int = 1
"""The number of responses to generate for a given prompt."""
max_tokens: int = 16
"""The maximum number of tokens to generate for each response."""
min_tokens: int = 0
"""The minimum number of tokens to generate for each response.
_Not supported for all models._"""
temperature: float = 0.7
"""A value controlling the "creativity" of the model's responses."""
top_p: float = 1
"""A value controlling the diversity of the model's responses."""
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step.
_Not supported for all models._"""
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated.
_Not supported for all models._"""
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt.
_Not supported for all models._"""
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses. _Not supported for all models._"""
custom_model: Optional[str] = None
epoch: Optional[int] = None
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "ai21-llm"
@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}
if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()
if self.custom_model is not None:
base_params["custom_model"] = self.custom_model
if self.epoch is not None:
base_params["epoch"] = self.epoch
if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()
return base_params
def _build_params_for_request(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Mapping[str, Any]:
params = {}
if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop
return {
**self._default_params,
**params,
**kwargs,
}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations: List[List[Generation]] = []
token_count = 0
params = self._build_params_for_request(stop=stop, **kwargs)
for prompt in prompts:
response = self._invoke_completion(prompt=prompt, **params)
generation = self._response_to_generation(response)
generations.append(generation)
token_count += self.client.count_tokens(prompt)
llm_output = {"token_count": token_count, "model_name": self.model}
return LLMResult(generations=generations, llm_output=llm_output)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
# Change implementation if integration natively supports async generation.
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), prompts, stop, run_manager
)
def _invoke_completion(
self,
prompt: str,
**kwargs: Any,
) -> CompletionsResponse:
return self.client.completion.create(
prompt=prompt,
**kwargs,
)
def _response_to_generation(
self, response: CompletionsResponse
) -> List[Generation]:
return [
Generation(
text=completion.data.text, # type: ignore[arg-type]
generation_info=completion.to_dict(),
)
for completion in response.completions
]

View File

@ -1,158 +0,0 @@
import copy
import logging
import re
from typing import (
Any,
Iterable,
List,
Optional,
)
from ai21.models import DocumentType
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import SecretStr
from langchain_text_splitters import TextSplitter
from langchain_ai21.ai21_base import AI21Base
logger = logging.getLogger(__name__)
class AI21SemanticTextSplitter(TextSplitter):
"""Splitting text into coherent and readable units,
based on distinct topics and lines.
"""
def __init__(
self,
chunk_size: int = 0,
chunk_overlap: int = 0,
client: Optional[Any] = None,
api_key: Optional[SecretStr] = None,
api_host: Optional[str] = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
**kwargs,
)
self._segmentation = AI21Base(
client=client,
api_key=api_key,
api_host=api_host,
timeout_sec=timeout_sec,
num_retries=num_retries,
).client.segmentation
def split_text(self, source: str) -> List[str]:
"""Split text into multiple components.
Args:
source: Specifies the text input for text segmentation
"""
response = self._segmentation.create(
source=source, source_type=DocumentType.TEXT
)
segments = [segment.segment_text for segment in response.segments]
if self._chunk_size > 0:
return self._merge_splits_no_seperator(segments)
return segments
def split_text_to_documents(self, source: str) -> List[Document]:
"""Split text into multiple documents.
Args:
source: Specifies the text input for text segmentation
"""
response = self._segmentation.create(
source=source, source_type=DocumentType.TEXT
)
return [
Document(
page_content=segment.segment_text,
metadata={"source_type": segment.segment_type},
)
for segment in response.segments
]
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
normalized_text = self._normalized_text(text)
index = 0
previous_chunk_len = 0
for chunk in self.split_text_to_documents(text):
# merge metadata from user (if exists) and from segmentation api
metadata = copy.deepcopy(_metadatas[i])
metadata.update(chunk.metadata)
if self._add_start_index:
# find the start index of the chunk
offset = index + previous_chunk_len - self._chunk_overlap
normalized_chunk = self._normalized_text(chunk.page_content)
index = normalized_text.find(normalized_chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(normalized_chunk)
documents.append(
Document(
page_content=chunk.page_content,
metadata=metadata,
)
)
return documents
def _normalized_text(self, string: str) -> str:
"""Use regular expression to replace sequences of '\n'"""
return re.sub(r"\s+", "", string)
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
"""This method overrides the default implementation of TextSplitter"""
return self._merge_splits_no_seperator(splits)
def _merge_splits_no_seperator(self, splits: Iterable[str]) -> List[str]:
"""Merge splits into chunks.
If the segment size is bigger than chunk_size,
it will be left as is (won't be cut to match to chunk_size).
If the segment size is smaller than chunk_size,
it will be merged with the next segment until the chunk_size is reached.
"""
chunks = []
current_chunk = ""
for split in splits:
split_len = self._length_function(split)
if split_len > self._chunk_size:
logger.warning(
f"Split of length {split_len}"
f"exceeds chunk size {self._chunk_size}."
)
if self._length_function(current_chunk) + split_len > self._chunk_size:
if current_chunk != "":
chunks.append(current_chunk)
current_chunk = ""
current_chunk += split
if current_chunk != "":
chunks.append(current_chunk)
return chunks

File diff suppressed because it is too large Load Diff

View File

@ -1,90 +0,0 @@
[build-system]
requires = [ "poetry-core>=1.0.0",]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-ai21"
version = "0.1.8"
description = "An integration package connecting AI21 and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.mypy]
disallow_untyped_defs = "True"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/ai21"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-ai21%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.2.4"
langchain-text-splitters = "^0.2.0"
ai21 = "^2.14.1"
[tool.ruff.lint]
select = [ "E", "F", "I",]
[tool.coverage.run]
omit = [ "tests/*",]
[tool.pytest.ini_options]
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
markers = [ "requires: mark tests as requiring a specific library", "asyncio: mark tests as requiring asyncio", "compile: mark placeholder test used to compile integration tests without running them", "scheduled: mark tests to run in scheduled testing",]
asyncio_mode = "auto"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
[tool.poetry.group.test.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.test.dependencies.langchain-standard-tests]
path = "../../standard-tests"
develop = true
[tool.poetry.group.test.dependencies.langchain-text-splitters]
path = "../../text-splitters"
develop = true
[tool.poetry.group.dev.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.typing.dependencies.langchain-core]
path = "../../core"
develop = true

View File

@ -1,17 +0,0 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -1,188 +0,0 @@
"""Test ChatAI21 chat model."""
import pytest
from langchain_core.messages import (
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_ai21.chat_models import ChatAI21
from tests.unit_tests.conftest import (
J2_CHAT_MODEL_NAME,
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
JAMBA_CHAT_MODEL_NAME,
JAMBA_FAMILY_MODEL_NAMES,
)
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
"when_jamba1.5-mini_model",
"when_jamba1.5-large_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
],
)
def test_invoke(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
@pytest.mark.parametrize(
ids=[
"when_j2_model_num_results_is_1",
"when_j2_model_num_results_is_3",
"when_jamba_model_n_is_1",
"when_jamba_model_n_is_3",
"when_jamba1.5_mini_model_n_is_1",
"when_jamba1.5_mini_model_n_is_3",
"when_jamba1.5_large_model_n_is_1",
"when_jamba1.5_large_model_n_is_3",
],
argnames=["model", "num_results"],
argvalues=[
(J2_CHAT_MODEL_NAME, 1),
(J2_CHAT_MODEL_NAME, 3),
(JAMBA_CHAT_MODEL_NAME, 1),
(JAMBA_CHAT_MODEL_NAME, 3),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 1),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 3),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 1),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 3),
],
)
def test_generation(model: str, num_results: int) -> None:
"""Test generation with multiple models and different result counts."""
# Determine the configuration key based on the model type
config_key = "n" if model in JAMBA_FAMILY_MODEL_NAMES else "num_results"
# Create the model instance using the appropriate key for the result count
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
result = llm.generate([[message]], config=dict(tags=["foo"]))
for generations in result.generations:
assert len(generations) == num_results
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
"when_jamba1.5_mini_model",
"when_jamba1.5_large_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
],
)
async def test_ageneration(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
for generations in result.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test__chat_stream() -> None:
llm = ChatAI21(model="jamba-1.5-mini") # type: ignore[call-arg]
message = HumanMessage(content="What is the meaning of life?")
for chunk in llm.stream([message]):
assert isinstance(chunk, AIMessageChunk)
assert isinstance(chunk.content, str)
def test__j2_chat_stream__should_raise_error() -> None:
llm = ChatAI21(model="j2-ultra") # type: ignore[call-arg]
message = HumanMessage(content="What is the meaning of life?")
with pytest.raises(NotImplementedError):
for _ in llm.stream([message]):
pass
@pytest.mark.parametrize(
ids=[
"when_jamba1.5_mini_model",
"when_jamba1.5_large_model",
],
argnames=["model"],
argvalues=[
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
],
)
def test_tool_calls(model: str) -> None:
@tool
def get_weather(location: str, date: str) -> str:
"""“Provide the weather for the specified location on the given date.”"""
if location == "New York" and date == "2024-12-05":
return "25 celsius"
return "32 celsius"
llm = ChatAI21(model=model, temperature=0) # type: ignore[call-arg]
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
chat_messages = [
SystemMessage(
content="You are a helpful assistant. "
"You can use the provided tools "
"to assist with various tasks and provide "
"accurate information"
),
HumanMessage(
content="What is the forecast for the weather "
"in New York on December 5, 2024?"
),
]
response = llm_with_tools.invoke(chat_messages)
chat_messages.append(response)
assert response.tool_calls is not None # type: ignore[attr-defined]
tool_call = response.tool_calls[0] # type: ignore[attr-defined]
assert tool_call["name"] == "get_weather"
weather = get_weather.invoke( # type: ignore[attr-defined]
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]}
)
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
llm_answer = llm_with_tools.invoke(chat_messages)
content = llm_answer.content.lower() # type: ignore[union-attr]
assert "new york" in content and "25" in content and "celsius" in content

View File

@ -1,7 +0,0 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,61 +0,0 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable
from langchain_ai21.contextual_answers import (
ANSWER_NOT_IN_CONTEXT_RESPONSE,
AI21ContextualAnswers,
)
context = """
Albert Einstein German: 14 March 1879 18 April 1955)
was a German-born theoretical physicist who is widely held
to be one of the greatest and most influential scientists
"""
_GOOD_QUESTION = "When did Albert Einstein born?"
_BAD_QUESTION = "What color is Yoda's light saber?"
_EXPECTED_PARTIAL_RESPONSE = "March 14, 1879"
def test_invoke__when_good_question() -> None:
llm = AI21ContextualAnswers()
response = llm.invoke(
{"context": context, "question": _GOOD_QUESTION},
config={"metadata": {"name": "I AM A TEST"}},
)
assert response != ANSWER_NOT_IN_CONTEXT_RESPONSE
def test_invoke__when_bad_question__should_return_answer_not_in_context() -> None:
llm = AI21ContextualAnswers()
response = llm.invoke(input={"context": context, "question": _BAD_QUESTION})
assert response == ANSWER_NOT_IN_CONTEXT_RESPONSE
def test_invoke__when_response_if_no_answer_passed__should_use_it() -> None:
response_if_no_answer_found = "This should be the response"
llm = AI21ContextualAnswers()
response = llm.invoke(
input={"context": context, "question": _BAD_QUESTION},
response_if_no_answer_found=response_if_no_answer_found,
)
assert response == response_if_no_answer_found
def test_invoke_when_used_in_a_simple_chain_with_no_vectorstore() -> None:
tsm = AI21ContextualAnswers()
chain: Runnable = tsm | StrOutputParser()
response = chain.invoke(
{"context": context, "question": _GOOD_QUESTION},
)
assert response != ANSWER_NOT_IN_CONTEXT_RESPONSE

View File

@ -1,37 +0,0 @@
"""Test AI21 embeddings."""
from langchain_ai21.embeddings import AI21Embeddings
def test_langchain_ai21_embedding_documents() -> None:
"""Test AI21 embeddings."""
documents = ["foo bar"]
embedding = AI21Embeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) > 0
def test_langchain_ai21_embedding_query() -> None:
"""Test AI21 embeddings."""
document = "foo bar"
embedding = AI21Embeddings()
output = embedding.embed_query(document)
assert len(output) > 0
def test_langchain_ai21_embedding_documents__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = ["foo", "bar"]
embedding = AI21Embeddings()
output = embedding.embed_documents(documents, batch_size=1)
assert len(output) == 2
assert len(output[0]) > 0
def test_langchain_ai21_embedding_query__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = "foo bar"
embedding = AI21Embeddings()
output = embedding.embed_query(documents, batch_size=1)
assert len(output) > 0

View File

@ -1,104 +0,0 @@
"""Test AI21LLM llm."""
from langchain_ai21.llms import AI21LLM
_MODEL_NAME = "j2-mid"
def _generate_llm() -> AI21LLM:
"""
Testing AI21LLm using non default parameters with the following parameters
"""
return AI21LLM(
model=_MODEL_NAME,
max_tokens=2, # Use less tokens for a faster response
temperature=0, # for a consistent response
epoch=1,
)
def test_stream() -> None:
"""Test streaming tokens from AI21."""
llm = AI21LLM(
model=_MODEL_NAME,
)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token, str)
async def test_abatch() -> None:
"""Test streaming tokens from AI21LLM."""
llm = AI21LLM(
model=_MODEL_NAME,
)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model=_MODEL_NAME,
)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
def test_batch() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model=_MODEL_NAME,
)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model=_MODEL_NAME,
)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)
def test_invoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model=_MODEL_NAME,
)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result, str)
def test__generate() -> None:
llm = _generate_llm()
llm_result = llm.generate(
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
stop=["##"],
)
assert len(llm_result.generations) > 0
assert llm_result.llm_output["token_count"] != 0 # type: ignore
async def test__agenerate() -> None:
llm = _generate_llm()
llm_result = await llm.agenerate(
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
stop=["##"],
)
assert len(llm_result.generations) > 0
assert llm_result.llm_output["token_count"] != 0 # type: ignore

View File

@ -1,130 +0,0 @@
from ai21 import AI21Client
from langchain_core.documents import Document
from langchain_ai21 import AI21SemanticTextSplitter
TEXT = (
"The original full name of the franchise is Pocket Monsters (ポケットモンスター, "
"Poketto Monsutā), which was abbreviated to "
"Pokemon during development of the original games.\n"
"When the franchise was released internationally, the short form of the title was "
"used, with an acute accent (´) "
"over the e to aid in pronunciation.\n"
"Pokémon refers to both the franchise itself and the creatures within its "
"fictional universe.\n"
"As a noun, it is identical in both the singular and plural, as is every "
"individual species name;[10] it is "
'grammatically correct to say "one Pokémon" and "many Pokémon", as well '
'as "one Pikachu" and "many Pikachu".\n'
"In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or "
"/'powkɪmon/ (poe-key-mon).\n"
"The Pokémon franchise is set in a world in which humans coexist with creatures "
"known as Pokémon.\n"
"Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced "
"in subsequent games; as of December 2023, 1,025 Pokémon species have been "
"introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example,"
"Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] "
"that possess electrical abilities.[15]\nThe player character takes the role of a "
"Pokémon Trainer.\nThe Trainer has three primary goals: travel and explore the "
"Pokémon world; discover and catch each Pokémon species in order to complete their"
"Pokédex; and train a team of up to six Pokémon at a time and have them engage "
"in battles.\nMost Pokémon can be caught with spherical devices known as Poké "
"Balls.\nOnce the opposing Pokémon is sufficiently weakened, the Trainer throws "
"the Poké Ball against the Pokémon, which is then transformed into a form of "
"energy and transported into the device.\nOnce the catch is successful, "
"the Pokémon is tamed and is under the Trainer's command from then on.\n"
"If the Poké Ball is thrown again, the Pokémon re-materializes into its "
"original state.\nThe Trainer's Pokémon can engage in battles against opposing "
"Pokémon, including those in the wild or owned by other Trainers.\nBecause the "
"franchise is aimed at children, these battles are never presented as overtly "
"violent and contain no blood or gore.[I]\nPokémon never die in battle, instead "
"fainting upon being defeated.[20][21][22]\nAfter a Pokémon wins a battle, it "
"gains experience and becomes stronger.[23] After gaining a certain amount of "
"experience points, its level increases, as well as one or more of its "
"statistics.\nAs its level increases, the Pokémon can learn new offensive "
"and defensive moves to use in battle.[24][25] Furthermore, many species can "
"undergo a form of spontaneous metamorphosis called Pokémon evolution, and "
"transform into stronger forms.[26] Most Pokémon will evolve at a certain level, "
"while others evolve through different means, such as exposure to a certain "
"item.[27]\n"
)
def test_split_text_to_document() -> None:
segmentation = AI21SemanticTextSplitter()
segments = segmentation.split_text_to_documents(source=TEXT)
assert len(segments) > 0
for segment in segments:
assert segment.page_content is not None
assert segment.metadata is not None
def test_split_text() -> None:
segmentation = AI21SemanticTextSplitter()
segments = segmentation.split_text(source=TEXT)
assert len(segments) > 0
def test_split_text__when_chunk_size_is_large__should_merge_segments() -> None:
segmentation_no_merge = AI21SemanticTextSplitter()
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
segmentation_merge = AI21SemanticTextSplitter(chunk_size=1000)
segments_merge = segmentation_merge.split_text(source=TEXT)
# Assert that a merge did happen
assert len(segments_no_merge) > len(segments_merge)
reconstructed_text_merged = "".join(segments_merge)
reconstructed_text_non_merged = "".join(segments_no_merge)
# Assert that the merge did not change the content
assert reconstructed_text_merged == reconstructed_text_non_merged
def test_split_text__chunk_size_is_too_small__should_return_non_merged_segments() -> (
None
):
segmentation_no_merge = AI21SemanticTextSplitter()
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
segmentation_merge = AI21SemanticTextSplitter(chunk_size=10)
segments_merge = segmentation_merge.split_text(source=TEXT)
# Assert that a merge did happen
assert len(segments_no_merge) == len(segments_merge)
reconstructed_text_merged = "".join(segments_merge)
reconstructed_text_non_merged = "".join(segments_no_merge)
# Assert that the merge did not change the content
assert reconstructed_text_merged == reconstructed_text_non_merged
def test_split_text__when_chunk_size_set_with_ai21_tokenizer() -> None:
segmentation_no_merge = AI21SemanticTextSplitter(
length_function=AI21Client().count_tokens
)
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
segmentation_merge = AI21SemanticTextSplitter(
chunk_size=1000, length_function=AI21Client().count_tokens
)
segments_merge = segmentation_merge.split_text(source=TEXT)
# Assert that a merge did happen
assert len(segments_no_merge) > len(segments_merge)
reconstructed_text_merged = "".join(segments_merge)
reconstructed_text_non_merged = "".join(segments_no_merge)
# Assert that the merge did not change the content
assert reconstructed_text_merged == reconstructed_text_non_merged
def test_create_documents() -> None:
texts = [TEXT]
segmentation = AI21SemanticTextSplitter()
documents = segmentation.create_documents(texts=texts)
assert len(documents) > 0
for document in documents:
assert document.page_content is not None
assert document.metadata is not None
def test_split_documents() -> None:
documents = [Document(page_content=TEXT, metadata={"foo": "bar"})]
segmentation = AI21SemanticTextSplitter()
segments = segmentation.split_documents(documents=documents)
assert len(segments) > 0
for segment in segments:
assert segment.page_content is not None
assert segment.metadata is not None

View File

@ -1,89 +0,0 @@
"""Standard LangChain interface tests"""
import time
from typing import Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_ai21 import ChatAI21
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
class BaseTestAI21(ChatModelIntegrationTests):
def teardown(self) -> None:
# avoid getting rate limited
time.sleep(1)
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(self, model: BaseChatModel) -> None:
super().test_usage_metadata(model)
class TestAI21J2(BaseTestAI21):
has_tool_calling = False
@property
def chat_model_params(self) -> dict:
return {
"model": "j2-ultra",
"rate_limiter": rate_limiter,
}
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
def test_stream(self, model: BaseChatModel) -> None:
super().test_stream(model)
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
async def test_astream(self, model: BaseChatModel) -> None:
await super().test_astream(model)
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model)
class TestAI21Jamba(BaseTestAI21):
has_tool_calling = False
@property
def chat_model_params(self) -> dict:
return {
"model": "jamba-instruct-preview",
}
class TestAI21Jamba1_5(BaseTestAI21):
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
@property
def chat_model_params(self) -> dict:
return {
"model": "jamba-1.5-mini",
}
@pytest.mark.xfail(reason="Prompt doesn't generate tool calls for Jamba 1.5.")
def test_tool_calling(self, model: BaseChatModel) -> None:
super().test_tool_calling(model)
@pytest.mark.xfail(reason="Prompt doesn't generate tool calls for Jamba 1.5.")
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
@pytest.mark.xfail(reason="Requires tool calling & stream - still WIP")
def test_structured_output(self, model: BaseChatModel) -> None:
super().test_structured_output(model)
@pytest.mark.xfail(reason="Requires tool calling & stream - still WIP")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
super().test_structured_output_pydantic_2_v1(model)

View File

@ -1,9 +0,0 @@
import pytest
from langchain_ai21.chat.chat_adapter import ChatAdapter
from langchain_ai21.chat.chat_factory import create_chat_adapter
@pytest.fixture
def chat_adapter(model: str) -> ChatAdapter:
return create_chat_adapter(model)

View File

@ -1,248 +0,0 @@
from typing import List
import pytest
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import (
AssistantMessage,
ChatMessage,
UserMessage,
)
from ai21.models.chat import (
SystemMessage as AI21SystemMessage,
)
from ai21.models.chat import ToolMessage as AI21ToolMessage
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.messages import (
ChatMessage as LangChainChatMessage,
)
from langchain_ai21.chat.chat_adapter import ChatAdapter
_J2_MODEL_NAME = "j2-ultra"
_JAMBA_MODEL_NAME = "jamba-instruct-preview"
_JAMBA_1_5_MINI_MODEL_NAME = "jamba-1.5-mini"
_JAMBA_1_5_LARGE_MODEL_NAME = "jamba-1.5-large"
@pytest.mark.parametrize(
ids=[
"when_human_message_j2_model",
"when_ai_message_j2_model",
"when_human_message_jamba_model",
"when_ai_message_jamba_model",
],
argnames=["model", "message", "expected_ai21_message"],
argvalues=[
(
_J2_MODEL_NAME,
HumanMessage(content="Human Message Content"),
J2ChatMessage(role=RoleType.USER, text="Human Message Content"),
),
(
_J2_MODEL_NAME,
AIMessage(content="AI Message Content"),
J2ChatMessage(role=RoleType.ASSISTANT, text="AI Message Content"),
),
(
_JAMBA_MODEL_NAME,
HumanMessage(content="Human Message Content"),
UserMessage(role="user", content="Human Message Content"),
),
(
_JAMBA_MODEL_NAME,
AIMessage(content="AI Message Content"),
AssistantMessage(
role="assistant", content="AI Message Content", tool_calls=[]
),
),
],
)
def test_convert_message_to_ai21_message(
message: BaseMessage,
expected_ai21_message: ChatMessage,
chat_adapter: ChatAdapter,
) -> None:
ai21_message = chat_adapter._convert_message_to_ai21_message(message)
assert ai21_message == expected_ai21_message
@pytest.mark.parametrize(
ids=[
"when_system_message_j2_model",
"when_langchain_chat_message_j2_model",
],
argnames=["model", "message"],
argvalues=[
(
_J2_MODEL_NAME,
AI21SystemMessage(content="System Message Content"),
),
(
_J2_MODEL_NAME,
LangChainChatMessage(content="Chat Message Content", role="human"),
),
],
)
def test_convert_message_to_ai21_message__when_invalid_role__should_raise_exception(
message: BaseMessage,
chat_adapter: ChatAdapter,
) -> None:
with pytest.raises(ValueError) as e:
chat_adapter._convert_message_to_ai21_message(message)
assert e.value.args[0] == (
f"Could not resolve role type from message {message}. "
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
)
@pytest.mark.parametrize(
ids=[
"when_all_messages_are_human_messages__should_return_system_none_j2_model",
"when_first_message_is_system__should_return_system_j2_model",
"when_all_messages_are_human_messages__should_return_system_none_jamba_model",
"when_first_message_is_system__should_return_system_jamba_model",
"when_tool_calling_message__should_return_tool_jamba_mini_model",
"when_tool_calling_message__should_return_tool_jamba_large_model",
],
argnames=["model", "messages", "expected_messages"],
argvalues=[
(
_J2_MODEL_NAME,
[
HumanMessage(content="Human Message Content 1"),
HumanMessage(content="Human Message Content 2"),
],
{
"system": "",
"messages": [
J2ChatMessage(
role=RoleType.USER,
text="Human Message Content 1",
),
J2ChatMessage(
role=RoleType.USER,
text="Human Message Content 2",
),
],
},
),
(
_J2_MODEL_NAME,
[
SystemMessage(content="System Message Content 1"),
HumanMessage(content="Human Message Content 1"),
],
{
"system": "System Message Content 1",
"messages": [
J2ChatMessage(
role=RoleType.USER,
text="Human Message Content 1",
),
],
},
),
(
_JAMBA_MODEL_NAME,
[
HumanMessage(content="Human Message Content 1"),
HumanMessage(content="Human Message Content 2"),
],
{
"messages": [
UserMessage(
role="user",
content="Human Message Content 1",
),
UserMessage(
role="user",
content="Human Message Content 2",
),
]
},
),
(
_JAMBA_MODEL_NAME,
[
SystemMessage(content="System Message Content 1"),
HumanMessage(content="Human Message Content 1"),
],
{
"messages": [
AI21SystemMessage(
role="system", content="System Message Content 1"
),
UserMessage(role="user", content="Human Message Content 1"),
],
},
),
(
_JAMBA_1_5_MINI_MODEL_NAME,
[
ToolMessage(
content="42",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
)
],
{
"messages": [
AI21ToolMessage(
role="tool",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
content="42",
),
],
},
),
(
_JAMBA_1_5_LARGE_MODEL_NAME,
[
ToolMessage(
content="42",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
)
],
{
"messages": [
AI21ToolMessage(
role="tool",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
content="42",
),
],
},
),
],
)
def test_convert_messages(
chat_adapter: ChatAdapter,
messages: List[BaseMessage],
expected_messages: List[ChatMessage],
) -> None:
converted_messages = chat_adapter.convert_messages(messages)
assert converted_messages == expected_messages
@pytest.mark.parametrize(
ids=[
"when_j2_model",
],
argnames=["model"],
argvalues=[
(_J2_MODEL_NAME,),
],
)
def test_convert_messages__when_system_is_not_first(chat_adapter: ChatAdapter) -> None:
messages = [
HumanMessage(content="Human Message Content 1"),
SystemMessage(content="System Message Content 1"),
]
with pytest.raises(ValueError):
chat_adapter.convert_messages(messages)

View File

@ -1,34 +0,0 @@
from typing import Type
import pytest
from langchain_ai21.chat.chat_adapter import (
ChatAdapter,
J2ChatAdapter,
JambaChatCompletionsAdapter,
)
from langchain_ai21.chat.chat_factory import create_chat_adapter
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
],
argnames=["model", "expected_chat_type"],
argvalues=[
(J2_CHAT_MODEL_NAME, J2ChatAdapter),
(JAMBA_CHAT_MODEL_NAME, JambaChatCompletionsAdapter),
],
)
def test_create_chat_adapter_with_supported_models(
model: str, expected_chat_type: Type[ChatAdapter]
) -> None:
adapter = create_chat_adapter(model)
assert isinstance(adapter, expected_chat_type)
def test_create_chat_adapter__when_model_not_supported() -> None:
with pytest.raises(ValueError):
create_chat_adapter("unsupported-model")

View File

@ -1,207 +0,0 @@
import os
from contextlib import contextmanager
from typing import Generator
from unittest.mock import Mock
import pytest
from ai21 import AI21Client
from ai21.models import (
AnswerResponse,
ChatOutput,
ChatResponse,
Completion,
CompletionData,
CompletionFinishReason,
CompletionsResponse,
FinishReason,
Penalty,
RoleType,
SegmentationResponse,
)
from ai21.models.responses.segmentation_response import Segment
from pytest_mock import MockerFixture
J2_CHAT_MODEL_NAME = "j2-ultra"
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
JAMBA_1_5_MINI_CHAT_MODEL_NAME = "jamba-1.5-mini"
JAMBA_1_5_LARGE_CHAT_MODEL_NAME = "jamba-1.5-large"
DUMMY_API_KEY = "test_api_key"
JAMBA_FAMILY_MODEL_NAMES = [
JAMBA_CHAT_MODEL_NAME,
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
]
BASIC_EXAMPLE_LLM_PARAMETERS = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
),
}
BASIC_EXAMPLE_CHAT_PARAMETERS = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
),
"n": 3,
}
SEGMENTS = [
Segment( # type: ignore[call-arg]
segment_type="normal_text",
segment_text=(
"The original full name of the franchise is Pocket Monsters "
"(ポケットモンスター, Poketto Monsutā), which was abbreviated to "
"Pokemon during development of the original games.\n\nWhen the "
"franchise was released internationally, the short form of the "
"title was used, with an acute accent (´) over the e to aid "
"in pronunciation."
),
),
Segment( # type: ignore[call-arg]
segment_type="normal_text",
segment_text=(
"Pokémon refers to both the franchise itself and the creatures "
"within its fictional universe.\n\nAs a noun, it is identical in "
"both the singular and plural, as is every individual species "
'name;[10] it is grammatically correct to say "one Pokémon" '
'and "many Pokémon", as well as "one Pikachu" and "many '
'Pikachu".\n\nIn English, Pokémon may be pronounced either '
"/'powkɛmon/ (poe-keh-mon) or /'powkɪmon/ (poe-key-mon)."
),
),
]
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
).to_dict(),
}
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
).to_dict(),
"n": 3,
}
@pytest.fixture
def mocked_completion_response(mocker: MockerFixture) -> Mock:
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
mocked_response.prompt = "this is a test prompt"
mocked_response.completions = [
Completion( # type: ignore[call-arg]
data=CompletionData(text="test", tokens=[]),
finish_reason=CompletionFinishReason(reason=None, length=None),
)
]
return mocked_response
@pytest.fixture
def mock_client_with_completion(
mocker: MockerFixture, mocked_completion_response: Mock
) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.completion = mocker.MagicMock()
mock_client.completion.create.side_effect = [
mocked_completion_response,
mocked_completion_response,
]
mock_client.count_tokens.side_effect = [10, 20]
return mock_client
@pytest.fixture
def mock_client_with_chat(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.chat = mocker.MagicMock()
output = ChatOutput( # type: ignore[call-arg]
text="Hello Pickle Rick!",
role=RoleType.ASSISTANT,
finish_reason=FinishReason(reason="testing"),
)
mock_client.chat.create.return_value = ChatResponse(outputs=[output])
return mock_client
@contextmanager
def temporarily_unset_api_key() -> Generator:
"""
Unset and set environment key for testing purpose for when an API KEY is not set
"""
api_key = os.environ.pop("AI21_API_KEY", None)
yield
if api_key is not None:
os.environ["AI21_API_KEY"] = api_key
@pytest.fixture
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.answer = mocker.MagicMock()
mock_client.answer.create.return_value = AnswerResponse( # type: ignore[call-arg]
id="some_id",
answer="some answer",
answer_in_context=False,
)
return mock_client
@pytest.fixture
def mock_client_with_semantic_text_splitter(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.segmentation = mocker.MagicMock()
mock_client.segmentation.create.return_value = SegmentationResponse(
id="12345",
segments=SEGMENTS,
)
return mock_client

View File

@ -1,174 +0,0 @@
"""Test chat model integration."""
from typing import cast
from unittest.mock import Mock, call
import pytest
from ai21 import MissingApiKeyError
from ai21.models import ChatMessage, Penalty, RoleType
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_ai21.chat_models import (
ChatAI21,
)
from tests.unit_tests.conftest import (
BASIC_EXAMPLE_CHAT_PARAMETERS,
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
DUMMY_API_KEY,
temporarily_unset_api_key,
)
def test_initialization__when_no_api_key__should_raise_exception() -> None:
"""Test integration initialization."""
with temporarily_unset_api_key():
with pytest.raises(MissingApiKeyError):
ChatAI21(model="j2-ultra") # type: ignore[call-arg]
def test_initialization__when_default_parameters_in_init() -> None:
"""Test chat model initialization."""
ChatAI21(api_key=DUMMY_API_KEY, model="j2-ultra") # type: ignore[call-arg, arg-type]
def test_initialization__when_custom_parameters_in_init() -> None:
model = "j2-ultra"
num_results = 1
max_tokens = 10
min_tokens = 20
temperature = 0.1
top_p = 0.1
top_k_return = 0
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) # type: ignore[call-arg]
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) # type: ignore[call-arg]
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) # type: ignore[call-arg]
llm = ChatAI21( # type: ignore[call-arg]
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
model=model,
num_results=num_results,
max_tokens=max_tokens,
min_tokens=min_tokens,
temperature=temperature,
top_p=top_p,
top_k_return=top_k_return,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
count_penalty=count_penalty,
)
assert llm.model == model
assert llm.num_results == num_results
assert llm.max_tokens == max_tokens
assert llm.min_tokens == min_tokens
assert llm.temperature == temperature
assert llm.top_p == top_p
assert llm.top_k_return == top_k_return
assert llm.frequency_penalty == frequency_penalty
assert llm.presence_penalty == presence_penalty
assert count_penalty == count_penalty
def test_invoke(mock_client_with_chat: Mock) -> None:
chat_input = "I'm Pickle Rick"
llm = ChatAI21(
model="j2-ultra",
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
client=mock_client_with_chat,
**BASIC_EXAMPLE_CHAT_PARAMETERS, # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
)
llm.invoke(input=chat_input, config=dict(tags=["foo"]), stop=["\n"])
mock_client_with_chat.chat.create.assert_called_once_with(
model="j2-ultra",
messages=[ChatMessage(role=RoleType.USER, text=chat_input)],
system="",
stop_sequences=["\n"],
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
)
def test_generate(mock_client_with_chat: Mock) -> None:
messages0 = [
HumanMessage(content="I'm Pickle Rick"),
AIMessage(content="Hello Pickle Rick! I am your AI Assistant"),
HumanMessage(content="Nice to meet you."),
]
messages1 = [
SystemMessage(content="system message"),
HumanMessage(content="What is 1 + 1"),
]
llm = ChatAI21(
model="j2-ultra",
client=mock_client_with_chat,
**BASIC_EXAMPLE_CHAT_PARAMETERS, # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
)
llm.generate(messages=[messages0, messages1])
mock_client_with_chat.chat.create.assert_has_calls(
[
call(
model="j2-ultra",
messages=[
ChatMessage(
role=RoleType.USER,
text=str(messages0[0].content),
),
ChatMessage(
role=RoleType.ASSISTANT, text=str(messages0[1].content)
),
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
],
system="",
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
),
call(
model="j2-ultra",
messages=[
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
],
system="system message",
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
),
]
)
def test_api_key_is_secret_string() -> None:
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") # type: ignore[call-arg, arg-type]
assert isinstance(llm.api_key, SecretStr)
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
llm = ChatAI21(model="j2-ultra") # type: ignore[call-arg]
print(llm.api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") # type: ignore[call-arg, arg-type]
print(llm.api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") # type: ignore[call-arg, arg-type]
assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"

View File

@ -1,109 +0,0 @@
from unittest.mock import Mock
import pytest
from langchain_core.documents import Document
from langchain_ai21 import AI21ContextualAnswers
from langchain_ai21.contextual_answers import ContextualAnswerInput
from tests.unit_tests.conftest import DUMMY_API_KEY
@pytest.mark.parametrize(
ids=[
"when_no_context__should_raise_exception",
"when_no_question__should_raise_exception",
"when_question_is_an_empty_string__should_raise_exception",
"when_context_is_an_empty_string__should_raise_exception",
"when_context_is_an_empty_list",
],
argnames="input",
argvalues=[
({"question": "What is the capital of France?"}),
({"context": "Paris is the capital of France"}),
({"question": "", "context": "Paris is the capital of France"}),
({"context": "", "question": "some question?"}),
({"context": [], "question": "What is the capital of France?"}),
],
)
def test_invoke__on_bad_input(
input: ContextualAnswerInput,
mock_client_with_contextual_answers: Mock,
) -> None:
tsm = AI21ContextualAnswers(
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
client=mock_client_with_contextual_answers, # type: ignore[arg-type]
)
with pytest.raises(ValueError) as error:
tsm.invoke(input)
assert (
error.value.args[0]
== f"Input must contain a 'context' and 'question' fields. Got {input}"
)
@pytest.mark.parametrize(
ids=[
"when_context_is_not_str_or_list_of_docs_or_str",
],
argnames="input",
argvalues=[
({"context": 1242, "question": "What is the capital of France?"}),
],
)
def test_invoke__on_context_bad_input(
input: ContextualAnswerInput, mock_client_with_contextual_answers: Mock
) -> None:
tsm = AI21ContextualAnswers(
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
client=mock_client_with_contextual_answers,
)
with pytest.raises(ValueError) as error:
tsm.invoke(input)
assert (
error.value.args[0] == f"Expected input to be a list of strings or Documents."
f" Received {type(input)}"
)
@pytest.mark.parametrize(
ids=[
"when_context_is_a_list_of_strings",
"when_context_is_a_list_of_documents",
"when_context_is_a_string",
],
argnames="input",
argvalues=[
(
{
"context": ["Paris is the capital of france"],
"question": "What is the capital of France?",
}
),
(
{
"context": [Document(page_content="Paris is the capital of france")],
"question": "What is the capital of France?",
}
),
(
{
"context": "Paris is the capital of france",
"question": "What is the capital of France?",
}
),
],
)
def test_invoke__on_good_input(
input: ContextualAnswerInput, mock_client_with_contextual_answers: Mock
) -> None:
tsm = AI21ContextualAnswers(
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
client=mock_client_with_contextual_answers,
)
response = tsm.invoke(input)
assert isinstance(response, str)

View File

@ -1,102 +0,0 @@
"""Test embedding model integration."""
from typing import List
from unittest.mock import Mock
import pytest
from ai21 import AI21Client, MissingApiKeyError
from ai21.models import EmbedResponse, EmbedResult, EmbedType
from pytest_mock import MockerFixture
from langchain_ai21.embeddings import AI21Embeddings
from tests.unit_tests.conftest import DUMMY_API_KEY, temporarily_unset_api_key
_EXAMPLE_EMBEDDING_0 = [1.0, 2.0, 3.0]
_EXAMPLE_EMBEDDING_1 = [4.0, 5.0, 6.0]
_EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
results=[
EmbedResult(embedding=_EXAMPLE_EMBEDDING_0),
EmbedResult(embedding=_EXAMPLE_EMBEDDING_1),
EmbedResult(embedding=_EXAMPLE_EMBEDDING_2),
],
id="test_id",
)
def test_initialization__when_no_api_key__should_raise_exception() -> None:
"""Test integration initialization."""
with temporarily_unset_api_key():
with pytest.raises(MissingApiKeyError):
AI21Embeddings()
@pytest.fixture
def mock_client_with_embeddings(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.embed = mocker.MagicMock()
mock_client.embed.create.return_value = _EXAMPLE_EMBEDDING_RESPONSE
return mock_client
def test_embed_query(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY) # type: ignore[arg-type]
text = "Hello embeddings world!"
response = llm.embed_query(text=text)
assert response == _EXAMPLE_EMBEDDING_0
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=[text],
type=EmbedType.QUERY,
)
def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY) # type: ignore[arg-type]
texts = ["Hello embeddings world!", "Some other text", "Some more text"]
response = llm.embed_documents(texts=texts)
assert response == [
_EXAMPLE_EMBEDDING_0,
_EXAMPLE_EMBEDDING_1,
_EXAMPLE_EMBEDDING_2,
]
mock_client_with_embeddings.embed.create.assert_called_once_with(
texts=texts,
type=EmbedType.SEGMENT,
)
@pytest.mark.parametrize(
ids=[
"empty_texts",
"chunk_size_greater_than_texts_length",
"chunk_size_equal_to_texts_length",
"chunk_size_less_than_texts_length",
"chunk_size_one_with_multiple_texts",
"chunk_size_greater_than_texts_length",
],
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
argvalues=[
([], 3, 0),
(["text1", "text2", "text3"], 5, 1),
(["text1", "text2", "text3"], 3, 1),
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
(["text1", "text2", "text3"], 1, 3),
(["text1", "text2", "text3"], 10, 1),
],
)
def test_get_len_safe_embeddings(
mock_client_with_embeddings: Mock,
texts: List[str],
chunk_size: int,
expected_internal_embeddings_calls: int,
) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY) # type: ignore[arg-type]
llm.embed_documents(texts=texts, batch_size=chunk_size)
assert (
mock_client_with_embeddings.embed.create.call_count
== expected_internal_embeddings_calls
)

View File

@ -1,13 +0,0 @@
from langchain_ai21 import __all__
EXPECTED_ALL = [
"AI21LLM",
"ChatAI21",
"AI21Embeddings",
"AI21ContextualAnswers",
"AI21SemanticTextSplitter",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -1,146 +0,0 @@
"""Test AI21 Chat API wrapper."""
from typing import cast
from unittest.mock import Mock, call
import pytest
from ai21 import MissingApiKeyError
from ai21.models import (
Penalty,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_ai21 import AI21LLM
from tests.unit_tests.conftest import (
BASIC_EXAMPLE_LLM_PARAMETERS,
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
DUMMY_API_KEY,
temporarily_unset_api_key,
)
def test_initialization__when_no_api_key__should_raise_exception() -> None:
"""Test integration initialization."""
with temporarily_unset_api_key():
with pytest.raises(MissingApiKeyError):
AI21LLM(
model="j2-ultra",
)
def test_initialization__when_default_parameters() -> None:
"""Test integration initialization."""
AI21LLM(
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
model="j2-ultra",
)
def test_initialization__when_custom_parameters_to_init() -> None:
"""Test integration initialization."""
AI21LLM( # type: ignore[call-arg]
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
model="j2-mid",
num_results=2,
max_tokens=20,
min_tokens=10,
temperature=0.5,
top_p=0.5,
top_k_return=0,
stop_sequences=["\n"],
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
count_penalty=Penalty( # type: ignore[call-arg]
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
),
custom_model="test_model",
epoch=1,
)
def test_generate(mock_client_with_completion: Mock) -> None:
# Setup test
prompt0 = "Hi, my name is what?"
prompt1 = "My name is who?"
stop = ["\n"]
custom_model = "test_model"
epoch = 1
ai21 = AI21LLM(
model="j2-ultra",
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
client=mock_client_with_completion,
custom_model=custom_model,
epoch=epoch,
**BASIC_EXAMPLE_LLM_PARAMETERS, # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
)
# Make call to testing function
ai21.generate(
[prompt0, prompt1],
stop=stop,
)
# Assertions
mock_client_with_completion.count_tokens.assert_has_calls(
[
call(prompt0),
call(prompt1),
],
)
mock_client_with_completion.completion.create.assert_has_calls(
[
call(
prompt=prompt0,
model="j2-ultra",
custom_model=custom_model,
stop_sequences=stop,
epoch=epoch,
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
),
call(
prompt=prompt1,
model="j2-ultra",
custom_model=custom_model,
stop_sequences=stop,
epoch=epoch,
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
),
]
)
def test_api_key_is_secret_string() -> None:
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") # type: ignore[arg-type]
assert isinstance(llm.api_key, SecretStr)
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
llm = AI21LLM(model="j2-ultra")
print(llm.api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") # type: ignore[arg-type]
print(llm.api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") # type: ignore[arg-type]
assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"

View File

@ -1,129 +0,0 @@
from unittest.mock import Mock
import pytest
from langchain_ai21 import AI21SemanticTextSplitter
from tests.unit_tests.conftest import SEGMENTS
TEXT = (
"The original full name of the franchise is Pocket Monsters (ポケットモンスター, "
"Poketto Monsutā), which was abbreviated to "
"Pokemon during development of the original games.\n"
"When the franchise was released internationally, the short form of the title was "
"used, with an acute accent (´) "
"over the e to aid in pronunciation.\n"
"Pokémon refers to both the franchise itself and the creatures within its "
"fictional universe.\n"
"As a noun, it is identical in both the singular and plural, as is every "
"individual species name;[10] it is "
'grammatically correct to say "one Pokémon" and "many Pokémon", as well '
'as "one Pikachu" and "many Pikachu".\n'
"In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or "
"/'powkɪmon/ (poe-key-mon).\n"
"The Pokémon franchise is set in a world in which humans coexist with creatures "
"known as Pokémon.\n"
"Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced "
"in subsequent games; as of December 2023, 1,025 Pokémon species have been "
"introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example,"
"Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] "
"that possess electrical abilities.[15]"
)
@pytest.mark.parametrize(
ids=[
"when_chunk_size_is_zero",
"when_chunk_size_is_large",
"when_chunk_size_is_small",
],
argnames=["chunk_size", "expected_segmentation_len"],
argvalues=[
(0, 2),
(1000, 1),
(10, 2),
],
)
def test_split_text__on_chunk_size(
chunk_size: int,
expected_segmentation_len: int,
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(
chunk_size=chunk_size,
client=mock_client_with_semantic_text_splitter,
)
segments = sts.split_text("This is a test")
assert len(segments) == expected_segmentation_len
def test_split_text__on_large_chunk_size__should_merge_chunks(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts_no_merge = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter
)
sts_merge = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter,
chunk_size=1000,
)
segments_no_merge = sts_no_merge.split_text("This is a test")
segments_merge = sts_merge.split_text("This is a test")
assert len(segments_merge) > 0
assert len(segments_no_merge) > 0
assert len(segments_no_merge) > len(segments_merge)
def test_split_text__on_small_chunk_size__should_not_merge_chunks(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts_no_merge = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter
)
segments = sts_no_merge.split_text("This is a test")
assert len(segments) == 2
for index in range(2):
assert segments[index] == SEGMENTS[index].segment_text
def test_create_documents__on_start_index__should_should_add_start_index(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(
client=mock_client_with_semantic_text_splitter,
add_start_index=True,
)
response = sts.create_documents(texts=[TEXT])
assert len(response) > 0
for segment in response:
assert segment.page_content is not None
assert segment.metadata is not None
assert "start_index" in segment.metadata
assert segment.metadata["start_index"] > -1
def test_create_documents__when_metadata_from_user__should_add_metadata(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter)
metadatas = [{"hello": "world"}]
response = sts.create_documents(texts=[TEXT], metadatas=metadatas)
assert len(response) > 0
for index in range(len(response)):
assert response[index].page_content == SEGMENTS[index].segment_text
assert len(response[index].metadata) == 2
assert response[index].metadata["source_type"] == SEGMENTS[index].segment_type
assert response[index].metadata["hello"] == "world"
def test_split_text_to_documents__when_metadata_not_passed__should_contain_source_type(
mock_client_with_semantic_text_splitter: Mock,
) -> None:
sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter)
response = sts.split_text_to_documents(TEXT)
assert len(response) > 0
for segment in response:
assert segment.page_content is not None
assert segment.metadata is not None
assert "source_type" in segment.metadata
assert segment.metadata["source_type"] is not None

View File

@ -1,36 +0,0 @@
"""Standard LangChain interface tests"""
from typing import Type
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
ChatModelUnitTests, # type: ignore[import-not-found]
)
from langchain_ai21 import ChatAI21
class TestAI21J2(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@property
def chat_model_params(self) -> dict:
return {
"model": "j2-ultra",
"api_key": "test_api_key",
}
class TestAI21Jamba(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@property
def chat_model_params(self) -> dict:
return {
"model": "jamba-instruct",
"api_key": "test_api_key",
}

View File

@ -1,29 +0,0 @@
from typing import List
import pytest
from langchain_ai21.embeddings import _split_texts_into_batches
@pytest.mark.parametrize(
ids=[
"when_chunk_size_is_2__should_return_3_chunks",
"when_texts_is_empty__should_return_empty_list",
"when_chunk_size_is_1__should_return_10_chunks",
],
argnames=["input_texts", "chunk_size", "expected_output"],
argvalues=[
(["a", "b", "c", "d", "e"], 2, [["a", "b"], ["c", "d"], ["e"]]),
([], 3, []),
(
["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
1,
[["1"], ["2"], ["3"], ["4"], ["5"], ["6"], ["7"], ["8"], ["9"], ["10"]],
),
],
)
def test_chunked_text_generator(
input_texts: List[str], chunk_size: int, expected_output: List[List[str]]
) -> None:
result = list(_split_texts_into_batches(input_texts, chunk_size))
assert result == expected_output

View File

@ -1 +0,0 @@
__pycache__

View File

@ -1,3 +0,0 @@
This package has moved!
https://github.com/langchain-ai/langchain-upstage/tree/main/libs/upstage