mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Added support for streaming in AI21 Jamba Model - **Twitter handle:** https://github.com/AI21Labs - [x] **Add tests and docs**: If you're adding a new integration, please include - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Chester Curme <chester.curme@gmail.com>
251 lines
8.0 KiB
Python
251 lines
8.0 KiB
Python
import asyncio
|
|
from functools import partial
|
|
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
|
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models.chat_models import (
|
|
BaseChatModel,
|
|
LangSmithParams,
|
|
generate_from_stream,
|
|
)
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
from langchain_core.pydantic_v1 import root_validator
|
|
|
|
from langchain_ai21.ai21_base import AI21Base
|
|
from langchain_ai21.chat.chat_adapter import ChatAdapter
|
|
from langchain_ai21.chat.chat_factory import create_chat_adapter
|
|
|
|
|
|
class ChatAI21(BaseChatModel, AI21Base):
|
|
"""ChatAI21 chat model. Different model types support different parameters and
|
|
different parameter values. Please read the [AI21 reference documentation]
|
|
(https://docs.ai21.com/reference) for your model to understand which parameters
|
|
are available.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_ai21 import ChatAI21
|
|
|
|
|
|
model = ChatAI21(
|
|
# defaults to os.environ.get("AI21_API_KEY")
|
|
api_key="my_api_key"
|
|
)
|
|
"""
|
|
|
|
model: str
|
|
"""Model type you wish to interact with.
|
|
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
|
|
num_results: int = 1
|
|
"""The number of responses to generate for a given prompt."""
|
|
stop: Optional[List[str]] = None
|
|
"""Default stop sequences."""
|
|
|
|
max_tokens: int = 16
|
|
"""The maximum number of tokens to generate for each response."""
|
|
|
|
min_tokens: int = 0
|
|
"""The minimum number of tokens to generate for each response.
|
|
_Not supported for all models._"""
|
|
|
|
temperature: float = 0.7
|
|
"""A value controlling the "creativity" of the model's responses."""
|
|
|
|
top_p: float = 1
|
|
"""A value controlling the diversity of the model's responses."""
|
|
|
|
top_k_return: int = 0
|
|
"""The number of top-scoring tokens to consider for each generation step.
|
|
_Not supported for all models._"""
|
|
|
|
frequency_penalty: Optional[Any] = None
|
|
"""A penalty applied to tokens that are frequently generated.
|
|
_Not supported for all models._"""
|
|
|
|
presence_penalty: Optional[Any] = None
|
|
""" A penalty applied to tokens that are already present in the prompt.
|
|
_Not supported for all models._"""
|
|
|
|
count_penalty: Optional[Any] = None
|
|
"""A penalty applied to tokens based on their frequency
|
|
in the generated responses. _Not supported for all models._"""
|
|
|
|
n: int = 1
|
|
"""Number of chat completions to generate for each prompt."""
|
|
streaming: bool = False
|
|
|
|
_chat_adapter: ChatAdapter
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
values = super().validate_environment(values)
|
|
model = values.get("model")
|
|
|
|
values["_chat_adapter"] = create_chat_adapter(model) # type: ignore
|
|
|
|
return values
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of chat model."""
|
|
return "chat-ai21"
|
|
|
|
@property
|
|
def _default_params(self) -> Mapping[str, Any]:
|
|
base_params = {
|
|
"model": self.model,
|
|
"num_results": self.num_results,
|
|
"max_tokens": self.max_tokens,
|
|
"min_tokens": self.min_tokens,
|
|
"temperature": self.temperature,
|
|
"top_p": self.top_p,
|
|
"top_k_return": self.top_k_return,
|
|
"n": self.n,
|
|
}
|
|
if self.stop:
|
|
base_params["stop_sequences"] = self.stop
|
|
|
|
if self.count_penalty is not None:
|
|
base_params["count_penalty"] = self.count_penalty.to_dict()
|
|
|
|
if self.frequency_penalty is not None:
|
|
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
|
|
|
|
if self.presence_penalty is not None:
|
|
base_params["presence_penalty"] = self.presence_penalty.to_dict()
|
|
|
|
return base_params
|
|
|
|
def _get_ls_params(
|
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
|
) -> LangSmithParams:
|
|
"""Get standard params for tracing."""
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
ls_params = LangSmithParams(
|
|
ls_provider="ai21",
|
|
ls_model_name=self.model,
|
|
ls_model_type="chat",
|
|
ls_temperature=params.get("temperature", self.temperature),
|
|
)
|
|
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
|
ls_params["ls_max_tokens"] = ls_max_tokens
|
|
if ls_stop := stop or params.get("stop", None) or self.stop:
|
|
ls_params["ls_stop"] = ls_stop
|
|
return ls_params
|
|
|
|
def _build_params_for_request(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> Mapping[str, Any]:
|
|
params = {}
|
|
converted_messages = self._chat_adapter.convert_messages(messages)
|
|
|
|
if stop is not None:
|
|
if "stop" in kwargs:
|
|
raise ValueError("stop is defined in both stop and kwargs")
|
|
params["stop_sequences"] = stop
|
|
|
|
return {
|
|
**converted_messages,
|
|
**self._default_params,
|
|
**params,
|
|
**kwargs,
|
|
}
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
stream: Optional[bool] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
should_stream = stream or self.streaming
|
|
|
|
if should_stream:
|
|
return self._handle_stream_from_generate(
|
|
messages=messages,
|
|
stop=stop,
|
|
run_manager=run_manager,
|
|
**kwargs,
|
|
)
|
|
|
|
params = self._build_params_for_request(
|
|
messages=messages,
|
|
stop=stop,
|
|
stream=should_stream,
|
|
**kwargs,
|
|
)
|
|
|
|
messages = self._chat_adapter.call(self.client, **params)
|
|
generations = [ChatGeneration(message=message) for message in messages]
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
def _handle_stream_from_generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
stream_iter = self._stream(
|
|
messages=messages,
|
|
stop=stop,
|
|
run_manager=run_manager,
|
|
**kwargs,
|
|
)
|
|
return generate_from_stream(stream_iter)
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
params = self._build_params_for_request(
|
|
messages=messages,
|
|
stop=stop,
|
|
stream=True,
|
|
**kwargs,
|
|
)
|
|
|
|
for chunk in self._chat_adapter.call(self.client, **params):
|
|
if run_manager and isinstance(chunk.message.content, str):
|
|
run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk)
|
|
yield chunk
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
return await asyncio.get_running_loop().run_in_executor(
|
|
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
|
)
|
|
|
|
def _get_system_message_from_message(self, message: BaseMessage) -> str:
|
|
if not isinstance(message.content, str):
|
|
raise ValueError(
|
|
f"System Message must be of type str. Got {type(message.content)}"
|
|
)
|
|
|
|
return message.content
|