mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
ai21: migrate to external repo (#25827)
This commit is contained in:
parent
095b712a26
commit
8fb594fd2a
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,55 +0,0 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
integration_test integration_tests: TEST_FILE = tests/integration_tests/
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/ai21 --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_ai21
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_ai21 -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -1,217 +1,3 @@
|
||||
# langchain-ai21
|
||||
This package has moved!
|
||||
|
||||
This package contains the LangChain integrations for [AI21](https://docs.ai21.com/) models and tools.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the AI21 partner package
|
||||
```bash
|
||||
pip install langchain-ai21
|
||||
```
|
||||
- Get an AI21 api key and set it as an environment variable (`AI21_API_KEY`)
|
||||
|
||||
|
||||
## Chat Models
|
||||
|
||||
This package contains the `ChatAI21` class, which is the recommended way to interface with AI21 chat models, including Jamba-Instruct
|
||||
and any Jurassic chat models.
|
||||
|
||||
To use, install the requirements and configure your environment.
|
||||
|
||||
```bash
|
||||
export AI21_API_KEY=your-api-key
|
||||
```
|
||||
|
||||
Then initialize
|
||||
|
||||
```python
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
|
||||
chat = ChatAI21(model="jamba-instruct")
|
||||
messages = [HumanMessage(content="Hello from AI21")]
|
||||
chat.invoke(messages)
|
||||
```
|
||||
|
||||
For a list of the supported models, see [this page](https://docs.ai21.com/reference/python-sdk#chat)
|
||||
|
||||
### Streaming in Chat
|
||||
Streaming is supported by the latest models. To use streaming, set the `streaming` parameter to `True` when initializing the model.
|
||||
|
||||
```python
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
|
||||
chat = ChatAI21(model="jamba-instruct", streaming=True)
|
||||
messages = [HumanMessage(content="Hello from AI21")]
|
||||
|
||||
response = chat.invoke(messages)
|
||||
```
|
||||
|
||||
or use the `stream` method directly
|
||||
|
||||
```python
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
|
||||
chat = ChatAI21(model="jamba-instruct")
|
||||
messages = [HumanMessage(content="Hello from AI21")]
|
||||
|
||||
for chunk in chat.stream(messages):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
|
||||
## LLMs
|
||||
You can use AI21's Jurassic generative AI models as LangChain LLMs.
|
||||
To use the newer Jamba model, use the [ChatAI21 chat model](#chat-models), which
|
||||
supports single-turn instruction/question answering capabilities.
|
||||
|
||||
```python
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_ai21 import AI21LLM
|
||||
|
||||
llm = AI21LLM(model="j2-ultra")
|
||||
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
chain = prompt | llm
|
||||
|
||||
question = "Which scientist discovered relativity?"
|
||||
print(chain.invoke({"question": question}))
|
||||
```
|
||||
|
||||
## Embeddings
|
||||
|
||||
You can use AI21's [embeddings model](https://docs.ai21.com/reference/embeddings-ref) as shown here:
|
||||
|
||||
### Query
|
||||
|
||||
```python
|
||||
from langchain_ai21 import AI21Embeddings
|
||||
|
||||
embeddings = AI21Embeddings()
|
||||
embeddings.embed_query("Hello! This is some query")
|
||||
```
|
||||
|
||||
### Document
|
||||
|
||||
```python
|
||||
from langchain_ai21 import AI21Embeddings
|
||||
|
||||
embeddings = AI21Embeddings()
|
||||
embeddings.embed_documents(["Hello! This is document 1", "And this is document 2!"])
|
||||
```
|
||||
|
||||
## Task-Specific Models
|
||||
|
||||
### Contextual Answers
|
||||
|
||||
You can use AI21's [contextual answers model](https://docs.ai21.com/reference/contextual-answers-ref) to parse
|
||||
given text and answer a question based entirely on the provided information.
|
||||
|
||||
This means that if the answer to your question is not in the document,
|
||||
the model will indicate it (instead of providing a false answer)
|
||||
|
||||
```python
|
||||
from langchain_ai21 import AI21ContextualAnswers
|
||||
|
||||
tsm = AI21ContextualAnswers()
|
||||
|
||||
response = tsm.invoke(input={"context": "Lots of information here", "question": "Your question about the context"})
|
||||
```
|
||||
You can also use it with chains and output parsers and vector DBs:
|
||||
```python
|
||||
from langchain_ai21 import AI21ContextualAnswers
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
tsm = AI21ContextualAnswers()
|
||||
chain = tsm | StrOutputParser()
|
||||
|
||||
response = chain.invoke(
|
||||
{"context": "Your context", "question": "Your question"},
|
||||
)
|
||||
```
|
||||
|
||||
## Text Splitters
|
||||
|
||||
### Semantic Text Splitter
|
||||
|
||||
You can use AI21's semantic [text segmentation model](https://docs.ai21.com/reference/text-segmentation-ref) to split a text into segments by topic.
|
||||
Text is split at each point where the topic changes.
|
||||
|
||||
For a list for examples, see [this page](https://github.com/langchain-ai/langchain/blob/master/docs/docs/modules/data_connection/document_transformers/semantic_text_splitter.ipynb).
|
||||
|
||||
```python
|
||||
from langchain_ai21 import AI21SemanticTextSplitter
|
||||
|
||||
splitter = AI21SemanticTextSplitter()
|
||||
response = splitter.split_text("Your text")
|
||||
```
|
||||
|
||||
## Tool calls
|
||||
|
||||
### Function calling
|
||||
|
||||
AI21 models incorporate the Function Calling feature to support custom user functions. The models generate structured
|
||||
data that includes the function name and proposed arguments. This data empowers applications to call external APIs and
|
||||
incorporate the resulting information into subsequent model prompts, enriching responses with real-time data and
|
||||
context. Through function calling, users can access and utilize various services like transportation APIs and financial
|
||||
data providers to obtain more accurate and relevant answers. Here is an example of how to use function calling
|
||||
with AI21 models in LangChain:
|
||||
|
||||
```python
|
||||
import os
|
||||
from getpass import getpass
|
||||
from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
os.environ["AI21_API_KEY"] = getpass()
|
||||
|
||||
@tool
|
||||
def get_weather(location: str, date: str) -> str:
|
||||
"""“Provide the weather for the specified location on the given date.”"""
|
||||
if location == "New York" and date == "2024-12-05":
|
||||
return "25 celsius"
|
||||
elif location == "New York" and date == "2024-12-06":
|
||||
return "27 celsius"
|
||||
elif location == "London" and date == "2024-12-05":
|
||||
return "22 celsius"
|
||||
return "32 celsius"
|
||||
|
||||
llm = ChatAI21(model="jamba-1.5-mini")
|
||||
|
||||
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
|
||||
|
||||
chat_messages = [SystemMessage(content="You are a helpful assistant. You can use the provided tools "
|
||||
"to assist with various tasks and provide accurate information")]
|
||||
|
||||
human_messages = [
|
||||
HumanMessage(content="What is the forecast for the weather in New York on December 5, 2024?"),
|
||||
HumanMessage(content="And what about the 2024-12-06?"),
|
||||
HumanMessage(content="OK, thank you."),
|
||||
HumanMessage(content="What is the expected weather in London on December 5, 2024?")]
|
||||
|
||||
|
||||
for human_message in human_messages:
|
||||
print(f"User: {human_message.content}")
|
||||
chat_messages.append(human_message)
|
||||
response = llm_with_tools.invoke(chat_messages)
|
||||
chat_messages.append(response)
|
||||
if response.tool_calls:
|
||||
tool_call = response.tool_calls[0]
|
||||
if tool_call["name"] == "get_weather":
|
||||
weather = get_weather.invoke(
|
||||
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]})
|
||||
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
|
||||
llm_answer = llm_with_tools.invoke(chat_messages)
|
||||
print(f"Assistant: {llm_answer.content}")
|
||||
else:
|
||||
print(f"Assistant: {response.content}")
|
||||
|
||||
```
|
||||
https://github.com/langchain-ai/langchain-ai21/tree/main/libs/ai21
|
@ -1,13 +0,0 @@
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
from langchain_ai21.contextual_answers import AI21ContextualAnswers
|
||||
from langchain_ai21.embeddings import AI21Embeddings
|
||||
from langchain_ai21.llms import AI21LLM
|
||||
from langchain_ai21.semantic_text_splitter import AI21SemanticTextSplitter
|
||||
|
||||
__all__ = [
|
||||
"AI21LLM",
|
||||
"ChatAI21",
|
||||
"AI21Embeddings",
|
||||
"AI21ContextualAnswers",
|
||||
"AI21SemanticTextSplitter",
|
||||
]
|
@ -1,64 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ai21 import AI21Client
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
|
||||
_DEFAULT_TIMEOUT_SEC = 300
|
||||
|
||||
|
||||
class AI21Base(BaseModel):
|
||||
"""Base class for AI21 models."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
api_key: Optional[SecretStr] = None
|
||||
"""API key for AI21 API."""
|
||||
api_host: Optional[str] = None
|
||||
"""Host URL"""
|
||||
timeout_sec: Optional[float] = None
|
||||
"""Timeout in seconds.
|
||||
|
||||
If not set, it will default to the value of the environment
|
||||
variable `AI21_TIMEOUT_SEC` or 300 seconds.
|
||||
"""
|
||||
num_retries: Optional[int] = None
|
||||
"""Maximum number of retries for API requests before giving up."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
api_key = convert_to_secret_str(
|
||||
values.get("api_key") or os.getenv("AI21_API_KEY") or ""
|
||||
)
|
||||
values["api_key"] = api_key
|
||||
|
||||
api_host = (
|
||||
values.get("api_host")
|
||||
or os.getenv("AI21_API_URL")
|
||||
or "https://api.ai21.com"
|
||||
)
|
||||
values["api_host"] = api_host
|
||||
|
||||
timeout_sec = values.get("timeout_sec") or float(
|
||||
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
|
||||
)
|
||||
values["timeout_sec"] = timeout_sec
|
||||
return values
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def post_init(cls, values: Dict) -> Dict:
|
||||
api_key = values["api_key"]
|
||||
api_host = values["api_host"]
|
||||
timeout_sec = values["timeout_sec"]
|
||||
if values.get("client") is None:
|
||||
values["client"] = AI21Client(
|
||||
api_key=api_key.get_secret_value(),
|
||||
api_host=api_host,
|
||||
timeout_sec=None if timeout_sec is None else float(timeout_sec),
|
||||
via="langchain",
|
||||
)
|
||||
|
||||
return values
|
@ -1,324 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Union, cast, overload
|
||||
|
||||
from ai21.models import ChatMessage as J2ChatMessage
|
||||
from ai21.models import RoleType
|
||||
from ai21.models.chat import (
|
||||
AssistantMessage as AI21AssistantMessage,
|
||||
)
|
||||
from ai21.models.chat import ChatCompletionChunk, ChatMessageParam
|
||||
from ai21.models.chat import ChatMessage as AI21ChatMessage
|
||||
from ai21.models.chat import SystemMessage as AI21SystemMessage
|
||||
from ai21.models.chat import ToolCall as AI21ToolCall
|
||||
from ai21.models.chat import ToolFunction as AI21ToolFunction
|
||||
from ai21.models.chat import ToolMessage as AI21ToolMessage
|
||||
from ai21.models.chat import UserMessage as AI21UserMessage
|
||||
from ai21.stream.stream import Stream as AI21Stream
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.output_parsers.openai_tools import parse_tool_call
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
_ChatMessageTypes = Union[AI21ChatMessage, J2ChatMessage]
|
||||
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
|
||||
_ROLE_TYPE = Union[str, RoleType]
|
||||
|
||||
|
||||
class ChatAdapter(ABC):
|
||||
"""Common interface for the different Chat models available in AI21.
|
||||
|
||||
It converts LangChain messages to AI21 messages.
|
||||
Calls the appropriate AI21 model API with the converted messages.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def _convert_message_to_ai21_message(
|
||||
self,
|
||||
message: BaseMessage,
|
||||
) -> _ChatMessageTypes:
|
||||
role = self._parse_role(message)
|
||||
return self._chat_message(role=role, message=message)
|
||||
|
||||
def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE:
|
||||
role = None
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
return RoleType.SYSTEM
|
||||
|
||||
if isinstance(message, HumanMessage):
|
||||
return RoleType.USER
|
||||
|
||||
if isinstance(message, AIMessage):
|
||||
return RoleType.ASSISTANT
|
||||
|
||||
if isinstance(message, ToolMessage):
|
||||
return RoleType.TOOL
|
||||
|
||||
if isinstance(self, J2ChatAdapter):
|
||||
if not role:
|
||||
raise ValueError(
|
||||
f"Could not resolve role type from message {message}. "
|
||||
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
|
||||
)
|
||||
|
||||
# if it gets here, we rely on the server to handle the role type
|
||||
return message.type
|
||||
|
||||
@abstractmethod
|
||||
def _chat_message(
|
||||
self,
|
||||
role: _ROLE_TYPE,
|
||||
message: BaseMessage,
|
||||
) -> _ChatMessageTypes:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True],
|
||||
**params: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True] | Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
|
||||
pass
|
||||
|
||||
def _get_system_message_from_message(self, message: BaseMessage) -> str:
|
||||
if not isinstance(message.content, str):
|
||||
raise ValueError(
|
||||
f"System Message must be of type str. Got {type(message.content)}"
|
||||
)
|
||||
|
||||
return message.content
|
||||
|
||||
|
||||
class J2ChatAdapter(ChatAdapter):
|
||||
"""Adapter for J2Chat models."""
|
||||
|
||||
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
system_message = ""
|
||||
converted_messages = [] # type: ignore
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if message.type == "system":
|
||||
if i != 0:
|
||||
raise ValueError(_SYSTEM_ERR_MESSAGE)
|
||||
else:
|
||||
system_message = self._get_system_message_from_message(message)
|
||||
else:
|
||||
converted_message = self._convert_message_to_ai21_message(message)
|
||||
converted_messages.append(converted_message)
|
||||
|
||||
return {"system": system_message, "messages": converted_messages}
|
||||
|
||||
def _chat_message(
|
||||
self,
|
||||
role: _ROLE_TYPE,
|
||||
message: BaseMessage,
|
||||
) -> J2ChatMessage:
|
||||
return J2ChatMessage(role=RoleType(role), text=cast(str, message.content))
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True],
|
||||
**params: Any,
|
||||
) -> Iterator[ChatGenerationChunk]: ...
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage]: ...
|
||||
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True] | Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
|
||||
if stream:
|
||||
raise NotImplementedError("Streaming is not supported for Jurassic models.")
|
||||
|
||||
response = client.chat.create(**params)
|
||||
|
||||
return [AIMessage(output.text) for output in response.outputs]
|
||||
|
||||
|
||||
class JambaChatCompletionsAdapter(ChatAdapter):
|
||||
"""Adapter for Jamba Chat Completions."""
|
||||
|
||||
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
return {
|
||||
"messages": [
|
||||
self._convert_message_to_ai21_message(message) for message in messages
|
||||
],
|
||||
}
|
||||
|
||||
def _convert_lc_tool_calls_to_ai21_tool_calls(
|
||||
self, tool_calls: List[ToolCall]
|
||||
) -> Optional[List[AI21ToolCall]]:
|
||||
"""
|
||||
Convert Langchain ToolCalls to AI21 ToolCalls.
|
||||
"""
|
||||
ai21_tool_calls: List[AI21ToolCall] = []
|
||||
for lc_tool_call in tool_calls:
|
||||
if "id" not in lc_tool_call or not lc_tool_call["id"]:
|
||||
raise ValueError("Tool call ID is missing or empty.")
|
||||
|
||||
ai21_tool_call = AI21ToolCall(
|
||||
id=lc_tool_call["id"],
|
||||
type="function",
|
||||
function=AI21ToolFunction(
|
||||
name=lc_tool_call["name"],
|
||||
arguments=str(lc_tool_call["args"]),
|
||||
),
|
||||
)
|
||||
ai21_tool_calls.append(ai21_tool_call)
|
||||
|
||||
return ai21_tool_calls
|
||||
|
||||
def _get_content_as_string(self, base_message: BaseMessage) -> str:
|
||||
if isinstance(base_message.content, str):
|
||||
return base_message.content
|
||||
elif isinstance(base_message.content, list):
|
||||
return "\n".join(str(item) for item in base_message.content)
|
||||
else:
|
||||
raise ValueError("Unsupported content type")
|
||||
|
||||
def _chat_message(
|
||||
self,
|
||||
role: _ROLE_TYPE,
|
||||
message: BaseMessage,
|
||||
) -> ChatMessageParam:
|
||||
content = self._get_content_as_string(message)
|
||||
|
||||
if isinstance(message, AIMessage):
|
||||
return AI21AssistantMessage(
|
||||
tool_calls=self._convert_lc_tool_calls_to_ai21_tool_calls(
|
||||
message.tool_calls
|
||||
),
|
||||
content=content or None,
|
||||
)
|
||||
if isinstance(message, ToolMessage):
|
||||
return AI21ToolMessage(
|
||||
tool_call_id=message.tool_call_id,
|
||||
content=content,
|
||||
)
|
||||
if isinstance(message, HumanMessage):
|
||||
return AI21UserMessage(
|
||||
content=content,
|
||||
)
|
||||
if isinstance(message, SystemMessage):
|
||||
return AI21SystemMessage(
|
||||
content=content,
|
||||
)
|
||||
return AI21ChatMessage(
|
||||
role=role.value if isinstance(role, RoleType) else role,
|
||||
content=content,
|
||||
)
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True],
|
||||
**params: Any,
|
||||
) -> Iterator[ChatGenerationChunk]: ...
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage]: ...
|
||||
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True] | Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
|
||||
response = client.chat.completions.create(stream=stream, **params)
|
||||
|
||||
if stream:
|
||||
return self._stream_response(response)
|
||||
|
||||
ai_messages: List[BaseMessage] = []
|
||||
for message in response.choices:
|
||||
if message.message.tool_calls:
|
||||
tool_calls = [
|
||||
parse_tool_call(tool_call.model_dump(), return_id=True)
|
||||
for tool_call in message.message.tool_calls
|
||||
]
|
||||
ai_messages.append(AIMessage("", tool_calls=tool_calls))
|
||||
else:
|
||||
ai_messages.append(AIMessage(message.message.content))
|
||||
|
||||
return ai_messages
|
||||
|
||||
def _stream_response(
|
||||
self,
|
||||
response: AI21Stream[ChatCompletionChunk],
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
for chunk in response:
|
||||
converted_message = self._convert_ai21_chunk_to_chunk(chunk)
|
||||
yield ChatGenerationChunk(message=converted_message)
|
||||
|
||||
def _convert_ai21_chunk_to_chunk(
|
||||
self,
|
||||
chunk: ChatCompletionChunk,
|
||||
) -> BaseMessageChunk:
|
||||
usage = chunk.usage
|
||||
content = chunk.choices[0].delta.content or ""
|
||||
|
||||
if usage is None:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
)
|
||||
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
usage_metadata=UsageMetadata(
|
||||
input_tokens=usage.prompt_tokens,
|
||||
output_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
),
|
||||
)
|
@ -1,23 +0,0 @@
|
||||
from langchain_ai21.chat.chat_adapter import (
|
||||
ChatAdapter,
|
||||
J2ChatAdapter,
|
||||
JambaChatCompletionsAdapter,
|
||||
)
|
||||
|
||||
|
||||
def create_chat_adapter(model: str) -> ChatAdapter:
|
||||
"""Create a chat adapter based on the model.
|
||||
|
||||
Args:
|
||||
model: The model to create the chat adapter for.
|
||||
|
||||
Returns:
|
||||
The chat adapter.
|
||||
"""
|
||||
if "j2" in model:
|
||||
return J2ChatAdapter()
|
||||
|
||||
if "jamba" in model:
|
||||
return JambaChatCompletionsAdapter()
|
||||
|
||||
raise ValueError(f"Model {model} not supported.")
|
@ -1,271 +0,0 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
from langchain_ai21.chat.chat_adapter import ChatAdapter
|
||||
from langchain_ai21.chat.chat_factory import create_chat_adapter
|
||||
|
||||
|
||||
class ChatAI21(BaseChatModel, AI21Base):
|
||||
"""ChatAI21 chat model. Different model types support different parameters and
|
||||
different parameter values. Please read the [AI21 reference documentation]
|
||||
(https://docs.ai21.com/reference) for your model to understand which parameters
|
||||
are available.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_ai21 import ChatAI21
|
||||
|
||||
|
||||
model = ChatAI21(
|
||||
# defaults to os.environ.get("AI21_API_KEY")
|
||||
api_key="my_api_key"
|
||||
)
|
||||
"""
|
||||
|
||||
model: str
|
||||
"""Model type you wish to interact with.
|
||||
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
|
||||
num_results: int = 1
|
||||
"""The number of responses to generate for a given prompt."""
|
||||
stop: Optional[List[str]] = None
|
||||
"""Default stop sequences."""
|
||||
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate for each response."""
|
||||
|
||||
min_tokens: int = 0
|
||||
"""The minimum number of tokens to generate for each response.
|
||||
_Not supported for all models._"""
|
||||
|
||||
temperature: float = 0.4
|
||||
"""A value controlling the "creativity" of the model's responses."""
|
||||
|
||||
top_p: float = 1
|
||||
"""A value controlling the diversity of the model's responses."""
|
||||
|
||||
top_k_return: int = 0
|
||||
"""The number of top-scoring tokens to consider for each generation step.
|
||||
_Not supported for all models._"""
|
||||
|
||||
frequency_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens that are frequently generated.
|
||||
_Not supported for all models._"""
|
||||
|
||||
presence_penalty: Optional[Any] = None
|
||||
""" A penalty applied to tokens that are already present in the prompt.
|
||||
_Not supported for all models._"""
|
||||
|
||||
count_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens based on their frequency
|
||||
in the generated responses. _Not supported for all models._"""
|
||||
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
streaming: bool = False
|
||||
|
||||
_chat_adapter: ChatAdapter
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate the environment."""
|
||||
model = values["model"]
|
||||
values["_chat_adapter"] = create_chat_adapter(model)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "chat-ai21"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
base_params = {
|
||||
"model": self.model,
|
||||
"num_results": self.num_results,
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k_return": self.top_k_return,
|
||||
"n": self.n,
|
||||
}
|
||||
if self.stop:
|
||||
base_params["stop_sequences"] = self.stop
|
||||
|
||||
if self.count_penalty is not None:
|
||||
base_params["count_penalty"] = self.count_penalty.to_dict()
|
||||
|
||||
if self.frequency_penalty is not None:
|
||||
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
|
||||
|
||||
if self.presence_penalty is not None:
|
||||
base_params["presence_penalty"] = self.presence_penalty.to_dict()
|
||||
|
||||
return base_params
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="ai21",
|
||||
ls_model_name=self.model,
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None) or self.stop:
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
return ls_params
|
||||
|
||||
def _build_params_for_request(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Mapping[str, Any]:
|
||||
params = {}
|
||||
converted_messages = self._chat_adapter.convert_messages(messages)
|
||||
|
||||
if stop is not None:
|
||||
if "stop" in kwargs:
|
||||
raise ValueError("stop is defined in both stop and kwargs")
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
return {
|
||||
**converted_messages,
|
||||
**self._default_params,
|
||||
**params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream or self.streaming
|
||||
|
||||
if should_stream:
|
||||
return self._handle_stream_from_generate(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
params = self._build_params_for_request(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
stream=should_stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
messages = self._chat_adapter.call(self.client, **params)
|
||||
generations = [ChatGeneration(message=message) for message in messages]
|
||||
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
def _handle_stream_from_generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
stream_iter = self._stream(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._build_params_for_request(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for chunk in self._chat_adapter.call(self.client, **params):
|
||||
if run_manager and isinstance(chunk.message.content, str):
|
||||
run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
||||
)
|
||||
|
||||
def _get_system_message_from_message(self, message: BaseMessage) -> str:
|
||||
if not isinstance(message.content, str):
|
||||
raise ValueError(
|
||||
f"System Message must be of type str. Got {type(message.content)}"
|
||||
)
|
||||
|
||||
return message.content
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
@ -1,112 +0,0 @@
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
|
||||
ANSWER_NOT_IN_CONTEXT_RESPONSE = "Answer not in context"
|
||||
|
||||
ContextType = Union[str, List[Union[Document, str]]]
|
||||
|
||||
|
||||
class ContextualAnswerInput(TypedDict):
|
||||
"""Input for the ContextualAnswers runnable."""
|
||||
|
||||
context: ContextType
|
||||
question: str
|
||||
|
||||
|
||||
class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base):
|
||||
"""Runnable for the AI21 Contextual Answers API."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[ContextualAnswerInput]:
|
||||
"""Get the input type for this runnable."""
|
||||
return ContextualAnswerInput
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[str]:
|
||||
"""Get the input type for this runnable."""
|
||||
return str
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: ContextualAnswerInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
response_if_no_answer_found: str = ANSWER_NOT_IN_CONTEXT_RESPONSE,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = ensure_config(config)
|
||||
return self._call_with_config(
|
||||
func=lambda inner_input: self._call_contextual_answers(
|
||||
inner_input, response_if_no_answer_found
|
||||
),
|
||||
input=input,
|
||||
config=config,
|
||||
run_type="llm",
|
||||
)
|
||||
|
||||
def _call_contextual_answers(
|
||||
self,
|
||||
input: ContextualAnswerInput,
|
||||
response_if_no_answer_found: str,
|
||||
) -> str:
|
||||
context, question = self._convert_input(input)
|
||||
response = self.client.answer.create(context=context, question=question)
|
||||
|
||||
if response.answer is None:
|
||||
return response_if_no_answer_found
|
||||
|
||||
return response.answer
|
||||
|
||||
def _convert_input(self, input: ContextualAnswerInput) -> Tuple[str, str]:
|
||||
context, question = self._extract_context_and_question(input)
|
||||
|
||||
context = self._parse_context(context)
|
||||
|
||||
return context, question
|
||||
|
||||
def _extract_context_and_question(
|
||||
self,
|
||||
input: ContextualAnswerInput,
|
||||
) -> Tuple[ContextType, str]:
|
||||
context = input.get("context")
|
||||
question = input.get("question")
|
||||
|
||||
if not context or not question:
|
||||
raise ValueError(
|
||||
f"Input must contain a 'context' and 'question' fields. Got {input}"
|
||||
)
|
||||
|
||||
if not isinstance(context, list) and not isinstance(context, str):
|
||||
raise ValueError(
|
||||
f"Expected input to be a list of strings or Documents."
|
||||
f" Received {type(input)}"
|
||||
)
|
||||
|
||||
return context, question
|
||||
|
||||
def _parse_context(self, context: ContextType) -> str:
|
||||
if isinstance(context, str):
|
||||
return context
|
||||
|
||||
docs = [
|
||||
item.page_content if isinstance(item, Document) else item
|
||||
for item in context
|
||||
]
|
||||
|
||||
return "\n".join(docs)
|
@ -1,126 +0,0 @@
|
||||
from itertools import islice
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from ai21.models import EmbedType
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
|
||||
_DEFAULT_BATCH_SIZE = 128
|
||||
|
||||
|
||||
def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
|
||||
texts_itr = iter(texts)
|
||||
return iter(lambda: list(islice(texts_itr, batch_size)), [])
|
||||
|
||||
|
||||
class AI21Embeddings(Embeddings, AI21Base):
|
||||
"""AI21 embedding model integration.
|
||||
|
||||
Install ``langchain_ai21`` and set environment variable ``AI21_API_KEY``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain_ai21
|
||||
export AI21_API_KEY="your-api-key"
|
||||
|
||||
Key init args — client params:
|
||||
api_key: Optional[SecretStr]
|
||||
batch_size: int
|
||||
The number of texts that will be sent to the API in each batch.
|
||||
Use larger batch sizes if working with many short texts. This will reduce
|
||||
the number of API calls made, and can improve the time it takes to embed
|
||||
a large number of texts.
|
||||
num_retries: Optional[int]
|
||||
Maximum number of retries for API requests before giving up.
|
||||
timeout_sec: Optional[float]
|
||||
Timeout in seconds for API requests. If not set, it will default to the
|
||||
value of the environment variable `AI21_TIMEOUT_SEC` or 300 seconds.
|
||||
|
||||
See full list of supported init args and their descriptions in the params section.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_ai21 import AI21Embeddings
|
||||
|
||||
embed = AI21Embeddings(
|
||||
# api_key="...",
|
||||
# batch_size=128,
|
||||
)
|
||||
|
||||
Embed single text:
|
||||
.. code-block:: python
|
||||
|
||||
input_text = "The meaning of life is 42"
|
||||
vector = embed.embed_query(input_text)
|
||||
print(vector[:3])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
|
||||
|
||||
Embed multiple texts:
|
||||
.. code-block:: python
|
||||
|
||||
input_texts = ["Document 1...", "Document 2..."]
|
||||
vectors = embed.embed_documents(input_texts)
|
||||
print(len(vectors))
|
||||
# The first 3 coordinates for the first vector
|
||||
print(vectors[0][:3])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
2
|
||||
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
|
||||
"""
|
||||
|
||||
batch_size: int = _DEFAULT_BATCH_SIZE
|
||||
"""Maximum number of texts to embed in each batch"""
|
||||
|
||||
def embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return self._send_embeddings(
|
||||
texts=texts,
|
||||
batch_size=batch_size or self.batch_size,
|
||||
embed_type=EmbedType.SEGMENT,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self._send_embeddings(
|
||||
texts=[text],
|
||||
batch_size=batch_size or self.batch_size,
|
||||
embed_type=EmbedType.QUERY,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
def _send_embeddings(
|
||||
self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
chunks = _split_texts_into_batches(texts, batch_size)
|
||||
responses = [
|
||||
self.client.embed.create(
|
||||
texts=chunk,
|
||||
type=embed_type,
|
||||
**kwargs,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
return [
|
||||
result.embedding for response in responses for result in response.results
|
||||
]
|
@ -1,188 +0,0 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from ai21.models import CompletionsResponse
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
|
||||
|
||||
class AI21LLM(BaseLLM, AI21Base):
|
||||
"""AI21 large language models. Different model types support different parameters
|
||||
and different parameter values. Please read the [AI21 reference documentation]
|
||||
(https://docs.ai21.com/reference) for your model to understand which parameters
|
||||
are available.
|
||||
|
||||
AI21LLM supports only the older Jurassic models.
|
||||
We recommend using ChatAI21 with the newest models, for better results and more
|
||||
features.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_ai21 import AI21LLM
|
||||
|
||||
model = AI21LLM(
|
||||
# defaults to os.environ.get("AI21_API_KEY")
|
||||
api_key="my_api_key"
|
||||
)
|
||||
"""
|
||||
|
||||
model: str
|
||||
"""Model type you wish to interact with.
|
||||
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
|
||||
|
||||
num_results: int = 1
|
||||
"""The number of responses to generate for a given prompt."""
|
||||
|
||||
max_tokens: int = 16
|
||||
"""The maximum number of tokens to generate for each response."""
|
||||
|
||||
min_tokens: int = 0
|
||||
"""The minimum number of tokens to generate for each response.
|
||||
_Not supported for all models._"""
|
||||
|
||||
temperature: float = 0.7
|
||||
"""A value controlling the "creativity" of the model's responses."""
|
||||
|
||||
top_p: float = 1
|
||||
"""A value controlling the diversity of the model's responses."""
|
||||
|
||||
top_k_return: int = 0
|
||||
"""The number of top-scoring tokens to consider for each generation step.
|
||||
_Not supported for all models._"""
|
||||
|
||||
frequency_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens that are frequently generated.
|
||||
_Not supported for all models._"""
|
||||
|
||||
presence_penalty: Optional[Any] = None
|
||||
""" A penalty applied to tokens that are already present in the prompt.
|
||||
_Not supported for all models._"""
|
||||
|
||||
count_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens based on their frequency
|
||||
in the generated responses. _Not supported for all models._"""
|
||||
|
||||
custom_model: Optional[str] = None
|
||||
epoch: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of LLM."""
|
||||
return "ai21-llm"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
base_params = {
|
||||
"model": self.model,
|
||||
"num_results": self.num_results,
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k_return": self.top_k_return,
|
||||
}
|
||||
|
||||
if self.count_penalty is not None:
|
||||
base_params["count_penalty"] = self.count_penalty.to_dict()
|
||||
|
||||
if self.custom_model is not None:
|
||||
base_params["custom_model"] = self.custom_model
|
||||
|
||||
if self.epoch is not None:
|
||||
base_params["epoch"] = self.epoch
|
||||
|
||||
if self.frequency_penalty is not None:
|
||||
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
|
||||
|
||||
if self.presence_penalty is not None:
|
||||
base_params["presence_penalty"] = self.presence_penalty.to_dict()
|
||||
|
||||
return base_params
|
||||
|
||||
def _build_params_for_request(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Mapping[str, Any]:
|
||||
params = {}
|
||||
|
||||
if stop is not None:
|
||||
if "stop" in kwargs:
|
||||
raise ValueError("stop is defined in both stop and kwargs")
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
return {
|
||||
**self._default_params,
|
||||
**params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations: List[List[Generation]] = []
|
||||
token_count = 0
|
||||
|
||||
params = self._build_params_for_request(stop=stop, **kwargs)
|
||||
|
||||
for prompt in prompts:
|
||||
response = self._invoke_completion(prompt=prompt, **params)
|
||||
generation = self._response_to_generation(response)
|
||||
generations.append(generation)
|
||||
token_count += self.client.count_tokens(prompt)
|
||||
|
||||
llm_output = {"token_count": token_count, "model_name": self.model}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
# Change implementation if integration natively supports async generation.
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), prompts, stop, run_manager
|
||||
)
|
||||
|
||||
def _invoke_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> CompletionsResponse:
|
||||
return self.client.completion.create(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _response_to_generation(
|
||||
self, response: CompletionsResponse
|
||||
) -> List[Generation]:
|
||||
return [
|
||||
Generation(
|
||||
text=completion.data.text, # type: ignore[arg-type]
|
||||
generation_info=completion.to_dict(),
|
||||
)
|
||||
for completion in response.completions
|
||||
]
|
@ -1,158 +0,0 @@
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from ai21.models import DocumentType
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AI21SemanticTextSplitter(TextSplitter):
|
||||
"""Splitting text into coherent and readable units,
|
||||
based on distinct topics and lines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 0,
|
||||
chunk_overlap: int = 0,
|
||||
client: Optional[Any] = None,
|
||||
api_key: Optional[SecretStr] = None,
|
||||
api_host: Optional[str] = None,
|
||||
timeout_sec: Optional[float] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._segmentation = AI21Base(
|
||||
client=client,
|
||||
api_key=api_key,
|
||||
api_host=api_host,
|
||||
timeout_sec=timeout_sec,
|
||||
num_retries=num_retries,
|
||||
).client.segmentation
|
||||
|
||||
def split_text(self, source: str) -> List[str]:
|
||||
"""Split text into multiple components.
|
||||
|
||||
Args:
|
||||
source: Specifies the text input for text segmentation
|
||||
"""
|
||||
response = self._segmentation.create(
|
||||
source=source, source_type=DocumentType.TEXT
|
||||
)
|
||||
|
||||
segments = [segment.segment_text for segment in response.segments]
|
||||
|
||||
if self._chunk_size > 0:
|
||||
return self._merge_splits_no_seperator(segments)
|
||||
|
||||
return segments
|
||||
|
||||
def split_text_to_documents(self, source: str) -> List[Document]:
|
||||
"""Split text into multiple documents.
|
||||
|
||||
Args:
|
||||
source: Specifies the text input for text segmentation
|
||||
"""
|
||||
response = self._segmentation.create(
|
||||
source=source, source_type=DocumentType.TEXT
|
||||
)
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=segment.segment_text,
|
||||
metadata={"source_type": segment.segment_type},
|
||||
)
|
||||
for segment in response.segments
|
||||
]
|
||||
|
||||
def create_documents(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
) -> List[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
normalized_text = self._normalized_text(text)
|
||||
index = 0
|
||||
previous_chunk_len = 0
|
||||
|
||||
for chunk in self.split_text_to_documents(text):
|
||||
# merge metadata from user (if exists) and from segmentation api
|
||||
metadata = copy.deepcopy(_metadatas[i])
|
||||
metadata.update(chunk.metadata)
|
||||
|
||||
if self._add_start_index:
|
||||
# find the start index of the chunk
|
||||
offset = index + previous_chunk_len - self._chunk_overlap
|
||||
normalized_chunk = self._normalized_text(chunk.page_content)
|
||||
index = normalized_text.find(normalized_chunk, max(0, offset))
|
||||
metadata["start_index"] = index
|
||||
previous_chunk_len = len(normalized_chunk)
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=chunk.page_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
def _normalized_text(self, string: str) -> str:
|
||||
"""Use regular expression to replace sequences of '\n'"""
|
||||
return re.sub(r"\s+", "", string)
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
||||
"""This method overrides the default implementation of TextSplitter"""
|
||||
return self._merge_splits_no_seperator(splits)
|
||||
|
||||
def _merge_splits_no_seperator(self, splits: Iterable[str]) -> List[str]:
|
||||
"""Merge splits into chunks.
|
||||
If the segment size is bigger than chunk_size,
|
||||
it will be left as is (won't be cut to match to chunk_size).
|
||||
If the segment size is smaller than chunk_size,
|
||||
it will be merged with the next segment until the chunk_size is reached.
|
||||
"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for split in splits:
|
||||
split_len = self._length_function(split)
|
||||
|
||||
if split_len > self._chunk_size:
|
||||
logger.warning(
|
||||
f"Split of length {split_len}"
|
||||
f"exceeds chunk size {self._chunk_size}."
|
||||
)
|
||||
|
||||
if self._length_function(current_chunk) + split_len > self._chunk_size:
|
||||
if current_chunk != "":
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += split
|
||||
|
||||
if current_chunk != "":
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
1322
libs/partners/ai21/poetry.lock
generated
1322
libs/partners/ai21/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,90 +0,0 @@
|
||||
[build-system]
|
||||
requires = [ "poetry-core>=1.0.0",]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "langchain-ai21"
|
||||
version = "0.1.8"
|
||||
description = "An integration package connecting AI21 and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/ai21"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-ai21%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.2.4"
|
||||
langchain-text-splitters = "^0.2.0"
|
||||
ai21 = "^2.14.1"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I",]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
markers = [ "requires: mark tests as requiring a specific library", "asyncio: mark tests as requiring asyncio", "compile: mark placeholder test used to compile integration tests without running them", "scheduled: mark tests to run in scheduled testing",]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1.10"
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-core]
|
||||
path = "../../core"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-standard-tests]
|
||||
path = "../../standard-tests"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-text-splitters]
|
||||
path = "../../text-splitters"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies.langchain-core]
|
||||
path = "../../core"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.typing.dependencies.langchain-core]
|
||||
path = "../../core"
|
||||
develop = true
|
@ -1,17 +0,0 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_failure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -1,188 +0,0 @@
|
||||
"""Test ChatAI21 chat model."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
from tests.unit_tests.conftest import (
|
||||
J2_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
|
||||
JAMBA_CHAT_MODEL_NAME,
|
||||
JAMBA_FAMILY_MODEL_NAMES,
|
||||
)
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_j2_model",
|
||||
"when_jamba_model",
|
||||
"when_jamba1.5-mini_model",
|
||||
"when_jamba1.5-large_model",
|
||||
],
|
||||
argnames=["model"],
|
||||
argvalues=[
|
||||
(J2_CHAT_MODEL_NAME,),
|
||||
(JAMBA_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
|
||||
],
|
||||
)
|
||||
def test_invoke(model: str) -> None:
|
||||
"""Test invoke tokens from AI21."""
|
||||
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_j2_model_num_results_is_1",
|
||||
"when_j2_model_num_results_is_3",
|
||||
"when_jamba_model_n_is_1",
|
||||
"when_jamba_model_n_is_3",
|
||||
"when_jamba1.5_mini_model_n_is_1",
|
||||
"when_jamba1.5_mini_model_n_is_3",
|
||||
"when_jamba1.5_large_model_n_is_1",
|
||||
"when_jamba1.5_large_model_n_is_3",
|
||||
],
|
||||
argnames=["model", "num_results"],
|
||||
argvalues=[
|
||||
(J2_CHAT_MODEL_NAME, 1),
|
||||
(J2_CHAT_MODEL_NAME, 3),
|
||||
(JAMBA_CHAT_MODEL_NAME, 1),
|
||||
(JAMBA_CHAT_MODEL_NAME, 3),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 1),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 3),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 1),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 3),
|
||||
],
|
||||
)
|
||||
def test_generation(model: str, num_results: int) -> None:
|
||||
"""Test generation with multiple models and different result counts."""
|
||||
# Determine the configuration key based on the model type
|
||||
config_key = "n" if model in JAMBA_FAMILY_MODEL_NAMES else "num_results"
|
||||
|
||||
# Create the model instance using the appropriate key for the result count
|
||||
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||
|
||||
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
|
||||
|
||||
result = llm.generate([[message]], config=dict(tags=["foo"]))
|
||||
|
||||
for generations in result.generations:
|
||||
assert len(generations) == num_results
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_j2_model",
|
||||
"when_jamba_model",
|
||||
"when_jamba1.5_mini_model",
|
||||
"when_jamba1.5_large_model",
|
||||
],
|
||||
argnames=["model"],
|
||||
argvalues=[
|
||||
(J2_CHAT_MODEL_NAME,),
|
||||
(JAMBA_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
|
||||
],
|
||||
)
|
||||
async def test_ageneration(model: str) -> None:
|
||||
"""Test invoke tokens from AI21."""
|
||||
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
|
||||
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
|
||||
|
||||
for generations in result.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test__chat_stream() -> None:
|
||||
llm = ChatAI21(model="jamba-1.5-mini") # type: ignore[call-arg]
|
||||
message = HumanMessage(content="What is the meaning of life?")
|
||||
|
||||
for chunk in llm.stream([message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
assert isinstance(chunk.content, str)
|
||||
|
||||
|
||||
def test__j2_chat_stream__should_raise_error() -> None:
|
||||
llm = ChatAI21(model="j2-ultra") # type: ignore[call-arg]
|
||||
message = HumanMessage(content="What is the meaning of life?")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
for _ in llm.stream([message]):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_jamba1.5_mini_model",
|
||||
"when_jamba1.5_large_model",
|
||||
],
|
||||
argnames=["model"],
|
||||
argvalues=[
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
|
||||
],
|
||||
)
|
||||
def test_tool_calls(model: str) -> None:
|
||||
@tool
|
||||
def get_weather(location: str, date: str) -> str:
|
||||
"""“Provide the weather for the specified location on the given date.”"""
|
||||
if location == "New York" and date == "2024-12-05":
|
||||
return "25 celsius"
|
||||
return "32 celsius"
|
||||
|
||||
llm = ChatAI21(model=model, temperature=0) # type: ignore[call-arg]
|
||||
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
|
||||
|
||||
chat_messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant. "
|
||||
"You can use the provided tools "
|
||||
"to assist with various tasks and provide "
|
||||
"accurate information"
|
||||
),
|
||||
HumanMessage(
|
||||
content="What is the forecast for the weather "
|
||||
"in New York on December 5, 2024?"
|
||||
),
|
||||
]
|
||||
|
||||
response = llm_with_tools.invoke(chat_messages)
|
||||
chat_messages.append(response)
|
||||
assert response.tool_calls is not None # type: ignore[attr-defined]
|
||||
tool_call = response.tool_calls[0] # type: ignore[attr-defined]
|
||||
assert tool_call["name"] == "get_weather"
|
||||
|
||||
weather = get_weather.invoke( # type: ignore[attr-defined]
|
||||
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]}
|
||||
)
|
||||
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
|
||||
llm_answer = llm_with_tools.invoke(chat_messages)
|
||||
content = llm_answer.content.lower() # type: ignore[union-attr]
|
||||
assert "new york" in content and "25" in content and "celsius" in content
|
@ -1,7 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -1,61 +0,0 @@
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain_ai21.contextual_answers import (
|
||||
ANSWER_NOT_IN_CONTEXT_RESPONSE,
|
||||
AI21ContextualAnswers,
|
||||
)
|
||||
|
||||
context = """
|
||||
Albert Einstein German: 14 March 1879 – 18 April 1955)
|
||||
was a German-born theoretical physicist who is widely held
|
||||
to be one of the greatest and most influential scientists
|
||||
"""
|
||||
|
||||
|
||||
_GOOD_QUESTION = "When did Albert Einstein born?"
|
||||
_BAD_QUESTION = "What color is Yoda's light saber?"
|
||||
_EXPECTED_PARTIAL_RESPONSE = "March 14, 1879"
|
||||
|
||||
|
||||
def test_invoke__when_good_question() -> None:
|
||||
llm = AI21ContextualAnswers()
|
||||
|
||||
response = llm.invoke(
|
||||
{"context": context, "question": _GOOD_QUESTION},
|
||||
config={"metadata": {"name": "I AM A TEST"}},
|
||||
)
|
||||
|
||||
assert response != ANSWER_NOT_IN_CONTEXT_RESPONSE
|
||||
|
||||
|
||||
def test_invoke__when_bad_question__should_return_answer_not_in_context() -> None:
|
||||
llm = AI21ContextualAnswers()
|
||||
|
||||
response = llm.invoke(input={"context": context, "question": _BAD_QUESTION})
|
||||
|
||||
assert response == ANSWER_NOT_IN_CONTEXT_RESPONSE
|
||||
|
||||
|
||||
def test_invoke__when_response_if_no_answer_passed__should_use_it() -> None:
|
||||
response_if_no_answer_found = "This should be the response"
|
||||
llm = AI21ContextualAnswers()
|
||||
|
||||
response = llm.invoke(
|
||||
input={"context": context, "question": _BAD_QUESTION},
|
||||
response_if_no_answer_found=response_if_no_answer_found,
|
||||
)
|
||||
|
||||
assert response == response_if_no_answer_found
|
||||
|
||||
|
||||
def test_invoke_when_used_in_a_simple_chain_with_no_vectorstore() -> None:
|
||||
tsm = AI21ContextualAnswers()
|
||||
|
||||
chain: Runnable = tsm | StrOutputParser()
|
||||
|
||||
response = chain.invoke(
|
||||
{"context": context, "question": _GOOD_QUESTION},
|
||||
)
|
||||
|
||||
assert response != ANSWER_NOT_IN_CONTEXT_RESPONSE
|
@ -1,37 +0,0 @@
|
||||
"""Test AI21 embeddings."""
|
||||
|
||||
from langchain_ai21.embeddings import AI21Embeddings
|
||||
|
||||
|
||||
def test_langchain_ai21_embedding_documents() -> None:
|
||||
"""Test AI21 embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = AI21Embeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_ai21_embedding_query() -> None:
|
||||
"""Test AI21 embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = AI21Embeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
def test_langchain_ai21_embedding_documents__with_explicit_chunk_size() -> None:
|
||||
"""Test AI21 embeddings with chunk size passed as an argument."""
|
||||
documents = ["foo", "bar"]
|
||||
embedding = AI21Embeddings()
|
||||
output = embedding.embed_documents(documents, batch_size=1)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_ai21_embedding_query__with_explicit_chunk_size() -> None:
|
||||
"""Test AI21 embeddings with chunk size passed as an argument."""
|
||||
documents = "foo bar"
|
||||
embedding = AI21Embeddings()
|
||||
output = embedding.embed_query(documents, batch_size=1)
|
||||
assert len(output) > 0
|
@ -1,104 +0,0 @@
|
||||
"""Test AI21LLM llm."""
|
||||
|
||||
from langchain_ai21.llms import AI21LLM
|
||||
|
||||
_MODEL_NAME = "j2-mid"
|
||||
|
||||
|
||||
def _generate_llm() -> AI21LLM:
|
||||
"""
|
||||
Testing AI21LLm using non default parameters with the following parameters
|
||||
"""
|
||||
return AI21LLM(
|
||||
model=_MODEL_NAME,
|
||||
max_tokens=2, # Use less tokens for a faster response
|
||||
temperature=0, # for a consistent response
|
||||
epoch=1,
|
||||
)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from AI21."""
|
||||
llm = AI21LLM(
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test__generate() -> None:
|
||||
llm = _generate_llm()
|
||||
llm_result = llm.generate(
|
||||
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
|
||||
stop=["##"],
|
||||
)
|
||||
|
||||
assert len(llm_result.generations) > 0
|
||||
assert llm_result.llm_output["token_count"] != 0 # type: ignore
|
||||
|
||||
|
||||
async def test__agenerate() -> None:
|
||||
llm = _generate_llm()
|
||||
llm_result = await llm.agenerate(
|
||||
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
|
||||
stop=["##"],
|
||||
)
|
||||
|
||||
assert len(llm_result.generations) > 0
|
||||
assert llm_result.llm_output["token_count"] != 0 # type: ignore
|
@ -1,130 +0,0 @@
|
||||
from ai21 import AI21Client
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_ai21 import AI21SemanticTextSplitter
|
||||
|
||||
TEXT = (
|
||||
"The original full name of the franchise is Pocket Monsters (ポケットモンスター, "
|
||||
"Poketto Monsutā), which was abbreviated to "
|
||||
"Pokemon during development of the original games.\n"
|
||||
"When the franchise was released internationally, the short form of the title was "
|
||||
"used, with an acute accent (´) "
|
||||
"over the e to aid in pronunciation.\n"
|
||||
"Pokémon refers to both the franchise itself and the creatures within its "
|
||||
"fictional universe.\n"
|
||||
"As a noun, it is identical in both the singular and plural, as is every "
|
||||
"individual species name;[10] it is "
|
||||
'grammatically correct to say "one Pokémon" and "many Pokémon", as well '
|
||||
'as "one Pikachu" and "many Pikachu".\n'
|
||||
"In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or "
|
||||
"/'powkɪmon/ (poe-key-mon).\n"
|
||||
"The Pokémon franchise is set in a world in which humans coexist with creatures "
|
||||
"known as Pokémon.\n"
|
||||
"Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced "
|
||||
"in subsequent games; as of December 2023, 1,025 Pokémon species have been "
|
||||
"introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example,"
|
||||
"Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] "
|
||||
"that possess electrical abilities.[15]\nThe player character takes the role of a "
|
||||
"Pokémon Trainer.\nThe Trainer has three primary goals: travel and explore the "
|
||||
"Pokémon world; discover and catch each Pokémon species in order to complete their"
|
||||
"Pokédex; and train a team of up to six Pokémon at a time and have them engage "
|
||||
"in battles.\nMost Pokémon can be caught with spherical devices known as Poké "
|
||||
"Balls.\nOnce the opposing Pokémon is sufficiently weakened, the Trainer throws "
|
||||
"the Poké Ball against the Pokémon, which is then transformed into a form of "
|
||||
"energy and transported into the device.\nOnce the catch is successful, "
|
||||
"the Pokémon is tamed and is under the Trainer's command from then on.\n"
|
||||
"If the Poké Ball is thrown again, the Pokémon re-materializes into its "
|
||||
"original state.\nThe Trainer's Pokémon can engage in battles against opposing "
|
||||
"Pokémon, including those in the wild or owned by other Trainers.\nBecause the "
|
||||
"franchise is aimed at children, these battles are never presented as overtly "
|
||||
"violent and contain no blood or gore.[I]\nPokémon never die in battle, instead "
|
||||
"fainting upon being defeated.[20][21][22]\nAfter a Pokémon wins a battle, it "
|
||||
"gains experience and becomes stronger.[23] After gaining a certain amount of "
|
||||
"experience points, its level increases, as well as one or more of its "
|
||||
"statistics.\nAs its level increases, the Pokémon can learn new offensive "
|
||||
"and defensive moves to use in battle.[24][25] Furthermore, many species can "
|
||||
"undergo a form of spontaneous metamorphosis called Pokémon evolution, and "
|
||||
"transform into stronger forms.[26] Most Pokémon will evolve at a certain level, "
|
||||
"while others evolve through different means, such as exposure to a certain "
|
||||
"item.[27]\n"
|
||||
)
|
||||
|
||||
|
||||
def test_split_text_to_document() -> None:
|
||||
segmentation = AI21SemanticTextSplitter()
|
||||
segments = segmentation.split_text_to_documents(source=TEXT)
|
||||
assert len(segments) > 0
|
||||
for segment in segments:
|
||||
assert segment.page_content is not None
|
||||
assert segment.metadata is not None
|
||||
|
||||
|
||||
def test_split_text() -> None:
|
||||
segmentation = AI21SemanticTextSplitter()
|
||||
segments = segmentation.split_text(source=TEXT)
|
||||
assert len(segments) > 0
|
||||
|
||||
|
||||
def test_split_text__when_chunk_size_is_large__should_merge_segments() -> None:
|
||||
segmentation_no_merge = AI21SemanticTextSplitter()
|
||||
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
|
||||
segmentation_merge = AI21SemanticTextSplitter(chunk_size=1000)
|
||||
segments_merge = segmentation_merge.split_text(source=TEXT)
|
||||
# Assert that a merge did happen
|
||||
assert len(segments_no_merge) > len(segments_merge)
|
||||
reconstructed_text_merged = "".join(segments_merge)
|
||||
reconstructed_text_non_merged = "".join(segments_no_merge)
|
||||
# Assert that the merge did not change the content
|
||||
assert reconstructed_text_merged == reconstructed_text_non_merged
|
||||
|
||||
|
||||
def test_split_text__chunk_size_is_too_small__should_return_non_merged_segments() -> (
|
||||
None
|
||||
):
|
||||
segmentation_no_merge = AI21SemanticTextSplitter()
|
||||
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
|
||||
segmentation_merge = AI21SemanticTextSplitter(chunk_size=10)
|
||||
segments_merge = segmentation_merge.split_text(source=TEXT)
|
||||
# Assert that a merge did happen
|
||||
assert len(segments_no_merge) == len(segments_merge)
|
||||
reconstructed_text_merged = "".join(segments_merge)
|
||||
reconstructed_text_non_merged = "".join(segments_no_merge)
|
||||
# Assert that the merge did not change the content
|
||||
assert reconstructed_text_merged == reconstructed_text_non_merged
|
||||
|
||||
|
||||
def test_split_text__when_chunk_size_set_with_ai21_tokenizer() -> None:
|
||||
segmentation_no_merge = AI21SemanticTextSplitter(
|
||||
length_function=AI21Client().count_tokens
|
||||
)
|
||||
segments_no_merge = segmentation_no_merge.split_text(source=TEXT)
|
||||
segmentation_merge = AI21SemanticTextSplitter(
|
||||
chunk_size=1000, length_function=AI21Client().count_tokens
|
||||
)
|
||||
segments_merge = segmentation_merge.split_text(source=TEXT)
|
||||
# Assert that a merge did happen
|
||||
assert len(segments_no_merge) > len(segments_merge)
|
||||
reconstructed_text_merged = "".join(segments_merge)
|
||||
reconstructed_text_non_merged = "".join(segments_no_merge)
|
||||
# Assert that the merge did not change the content
|
||||
assert reconstructed_text_merged == reconstructed_text_non_merged
|
||||
|
||||
|
||||
def test_create_documents() -> None:
|
||||
texts = [TEXT]
|
||||
segmentation = AI21SemanticTextSplitter()
|
||||
documents = segmentation.create_documents(texts=texts)
|
||||
assert len(documents) > 0
|
||||
for document in documents:
|
||||
assert document.page_content is not None
|
||||
assert document.metadata is not None
|
||||
|
||||
|
||||
def test_split_documents() -> None:
|
||||
documents = [Document(page_content=TEXT, metadata={"foo": "bar"})]
|
||||
segmentation = AI21SemanticTextSplitter()
|
||||
segments = segmentation.split_documents(documents=documents)
|
||||
assert len(segments) > 0
|
||||
for segment in segments:
|
||||
assert segment.page_content is not None
|
||||
assert segment.metadata is not None
|
@ -1,89 +0,0 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_ai21 import ChatAI21
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||
|
||||
|
||||
class BaseTestAI21(ChatModelIntegrationTests):
|
||||
def teardown(self) -> None:
|
||||
# avoid getting rate limited
|
||||
time.sleep(1)
|
||||
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata(model)
|
||||
|
||||
|
||||
class TestAI21J2(BaseTestAI21):
|
||||
has_tool_calling = False
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "j2-ultra",
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
|
||||
def test_stream(self, model: BaseChatModel) -> None:
|
||||
super().test_stream(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
|
||||
async def test_astream(self, model: BaseChatModel) -> None:
|
||||
await super().test_astream(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
||||
|
||||
class TestAI21Jamba(BaseTestAI21):
|
||||
has_tool_calling = False
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "jamba-instruct-preview",
|
||||
}
|
||||
|
||||
|
||||
class TestAI21Jamba1_5(BaseTestAI21):
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "jamba-1.5-mini",
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Prompt doesn't generate tool calls for Jamba 1.5.")
|
||||
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Prompt doesn't generate tool calls for Jamba 1.5.")
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Requires tool calling & stream - still WIP")
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
super().test_structured_output(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Requires tool calling & stream - still WIP")
|
||||
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
|
||||
super().test_structured_output_pydantic_2_v1(model)
|
@ -1,9 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from langchain_ai21.chat.chat_adapter import ChatAdapter
|
||||
from langchain_ai21.chat.chat_factory import create_chat_adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_adapter(model: str) -> ChatAdapter:
|
||||
return create_chat_adapter(model)
|
@ -1,248 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from ai21.models import ChatMessage as J2ChatMessage
|
||||
from ai21.models import RoleType
|
||||
from ai21.models.chat import (
|
||||
AssistantMessage,
|
||||
ChatMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from ai21.models.chat import (
|
||||
SystemMessage as AI21SystemMessage,
|
||||
)
|
||||
from ai21.models.chat import ToolMessage as AI21ToolMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
ChatMessage as LangChainChatMessage,
|
||||
)
|
||||
|
||||
from langchain_ai21.chat.chat_adapter import ChatAdapter
|
||||
|
||||
_J2_MODEL_NAME = "j2-ultra"
|
||||
_JAMBA_MODEL_NAME = "jamba-instruct-preview"
|
||||
_JAMBA_1_5_MINI_MODEL_NAME = "jamba-1.5-mini"
|
||||
_JAMBA_1_5_LARGE_MODEL_NAME = "jamba-1.5-large"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_human_message_j2_model",
|
||||
"when_ai_message_j2_model",
|
||||
"when_human_message_jamba_model",
|
||||
"when_ai_message_jamba_model",
|
||||
],
|
||||
argnames=["model", "message", "expected_ai21_message"],
|
||||
argvalues=[
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
HumanMessage(content="Human Message Content"),
|
||||
J2ChatMessage(role=RoleType.USER, text="Human Message Content"),
|
||||
),
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
AIMessage(content="AI Message Content"),
|
||||
J2ChatMessage(role=RoleType.ASSISTANT, text="AI Message Content"),
|
||||
),
|
||||
(
|
||||
_JAMBA_MODEL_NAME,
|
||||
HumanMessage(content="Human Message Content"),
|
||||
UserMessage(role="user", content="Human Message Content"),
|
||||
),
|
||||
(
|
||||
_JAMBA_MODEL_NAME,
|
||||
AIMessage(content="AI Message Content"),
|
||||
AssistantMessage(
|
||||
role="assistant", content="AI Message Content", tool_calls=[]
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_ai21_message(
|
||||
message: BaseMessage,
|
||||
expected_ai21_message: ChatMessage,
|
||||
chat_adapter: ChatAdapter,
|
||||
) -> None:
|
||||
ai21_message = chat_adapter._convert_message_to_ai21_message(message)
|
||||
assert ai21_message == expected_ai21_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_system_message_j2_model",
|
||||
"when_langchain_chat_message_j2_model",
|
||||
],
|
||||
argnames=["model", "message"],
|
||||
argvalues=[
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
AI21SystemMessage(content="System Message Content"),
|
||||
),
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
LangChainChatMessage(content="Chat Message Content", role="human"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_ai21_message__when_invalid_role__should_raise_exception(
|
||||
message: BaseMessage,
|
||||
chat_adapter: ChatAdapter,
|
||||
) -> None:
|
||||
with pytest.raises(ValueError) as e:
|
||||
chat_adapter._convert_message_to_ai21_message(message)
|
||||
assert e.value.args[0] == (
|
||||
f"Could not resolve role type from message {message}. "
|
||||
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_all_messages_are_human_messages__should_return_system_none_j2_model",
|
||||
"when_first_message_is_system__should_return_system_j2_model",
|
||||
"when_all_messages_are_human_messages__should_return_system_none_jamba_model",
|
||||
"when_first_message_is_system__should_return_system_jamba_model",
|
||||
"when_tool_calling_message__should_return_tool_jamba_mini_model",
|
||||
"when_tool_calling_message__should_return_tool_jamba_large_model",
|
||||
],
|
||||
argnames=["model", "messages", "expected_messages"],
|
||||
argvalues=[
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
[
|
||||
HumanMessage(content="Human Message Content 1"),
|
||||
HumanMessage(content="Human Message Content 2"),
|
||||
],
|
||||
{
|
||||
"system": "",
|
||||
"messages": [
|
||||
J2ChatMessage(
|
||||
role=RoleType.USER,
|
||||
text="Human Message Content 1",
|
||||
),
|
||||
J2ChatMessage(
|
||||
role=RoleType.USER,
|
||||
text="Human Message Content 2",
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
[
|
||||
SystemMessage(content="System Message Content 1"),
|
||||
HumanMessage(content="Human Message Content 1"),
|
||||
],
|
||||
{
|
||||
"system": "System Message Content 1",
|
||||
"messages": [
|
||||
J2ChatMessage(
|
||||
role=RoleType.USER,
|
||||
text="Human Message Content 1",
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
_JAMBA_MODEL_NAME,
|
||||
[
|
||||
HumanMessage(content="Human Message Content 1"),
|
||||
HumanMessage(content="Human Message Content 2"),
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
UserMessage(
|
||||
role="user",
|
||||
content="Human Message Content 1",
|
||||
),
|
||||
UserMessage(
|
||||
role="user",
|
||||
content="Human Message Content 2",
|
||||
),
|
||||
]
|
||||
},
|
||||
),
|
||||
(
|
||||
_JAMBA_MODEL_NAME,
|
||||
[
|
||||
SystemMessage(content="System Message Content 1"),
|
||||
HumanMessage(content="Human Message Content 1"),
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
AI21SystemMessage(
|
||||
role="system", content="System Message Content 1"
|
||||
),
|
||||
UserMessage(role="user", content="Human Message Content 1"),
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
_JAMBA_1_5_MINI_MODEL_NAME,
|
||||
[
|
||||
ToolMessage(
|
||||
content="42",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
)
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
AI21ToolMessage(
|
||||
role="tool",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
content="42",
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
_JAMBA_1_5_LARGE_MODEL_NAME,
|
||||
[
|
||||
ToolMessage(
|
||||
content="42",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
)
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
AI21ToolMessage(
|
||||
role="tool",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
content="42",
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_messages(
|
||||
chat_adapter: ChatAdapter,
|
||||
messages: List[BaseMessage],
|
||||
expected_messages: List[ChatMessage],
|
||||
) -> None:
|
||||
converted_messages = chat_adapter.convert_messages(messages)
|
||||
assert converted_messages == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_j2_model",
|
||||
],
|
||||
argnames=["model"],
|
||||
argvalues=[
|
||||
(_J2_MODEL_NAME,),
|
||||
],
|
||||
)
|
||||
def test_convert_messages__when_system_is_not_first(chat_adapter: ChatAdapter) -> None:
|
||||
messages = [
|
||||
HumanMessage(content="Human Message Content 1"),
|
||||
SystemMessage(content="System Message Content 1"),
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
chat_adapter.convert_messages(messages)
|
@ -1,34 +0,0 @@
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_ai21.chat.chat_adapter import (
|
||||
ChatAdapter,
|
||||
J2ChatAdapter,
|
||||
JambaChatCompletionsAdapter,
|
||||
)
|
||||
from langchain_ai21.chat.chat_factory import create_chat_adapter
|
||||
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_j2_model",
|
||||
"when_jamba_model",
|
||||
],
|
||||
argnames=["model", "expected_chat_type"],
|
||||
argvalues=[
|
||||
(J2_CHAT_MODEL_NAME, J2ChatAdapter),
|
||||
(JAMBA_CHAT_MODEL_NAME, JambaChatCompletionsAdapter),
|
||||
],
|
||||
)
|
||||
def test_create_chat_adapter_with_supported_models(
|
||||
model: str, expected_chat_type: Type[ChatAdapter]
|
||||
) -> None:
|
||||
adapter = create_chat_adapter(model)
|
||||
assert isinstance(adapter, expected_chat_type)
|
||||
|
||||
|
||||
def test_create_chat_adapter__when_model_not_supported() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
create_chat_adapter("unsupported-model")
|
@ -1,207 +0,0 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from ai21 import AI21Client
|
||||
from ai21.models import (
|
||||
AnswerResponse,
|
||||
ChatOutput,
|
||||
ChatResponse,
|
||||
Completion,
|
||||
CompletionData,
|
||||
CompletionFinishReason,
|
||||
CompletionsResponse,
|
||||
FinishReason,
|
||||
Penalty,
|
||||
RoleType,
|
||||
SegmentationResponse,
|
||||
)
|
||||
from ai21.models.responses.segmentation_response import Segment
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
J2_CHAT_MODEL_NAME = "j2-ultra"
|
||||
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
|
||||
JAMBA_1_5_MINI_CHAT_MODEL_NAME = "jamba-1.5-mini"
|
||||
JAMBA_1_5_LARGE_CHAT_MODEL_NAME = "jamba-1.5-large"
|
||||
DUMMY_API_KEY = "test_api_key"
|
||||
|
||||
JAMBA_FAMILY_MODEL_NAMES = [
|
||||
JAMBA_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
|
||||
]
|
||||
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS = {
|
||||
"num_results": 3,
|
||||
"max_tokens": 20,
|
||||
"min_tokens": 10,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
),
|
||||
}
|
||||
|
||||
BASIC_EXAMPLE_CHAT_PARAMETERS = {
|
||||
"num_results": 3,
|
||||
"max_tokens": 20,
|
||||
"min_tokens": 10,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
),
|
||||
"n": 3,
|
||||
}
|
||||
|
||||
SEGMENTS = [
|
||||
Segment( # type: ignore[call-arg]
|
||||
segment_type="normal_text",
|
||||
segment_text=(
|
||||
"The original full name of the franchise is Pocket Monsters "
|
||||
"(ポケットモンスター, Poketto Monsutā), which was abbreviated to "
|
||||
"Pokemon during development of the original games.\n\nWhen the "
|
||||
"franchise was released internationally, the short form of the "
|
||||
"title was used, with an acute accent (´) over the e to aid "
|
||||
"in pronunciation."
|
||||
),
|
||||
),
|
||||
Segment( # type: ignore[call-arg]
|
||||
segment_type="normal_text",
|
||||
segment_text=(
|
||||
"Pokémon refers to both the franchise itself and the creatures "
|
||||
"within its fictional universe.\n\nAs a noun, it is identical in "
|
||||
"both the singular and plural, as is every individual species "
|
||||
'name;[10] it is grammatically correct to say "one Pokémon" '
|
||||
'and "many Pokémon", as well as "one Pikachu" and "many '
|
||||
'Pikachu".\n\nIn English, Pokémon may be pronounced either '
|
||||
"/'powkɛmon/ (poe-keh-mon) or /'powkɪmon/ (poe-key-mon)."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
|
||||
"num_results": 3,
|
||||
"max_tokens": 20,
|
||||
"min_tokens": 10,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
).to_dict(),
|
||||
}
|
||||
|
||||
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
|
||||
"num_results": 3,
|
||||
"max_tokens": 20,
|
||||
"min_tokens": 10,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
).to_dict(),
|
||||
"n": 3,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_completion_response(mocker: MockerFixture) -> Mock:
|
||||
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
|
||||
mocked_response.prompt = "this is a test prompt"
|
||||
mocked_response.completions = [
|
||||
Completion( # type: ignore[call-arg]
|
||||
data=CompletionData(text="test", tokens=[]),
|
||||
finish_reason=CompletionFinishReason(reason=None, length=None),
|
||||
)
|
||||
]
|
||||
return mocked_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_completion(
|
||||
mocker: MockerFixture, mocked_completion_response: Mock
|
||||
) -> Mock:
|
||||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||||
mock_client.completion = mocker.MagicMock()
|
||||
mock_client.completion.create.side_effect = [
|
||||
mocked_completion_response,
|
||||
mocked_completion_response,
|
||||
]
|
||||
mock_client.count_tokens.side_effect = [10, 20]
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_chat(mocker: MockerFixture) -> Mock:
|
||||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||||
mock_client.chat = mocker.MagicMock()
|
||||
|
||||
output = ChatOutput( # type: ignore[call-arg]
|
||||
text="Hello Pickle Rick!",
|
||||
role=RoleType.ASSISTANT,
|
||||
finish_reason=FinishReason(reason="testing"),
|
||||
)
|
||||
mock_client.chat.create.return_value = ChatResponse(outputs=[output])
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporarily_unset_api_key() -> Generator:
|
||||
"""
|
||||
Unset and set environment key for testing purpose for when an API KEY is not set
|
||||
"""
|
||||
api_key = os.environ.pop("AI21_API_KEY", None)
|
||||
yield
|
||||
|
||||
if api_key is not None:
|
||||
os.environ["AI21_API_KEY"] = api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
|
||||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||||
mock_client.answer = mocker.MagicMock()
|
||||
mock_client.answer.create.return_value = AnswerResponse( # type: ignore[call-arg]
|
||||
id="some_id",
|
||||
answer="some answer",
|
||||
answer_in_context=False,
|
||||
)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_semantic_text_splitter(mocker: MockerFixture) -> Mock:
|
||||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||||
mock_client.segmentation = mocker.MagicMock()
|
||||
mock_client.segmentation.create.return_value = SegmentationResponse(
|
||||
id="12345",
|
||||
segments=SEGMENTS,
|
||||
)
|
||||
|
||||
return mock_client
|
@ -1,174 +0,0 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
import pytest
|
||||
from ai21 import MissingApiKeyError
|
||||
from ai21.models import ChatMessage, Penalty, RoleType
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_ai21.chat_models import (
|
||||
ChatAI21,
|
||||
)
|
||||
from tests.unit_tests.conftest import (
|
||||
BASIC_EXAMPLE_CHAT_PARAMETERS,
|
||||
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
|
||||
DUMMY_API_KEY,
|
||||
temporarily_unset_api_key,
|
||||
)
|
||||
|
||||
|
||||
def test_initialization__when_no_api_key__should_raise_exception() -> None:
|
||||
"""Test integration initialization."""
|
||||
with temporarily_unset_api_key():
|
||||
with pytest.raises(MissingApiKeyError):
|
||||
ChatAI21(model="j2-ultra") # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_initialization__when_default_parameters_in_init() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatAI21(api_key=DUMMY_API_KEY, model="j2-ultra") # type: ignore[call-arg, arg-type]
|
||||
|
||||
|
||||
def test_initialization__when_custom_parameters_in_init() -> None:
|
||||
model = "j2-ultra"
|
||||
num_results = 1
|
||||
max_tokens = 10
|
||||
min_tokens = 20
|
||||
temperature = 0.1
|
||||
top_p = 0.1
|
||||
top_k_return = 0
|
||||
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) # type: ignore[call-arg]
|
||||
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) # type: ignore[call-arg]
|
||||
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) # type: ignore[call-arg]
|
||||
|
||||
llm = ChatAI21( # type: ignore[call-arg]
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
model=model,
|
||||
num_results=num_results,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k_return=top_k_return,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
count_penalty=count_penalty,
|
||||
)
|
||||
assert llm.model == model
|
||||
assert llm.num_results == num_results
|
||||
assert llm.max_tokens == max_tokens
|
||||
assert llm.min_tokens == min_tokens
|
||||
assert llm.temperature == temperature
|
||||
assert llm.top_p == top_p
|
||||
assert llm.top_k_return == top_k_return
|
||||
assert llm.frequency_penalty == frequency_penalty
|
||||
assert llm.presence_penalty == presence_penalty
|
||||
assert count_penalty == count_penalty
|
||||
|
||||
|
||||
def test_invoke(mock_client_with_chat: Mock) -> None:
|
||||
chat_input = "I'm Pickle Rick"
|
||||
|
||||
llm = ChatAI21(
|
||||
model="j2-ultra",
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
client=mock_client_with_chat,
|
||||
**BASIC_EXAMPLE_CHAT_PARAMETERS, # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||
)
|
||||
llm.invoke(input=chat_input, config=dict(tags=["foo"]), stop=["\n"])
|
||||
|
||||
mock_client_with_chat.chat.create.assert_called_once_with(
|
||||
model="j2-ultra",
|
||||
messages=[ChatMessage(role=RoleType.USER, text=chat_input)],
|
||||
system="",
|
||||
stop_sequences=["\n"],
|
||||
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
|
||||
)
|
||||
|
||||
|
||||
def test_generate(mock_client_with_chat: Mock) -> None:
|
||||
messages0 = [
|
||||
HumanMessage(content="I'm Pickle Rick"),
|
||||
AIMessage(content="Hello Pickle Rick! I am your AI Assistant"),
|
||||
HumanMessage(content="Nice to meet you."),
|
||||
]
|
||||
messages1 = [
|
||||
SystemMessage(content="system message"),
|
||||
HumanMessage(content="What is 1 + 1"),
|
||||
]
|
||||
llm = ChatAI21(
|
||||
model="j2-ultra",
|
||||
client=mock_client_with_chat,
|
||||
**BASIC_EXAMPLE_CHAT_PARAMETERS, # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||
)
|
||||
|
||||
llm.generate(messages=[messages0, messages1])
|
||||
mock_client_with_chat.chat.create.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
model="j2-ultra",
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=RoleType.USER,
|
||||
text=str(messages0[0].content),
|
||||
),
|
||||
ChatMessage(
|
||||
role=RoleType.ASSISTANT, text=str(messages0[1].content)
|
||||
),
|
||||
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
|
||||
],
|
||||
system="",
|
||||
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
|
||||
),
|
||||
call(
|
||||
model="j2-ultra",
|
||||
messages=[
|
||||
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
|
||||
],
|
||||
system="system message",
|
||||
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_api_key_is_secret_string() -> None:
|
||||
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") # type: ignore[call-arg, arg-type]
|
||||
assert isinstance(llm.api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
|
||||
llm = ChatAI21(model="j2-ultra") # type: ignore[call-arg]
|
||||
print(llm.api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") # type: ignore[call-arg, arg-type]
|
||||
print(llm.api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||
llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") # type: ignore[call-arg, arg-type]
|
||||
assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"
|
@ -1,109 +0,0 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_ai21 import AI21ContextualAnswers
|
||||
from langchain_ai21.contextual_answers import ContextualAnswerInput
|
||||
from tests.unit_tests.conftest import DUMMY_API_KEY
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_no_context__should_raise_exception",
|
||||
"when_no_question__should_raise_exception",
|
||||
"when_question_is_an_empty_string__should_raise_exception",
|
||||
"when_context_is_an_empty_string__should_raise_exception",
|
||||
"when_context_is_an_empty_list",
|
||||
],
|
||||
argnames="input",
|
||||
argvalues=[
|
||||
({"question": "What is the capital of France?"}),
|
||||
({"context": "Paris is the capital of France"}),
|
||||
({"question": "", "context": "Paris is the capital of France"}),
|
||||
({"context": "", "question": "some question?"}),
|
||||
({"context": [], "question": "What is the capital of France?"}),
|
||||
],
|
||||
)
|
||||
def test_invoke__on_bad_input(
|
||||
input: ContextualAnswerInput,
|
||||
mock_client_with_contextual_answers: Mock,
|
||||
) -> None:
|
||||
tsm = AI21ContextualAnswers(
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
client=mock_client_with_contextual_answers, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as error:
|
||||
tsm.invoke(input)
|
||||
|
||||
assert (
|
||||
error.value.args[0]
|
||||
== f"Input must contain a 'context' and 'question' fields. Got {input}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_context_is_not_str_or_list_of_docs_or_str",
|
||||
],
|
||||
argnames="input",
|
||||
argvalues=[
|
||||
({"context": 1242, "question": "What is the capital of France?"}),
|
||||
],
|
||||
)
|
||||
def test_invoke__on_context_bad_input(
|
||||
input: ContextualAnswerInput, mock_client_with_contextual_answers: Mock
|
||||
) -> None:
|
||||
tsm = AI21ContextualAnswers(
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
client=mock_client_with_contextual_answers,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as error:
|
||||
tsm.invoke(input)
|
||||
|
||||
assert (
|
||||
error.value.args[0] == f"Expected input to be a list of strings or Documents."
|
||||
f" Received {type(input)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_context_is_a_list_of_strings",
|
||||
"when_context_is_a_list_of_documents",
|
||||
"when_context_is_a_string",
|
||||
],
|
||||
argnames="input",
|
||||
argvalues=[
|
||||
(
|
||||
{
|
||||
"context": ["Paris is the capital of france"],
|
||||
"question": "What is the capital of France?",
|
||||
}
|
||||
),
|
||||
(
|
||||
{
|
||||
"context": [Document(page_content="Paris is the capital of france")],
|
||||
"question": "What is the capital of France?",
|
||||
}
|
||||
),
|
||||
(
|
||||
{
|
||||
"context": "Paris is the capital of france",
|
||||
"question": "What is the capital of France?",
|
||||
}
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invoke__on_good_input(
|
||||
input: ContextualAnswerInput, mock_client_with_contextual_answers: Mock
|
||||
) -> None:
|
||||
tsm = AI21ContextualAnswers(
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
client=mock_client_with_contextual_answers,
|
||||
)
|
||||
|
||||
response = tsm.invoke(input)
|
||||
assert isinstance(response, str)
|
@ -1,102 +0,0 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from ai21 import AI21Client, MissingApiKeyError
|
||||
from ai21.models import EmbedResponse, EmbedResult, EmbedType
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain_ai21.embeddings import AI21Embeddings
|
||||
from tests.unit_tests.conftest import DUMMY_API_KEY, temporarily_unset_api_key
|
||||
|
||||
_EXAMPLE_EMBEDDING_0 = [1.0, 2.0, 3.0]
|
||||
_EXAMPLE_EMBEDDING_1 = [4.0, 5.0, 6.0]
|
||||
_EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
|
||||
|
||||
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
|
||||
results=[
|
||||
EmbedResult(embedding=_EXAMPLE_EMBEDDING_0),
|
||||
EmbedResult(embedding=_EXAMPLE_EMBEDDING_1),
|
||||
EmbedResult(embedding=_EXAMPLE_EMBEDDING_2),
|
||||
],
|
||||
id="test_id",
|
||||
)
|
||||
|
||||
|
||||
def test_initialization__when_no_api_key__should_raise_exception() -> None:
|
||||
"""Test integration initialization."""
|
||||
with temporarily_unset_api_key():
|
||||
with pytest.raises(MissingApiKeyError):
|
||||
AI21Embeddings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_embeddings(mocker: MockerFixture) -> Mock:
|
||||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||||
mock_client.embed = mocker.MagicMock()
|
||||
mock_client.embed.create.return_value = _EXAMPLE_EMBEDDING_RESPONSE
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
def test_embed_query(mock_client_with_embeddings: Mock) -> None:
|
||||
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY) # type: ignore[arg-type]
|
||||
|
||||
text = "Hello embeddings world!"
|
||||
response = llm.embed_query(text=text)
|
||||
assert response == _EXAMPLE_EMBEDDING_0
|
||||
mock_client_with_embeddings.embed.create.assert_called_once_with(
|
||||
texts=[text],
|
||||
type=EmbedType.QUERY,
|
||||
)
|
||||
|
||||
|
||||
def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
|
||||
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY) # type: ignore[arg-type]
|
||||
|
||||
texts = ["Hello embeddings world!", "Some other text", "Some more text"]
|
||||
response = llm.embed_documents(texts=texts)
|
||||
assert response == [
|
||||
_EXAMPLE_EMBEDDING_0,
|
||||
_EXAMPLE_EMBEDDING_1,
|
||||
_EXAMPLE_EMBEDDING_2,
|
||||
]
|
||||
mock_client_with_embeddings.embed.create.assert_called_once_with(
|
||||
texts=texts,
|
||||
type=EmbedType.SEGMENT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"empty_texts",
|
||||
"chunk_size_greater_than_texts_length",
|
||||
"chunk_size_equal_to_texts_length",
|
||||
"chunk_size_less_than_texts_length",
|
||||
"chunk_size_one_with_multiple_texts",
|
||||
"chunk_size_greater_than_texts_length",
|
||||
],
|
||||
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
|
||||
argvalues=[
|
||||
([], 3, 0),
|
||||
(["text1", "text2", "text3"], 5, 1),
|
||||
(["text1", "text2", "text3"], 3, 1),
|
||||
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
|
||||
(["text1", "text2", "text3"], 1, 3),
|
||||
(["text1", "text2", "text3"], 10, 1),
|
||||
],
|
||||
)
|
||||
def test_get_len_safe_embeddings(
|
||||
mock_client_with_embeddings: Mock,
|
||||
texts: List[str],
|
||||
chunk_size: int,
|
||||
expected_internal_embeddings_calls: int,
|
||||
) -> None:
|
||||
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY) # type: ignore[arg-type]
|
||||
llm.embed_documents(texts=texts, batch_size=chunk_size)
|
||||
assert (
|
||||
mock_client_with_embeddings.embed.create.call_count
|
||||
== expected_internal_embeddings_calls
|
||||
)
|
@ -1,13 +0,0 @@
|
||||
from langchain_ai21 import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AI21LLM",
|
||||
"ChatAI21",
|
||||
"AI21Embeddings",
|
||||
"AI21ContextualAnswers",
|
||||
"AI21SemanticTextSplitter",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -1,146 +0,0 @@
|
||||
"""Test AI21 Chat API wrapper."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
import pytest
|
||||
from ai21 import MissingApiKeyError
|
||||
from ai21.models import (
|
||||
Penalty,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_ai21 import AI21LLM
|
||||
from tests.unit_tests.conftest import (
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
DUMMY_API_KEY,
|
||||
temporarily_unset_api_key,
|
||||
)
|
||||
|
||||
|
||||
def test_initialization__when_no_api_key__should_raise_exception() -> None:
|
||||
"""Test integration initialization."""
|
||||
with temporarily_unset_api_key():
|
||||
with pytest.raises(MissingApiKeyError):
|
||||
AI21LLM(
|
||||
model="j2-ultra",
|
||||
)
|
||||
|
||||
|
||||
def test_initialization__when_default_parameters() -> None:
|
||||
"""Test integration initialization."""
|
||||
AI21LLM(
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
model="j2-ultra",
|
||||
)
|
||||
|
||||
|
||||
def test_initialization__when_custom_parameters_to_init() -> None:
|
||||
"""Test integration initialization."""
|
||||
AI21LLM( # type: ignore[call-arg]
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
model="j2-mid",
|
||||
num_results=2,
|
||||
max_tokens=20,
|
||||
min_tokens=10,
|
||||
temperature=0.5,
|
||||
top_p=0.5,
|
||||
top_k_return=0,
|
||||
stop_sequences=["\n"],
|
||||
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
|
||||
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
|
||||
count_penalty=Penalty( # type: ignore[call-arg]
|
||||
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
|
||||
),
|
||||
custom_model="test_model",
|
||||
epoch=1,
|
||||
)
|
||||
|
||||
|
||||
def test_generate(mock_client_with_completion: Mock) -> None:
|
||||
# Setup test
|
||||
prompt0 = "Hi, my name is what?"
|
||||
prompt1 = "My name is who?"
|
||||
stop = ["\n"]
|
||||
custom_model = "test_model"
|
||||
epoch = 1
|
||||
|
||||
ai21 = AI21LLM(
|
||||
model="j2-ultra",
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
client=mock_client_with_completion,
|
||||
custom_model=custom_model,
|
||||
epoch=epoch,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS, # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||
)
|
||||
|
||||
# Make call to testing function
|
||||
ai21.generate(
|
||||
[prompt0, prompt1],
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_client_with_completion.count_tokens.assert_has_calls(
|
||||
[
|
||||
call(prompt0),
|
||||
call(prompt1),
|
||||
],
|
||||
)
|
||||
|
||||
mock_client_with_completion.completion.create.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
prompt=prompt0,
|
||||
model="j2-ultra",
|
||||
custom_model=custom_model,
|
||||
stop_sequences=stop,
|
||||
epoch=epoch,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
),
|
||||
call(
|
||||
prompt=prompt1,
|
||||
model="j2-ultra",
|
||||
custom_model=custom_model,
|
||||
stop_sequences=stop,
|
||||
epoch=epoch,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_api_key_is_secret_string() -> None:
|
||||
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") # type: ignore[arg-type]
|
||||
assert isinstance(llm.api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
|
||||
llm = AI21LLM(model="j2-ultra")
|
||||
print(llm.api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") # type: ignore[arg-type]
|
||||
print(llm.api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||
llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") # type: ignore[arg-type]
|
||||
assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"
|
@ -1,129 +0,0 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_ai21 import AI21SemanticTextSplitter
|
||||
from tests.unit_tests.conftest import SEGMENTS
|
||||
|
||||
TEXT = (
|
||||
"The original full name of the franchise is Pocket Monsters (ポケットモンスター, "
|
||||
"Poketto Monsutā), which was abbreviated to "
|
||||
"Pokemon during development of the original games.\n"
|
||||
"When the franchise was released internationally, the short form of the title was "
|
||||
"used, with an acute accent (´) "
|
||||
"over the e to aid in pronunciation.\n"
|
||||
"Pokémon refers to both the franchise itself and the creatures within its "
|
||||
"fictional universe.\n"
|
||||
"As a noun, it is identical in both the singular and plural, as is every "
|
||||
"individual species name;[10] it is "
|
||||
'grammatically correct to say "one Pokémon" and "many Pokémon", as well '
|
||||
'as "one Pikachu" and "many Pikachu".\n'
|
||||
"In English, Pokémon may be pronounced either /'powkɛmon/ (poe-keh-mon) or "
|
||||
"/'powkɪmon/ (poe-key-mon).\n"
|
||||
"The Pokémon franchise is set in a world in which humans coexist with creatures "
|
||||
"known as Pokémon.\n"
|
||||
"Pokémon Red and Blue contain 151 Pokémon species, with new ones being introduced "
|
||||
"in subsequent games; as of December 2023, 1,025 Pokémon species have been "
|
||||
"introduced.\n[b] Most Pokémon are inspired by real-world animals;[12] for example,"
|
||||
"Pikachu are a yellow mouse-like species[13] with lightning bolt-shaped tails[14] "
|
||||
"that possess electrical abilities.[15]"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_chunk_size_is_zero",
|
||||
"when_chunk_size_is_large",
|
||||
"when_chunk_size_is_small",
|
||||
],
|
||||
argnames=["chunk_size", "expected_segmentation_len"],
|
||||
argvalues=[
|
||||
(0, 2),
|
||||
(1000, 1),
|
||||
(10, 2),
|
||||
],
|
||||
)
|
||||
def test_split_text__on_chunk_size(
|
||||
chunk_size: int,
|
||||
expected_segmentation_len: int,
|
||||
mock_client_with_semantic_text_splitter: Mock,
|
||||
) -> None:
|
||||
sts = AI21SemanticTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
client=mock_client_with_semantic_text_splitter,
|
||||
)
|
||||
segments = sts.split_text("This is a test")
|
||||
assert len(segments) == expected_segmentation_len
|
||||
|
||||
|
||||
def test_split_text__on_large_chunk_size__should_merge_chunks(
|
||||
mock_client_with_semantic_text_splitter: Mock,
|
||||
) -> None:
|
||||
sts_no_merge = AI21SemanticTextSplitter(
|
||||
client=mock_client_with_semantic_text_splitter
|
||||
)
|
||||
sts_merge = AI21SemanticTextSplitter(
|
||||
client=mock_client_with_semantic_text_splitter,
|
||||
chunk_size=1000,
|
||||
)
|
||||
segments_no_merge = sts_no_merge.split_text("This is a test")
|
||||
segments_merge = sts_merge.split_text("This is a test")
|
||||
assert len(segments_merge) > 0
|
||||
assert len(segments_no_merge) > 0
|
||||
assert len(segments_no_merge) > len(segments_merge)
|
||||
|
||||
|
||||
def test_split_text__on_small_chunk_size__should_not_merge_chunks(
|
||||
mock_client_with_semantic_text_splitter: Mock,
|
||||
) -> None:
|
||||
sts_no_merge = AI21SemanticTextSplitter(
|
||||
client=mock_client_with_semantic_text_splitter
|
||||
)
|
||||
segments = sts_no_merge.split_text("This is a test")
|
||||
assert len(segments) == 2
|
||||
for index in range(2):
|
||||
assert segments[index] == SEGMENTS[index].segment_text
|
||||
|
||||
|
||||
def test_create_documents__on_start_index__should_should_add_start_index(
|
||||
mock_client_with_semantic_text_splitter: Mock,
|
||||
) -> None:
|
||||
sts = AI21SemanticTextSplitter(
|
||||
client=mock_client_with_semantic_text_splitter,
|
||||
add_start_index=True,
|
||||
)
|
||||
|
||||
response = sts.create_documents(texts=[TEXT])
|
||||
assert len(response) > 0
|
||||
for segment in response:
|
||||
assert segment.page_content is not None
|
||||
assert segment.metadata is not None
|
||||
assert "start_index" in segment.metadata
|
||||
assert segment.metadata["start_index"] > -1
|
||||
|
||||
|
||||
def test_create_documents__when_metadata_from_user__should_add_metadata(
|
||||
mock_client_with_semantic_text_splitter: Mock,
|
||||
) -> None:
|
||||
sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter)
|
||||
metadatas = [{"hello": "world"}]
|
||||
response = sts.create_documents(texts=[TEXT], metadatas=metadatas)
|
||||
assert len(response) > 0
|
||||
for index in range(len(response)):
|
||||
assert response[index].page_content == SEGMENTS[index].segment_text
|
||||
assert len(response[index].metadata) == 2
|
||||
assert response[index].metadata["source_type"] == SEGMENTS[index].segment_type
|
||||
assert response[index].metadata["hello"] == "world"
|
||||
|
||||
|
||||
def test_split_text_to_documents__when_metadata_not_passed__should_contain_source_type(
|
||||
mock_client_with_semantic_text_splitter: Mock,
|
||||
) -> None:
|
||||
sts = AI21SemanticTextSplitter(client=mock_client_with_semantic_text_splitter)
|
||||
response = sts.split_text_to_documents(TEXT)
|
||||
assert len(response) > 0
|
||||
for segment in response:
|
||||
assert segment.page_content is not None
|
||||
assert segment.metadata is not None
|
||||
assert "source_type" in segment.metadata
|
||||
assert segment.metadata["source_type"] is not None
|
@ -1,36 +0,0 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelUnitTests, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
from langchain_ai21 import ChatAI21
|
||||
|
||||
|
||||
class TestAI21J2(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "j2-ultra",
|
||||
"api_key": "test_api_key",
|
||||
}
|
||||
|
||||
|
||||
class TestAI21Jamba(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "jamba-instruct",
|
||||
"api_key": "test_api_key",
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_ai21.embeddings import _split_texts_into_batches
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_chunk_size_is_2__should_return_3_chunks",
|
||||
"when_texts_is_empty__should_return_empty_list",
|
||||
"when_chunk_size_is_1__should_return_10_chunks",
|
||||
],
|
||||
argnames=["input_texts", "chunk_size", "expected_output"],
|
||||
argvalues=[
|
||||
(["a", "b", "c", "d", "e"], 2, [["a", "b"], ["c", "d"], ["e"]]),
|
||||
([], 3, []),
|
||||
(
|
||||
["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||
1,
|
||||
[["1"], ["2"], ["3"], ["4"], ["5"], ["6"], ["7"], ["8"], ["9"], ["10"]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_chunked_text_generator(
|
||||
input_texts: List[str], chunk_size: int, expected_output: List[List[str]]
|
||||
) -> None:
|
||||
result = list(_split_texts_into_batches(input_texts, chunk_size))
|
||||
assert result == expected_output
|
1
libs/partners/upstage/.gitignore
vendored
1
libs/partners/upstage/.gitignore
vendored
@ -1 +0,0 @@
|
||||
__pycache__
|
@ -1,3 +0,0 @@
|
||||
This package has moved!
|
||||
|
||||
https://github.com/langchain-ai/langchain-upstage/tree/main/libs/upstage
|
Loading…
Reference in New Issue
Block a user