langchain/libs/community/langchain_community/llms/titan_takeoff.py
pjb157 479be3cc91
community[minor]: Unify Titan Takeoff Integrations and Adding Embedding Support (#18775)
**Community: Unify Titan Takeoff Integrations and Adding Embedding
Support**

 **Description:** 
Titan Takeoff no longer reflects this either of the integrations in the
community folder. The two integrations (TitanTakeoffPro and
TitanTakeoff) where causing confusion with clients, so have moved code
into one place and created an alias for backwards compatibility. Added
Takeoff Client python package to do the bulk of the work with the
requests, this is because this package is actively updated with new
versions of Takeoff. So this integration will be far more robust and
will not degrade as badly over time.

**Issue:**
Fixes bugs in the old Titan integrations and unified the code with added
unit test converge to avoid future problems.

**Dependencies:**
Added optional dependency takeoff-client, all imports still work without
dependency including the Titan Takeoff classes but just will fail on
initialisation if not pip installed takeoff-client

**Twitter**
@MeryemArik9

Thanks all :)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-04-17 01:43:35 +00:00

262 lines
9.0 KiB
Python

from enum import Enum
from typing import Any, Iterator, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel
from langchain_community.llms.utils import enforce_stop_tokens
class Device(str, Enum):
"""The device to use for inference, cuda or cpu"""
cuda = "cuda"
cpu = "cpu"
class ReaderConfig(BaseModel):
class Config:
protected_namespaces = ()
model_name: str
"""The name of the model to use"""
device: Device = Device.cuda
"""The device to use for inference, cuda or cpu"""
consumer_group: str = "primary"
"""The consumer group to place the reader into"""
tensor_parallel: Optional[int] = None
"""The number of gpus you would like your model to be split across"""
max_seq_length: int = 512
"""The maximum sequence length to use for inference, defaults to 512"""
max_batch_size: int = 4
"""The max batch size for continuous batching of requests"""
class TitanTakeoff(LLM):
"""Titan Takeoff API LLMs.
Titan Takeoff is a wrapper to interface with Takeoff Inference API for
generative text to text language models.
You can use this wrapper to send requests to a generative language model
and to deploy readers with Takeoff.
Examples:
This is an example how to deploy a generative language model and send
requests.
.. code-block:: python
# Import the TitanTakeoff class from community package
import time
from langchain_community.llms import TitanTakeoff
# Specify the embedding reader you'd like to deploy
reader_1 = {
"model_name": "TheBloke/Llama-2-7b-Chat-AWQ",
"device": "cuda",
"tensor_parallel": 1,
"consumer_group": "llama"
}
# For every reader you pass into models arg Takeoff will spin
# up a reader according to the specs you provide. If you don't
# specify the arg no models are spun up and it assumes you have
# already done this separately.
llm = TitanTakeoff(models=[reader_1])
# Wait for the reader to be deployed, time needed depends on the
# model size and your internet speed
time.sleep(60)
# Returns the query, ie a List[float], sent to `llama` consumer group
# where we just spun up the Llama 7B model
print(embed.invoke(
"Where can I see football?", consumer_group="llama"
))
# You can also send generation parameters to the model, any of the
# following can be passed in as kwargs:
# https://docs.titanml.co/docs/next/apis/Takeoff%20inference_REST_API/generate#request
# for instance:
print(embed.invoke(
"Where can I see football?", consumer_group="llama", max_new_tokens=100
))
"""
base_url: str = "http://localhost"
"""The base URL of the Titan Takeoff (Pro) server. Default = "http://localhost"."""
port: int = 3000
"""The port of the Titan Takeoff (Pro) server. Default = 3000."""
mgmt_port: int = 3001
"""The management port of the Titan Takeoff (Pro) server. Default = 3001."""
streaming: bool = False
"""Whether to stream the output. Default = False."""
client: Any = None
"""Takeoff Client Python SDK used to interact with Takeoff API"""
def __init__(
self,
base_url: str = "http://localhost",
port: int = 3000,
mgmt_port: int = 3001,
streaming: bool = False,
models: List[ReaderConfig] = [],
):
"""Initialize the Titan Takeoff language wrapper.
Args:
base_url (str, optional): The base URL where the Takeoff
Inference Server is listening. Defaults to `http://localhost`.
port (int, optional): What port is Takeoff Inference API
listening on. Defaults to 3000.
mgmt_port (int, optional): What port is Takeoff Management API
listening on. Defaults to 3001.
streaming (bool, optional): Whether you want to by default use the
generate_stream endpoint over generate to stream responses.
Defaults to False. In reality, this is not significantly different
as the streamed response is buffered and returned similar to the
non-streamed response, but the run manager is applied per token
generated.
models (List[ReaderConfig], optional): Any readers you'd like to
spin up on. Defaults to [].
Raises:
ImportError: If you haven't installed takeoff-client, you will
get an ImportError. To remedy run `pip install 'takeoff-client==0.4.0'`
"""
super().__init__(
base_url=base_url, port=port, mgmt_port=mgmt_port, streaming=streaming
)
try:
from takeoff_client import TakeoffClient
except ImportError:
raise ImportError(
"takeoff-client is required for TitanTakeoff. "
"Please install it with `pip install 'takeoff-client>=0.4.0'`."
)
self.client = TakeoffClient(
self.base_url, port=self.port, mgmt_port=self.mgmt_port
)
for model in models:
self.client.create_reader(model)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "titan_takeoff"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Titan Takeoff (Pro) generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
run_manager: Optional callback manager to use when streaming.
Returns:
The string generated by the model.
Example:
.. code-block:: python
model = TitanTakeoff()
prompt = "What is the capital of the United Kingdom?"
# Use of model(prompt), ie `__call__` was deprecated in LangChain 0.1.7,
# use model.invoke(prompt) instead.
response = model.invoke(prompt)
"""
if self.streaming:
text_output = ""
for chunk in self._stream(
prompt=prompt,
stop=stop,
run_manager=run_manager,
):
text_output += chunk.text
return text_output
response = self.client.generate(prompt, **kwargs)
text = response["text"]
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Call out to Titan Takeoff (Pro) stream endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
run_manager: Optional callback manager to use when streaming.
Yields:
A dictionary like object containing a string token.
Example:
.. code-block:: python
model = TitanTakeoff()
prompt = "What is the capital of the United Kingdom?"
response = model.stream(prompt)
# OR
model = TitanTakeoff(streaming=True)
response = model.invoke(prompt)
"""
response = self.client.generate_stream(prompt, **kwargs)
buffer = ""
for text in response:
buffer += text.data
if "data:" in buffer:
# Remove the first instance of "data:" from the buffer.
if buffer.startswith("data:"):
buffer = ""
if len(buffer.split("data:", 1)) == 2:
content, _ = buffer.split("data:", 1)
buffer = content.rstrip("\n")
# Trim the buffer to only have content after the "data:" part.
if buffer: # Ensure that there's content to process.
chunk = GenerationChunk(text=buffer)
buffer = "" # Reset buffer for the next set of data.
yield chunk
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
# Yield any remaining content in the buffer.
if buffer:
chunk = GenerationChunk(text=buffer.replace("</s>", ""))
yield chunk
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)