feat(add): LLM integration of Cloudflare Workers AI (#14322)

Add [Text Generation by Cloudflare Workers
AI](https://developers.cloudflare.com/workers-ai/models/text-generation/).
It's a new LLM integration.

- Dependencies: N/A
This commit is contained in:
cxumol
2023-12-06 08:24:19 -08:00
committed by GitHub
parent 5efaedf488
commit 06e3316f54
3 changed files with 300 additions and 0 deletions

View File

@@ -0,0 +1,127 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Optional
import requests
from langchain_core.outputs import GenerationChunk
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class CloudflareWorkersAI(LLM):
"""Langchain LLM class to help to access Cloudflare Workers AI service.
To use, you must provide an API token and
account ID to access Cloudflare Workers AI, and
pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.cloudflare_workersai import CloudflareWorkersAI
my_account_id = "my_account_id"
my_api_token = "my_secret_api_token"
llm_model = "@cf/meta/llama-2-7b-chat-int8"
cf_ai = CloudflareWorkersAI(
account_id=my_account_id,
api_token=my_api_token,
model=llm_model
)
"""
account_id: str
api_token: str
model: str = "@cf/meta/llama-2-7b-chat-int8"
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
streaming: bool = False
endpoint_url: str = ""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the Cloudflare Workers AI class."""
super().__init__(**kwargs)
self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "cloudflare"
@property
def _default_params(self) -> Dict[str, Any]:
"""Default parameters"""
return {}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Identifying parameters"""
return {
"account_id": self.account_id,
"api_token": self.api_token,
"model": self.model,
"base_url": self.base_url,
}
def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
"""Call Cloudflare Workers API"""
headers = {"Authorization": f"Bearer {self.api_token}"}
data = {"prompt": prompt, "stream": self.streaming, **params}
response = requests.post(self.endpoint_url, headers=headers, json=data)
return response
def _process_response(self, response: requests.Response) -> str:
"""Process API response"""
if response.ok:
data = response.json()
return data["result"]["response"]
else:
raise ValueError(f"Request failed with status {response.status_code}")
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Streaming prediction"""
original_steaming: bool = self.streaming
self.streaming = True
_response_prefix_count = len("data: ")
_response_stream_end = b"data: [DONE]"
for chunk in self._call_api(prompt, kwargs).iter_lines():
if chunk == _response_stream_end:
break
if len(chunk) > _response_prefix_count:
try:
data = json.loads(chunk[_response_prefix_count:])
except Exception as e:
logger.debug(chunk)
raise e
if data is not None and "response" in data:
yield GenerationChunk(text=data["response"])
if run_manager:
run_manager.on_llm_new_token(data["response"])
logger.debug("stream end")
self.streaming = original_steaming
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Regular prediction"""
if self.streaming:
return "".join(
[c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
)
else:
response = self._call_api(prompt, kwargs)
return self._process_response(response)

View File

@@ -0,0 +1,46 @@
import responses
from langchain.llms.cloudflare_workersai import CloudflareWorkersAI
@responses.activate
def test_cloudflare_workersai_call() -> None:
responses.add(
responses.POST,
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8",
json={"result": {"response": "4"}},
status=200,
)
llm = CloudflareWorkersAI(
account_id="my_account_id",
api_token="my_api_token",
model="@cf/meta/llama-2-7b-chat-int8",
)
output = llm("What is 2 + 2?")
assert output == "4"
@responses.activate
def test_cloudflare_workersai_stream() -> None:
response_body = ['data: {"response": "Hello"}', "data: [DONE]"]
responses.add(
responses.POST,
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8",
body="\n".join(response_body),
status=200,
)
llm = CloudflareWorkersAI(
account_id="my_account_id",
api_token="my_api_token",
model="@cf/meta/llama-2-7b-chat-int8",
streaming=True,
)
outputs = []
for chunk in llm.stream("Say Hello"):
outputs.append(chunk)
assert "".join(outputs) == "Hello"