mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 19:09:57 +00:00
feat(add): LLM integration of Cloudflare Workers AI (#14322)
Add [Text Generation by Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/text-generation/). It's a new LLM integration. - Dependencies: N/A
This commit is contained in:
127
libs/langchain/langchain/llms/cloudflare_workersai.py
Normal file
127
libs/langchain/langchain/llms/cloudflare_workersai.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CloudflareWorkersAI(LLM):
|
||||
"""Langchain LLM class to help to access Cloudflare Workers AI service.
|
||||
|
||||
To use, you must provide an API token and
|
||||
account ID to access Cloudflare Workers AI, and
|
||||
pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.cloudflare_workersai import CloudflareWorkersAI
|
||||
|
||||
my_account_id = "my_account_id"
|
||||
my_api_token = "my_secret_api_token"
|
||||
llm_model = "@cf/meta/llama-2-7b-chat-int8"
|
||||
|
||||
cf_ai = CloudflareWorkersAI(
|
||||
account_id=my_account_id,
|
||||
api_token=my_api_token,
|
||||
model=llm_model
|
||||
)
|
||||
"""
|
||||
|
||||
account_id: str
|
||||
api_token: str
|
||||
model: str = "@cf/meta/llama-2-7b-chat-int8"
|
||||
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
|
||||
streaming: bool = False
|
||||
endpoint_url: str = ""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the Cloudflare Workers AI class."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of LLM."""
|
||||
return "cloudflare"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Default parameters"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Identifying parameters"""
|
||||
return {
|
||||
"account_id": self.account_id,
|
||||
"api_token": self.api_token,
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
|
||||
def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
|
||||
"""Call Cloudflare Workers API"""
|
||||
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||
data = {"prompt": prompt, "stream": self.streaming, **params}
|
||||
response = requests.post(self.endpoint_url, headers=headers, json=data)
|
||||
return response
|
||||
|
||||
def _process_response(self, response: requests.Response) -> str:
|
||||
"""Process API response"""
|
||||
if response.ok:
|
||||
data = response.json()
|
||||
return data["result"]["response"]
|
||||
else:
|
||||
raise ValueError(f"Request failed with status {response.status_code}")
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Streaming prediction"""
|
||||
original_steaming: bool = self.streaming
|
||||
self.streaming = True
|
||||
_response_prefix_count = len("data: ")
|
||||
_response_stream_end = b"data: [DONE]"
|
||||
for chunk in self._call_api(prompt, kwargs).iter_lines():
|
||||
if chunk == _response_stream_end:
|
||||
break
|
||||
if len(chunk) > _response_prefix_count:
|
||||
try:
|
||||
data = json.loads(chunk[_response_prefix_count:])
|
||||
except Exception as e:
|
||||
logger.debug(chunk)
|
||||
raise e
|
||||
if data is not None and "response" in data:
|
||||
yield GenerationChunk(text=data["response"])
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(data["response"])
|
||||
logger.debug("stream end")
|
||||
self.streaming = original_steaming
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Regular prediction"""
|
||||
if self.streaming:
|
||||
return "".join(
|
||||
[c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
|
||||
)
|
||||
else:
|
||||
response = self._call_api(prompt, kwargs)
|
||||
return self._process_response(response)
|
@@ -0,0 +1,46 @@
|
||||
import responses
|
||||
|
||||
from langchain.llms.cloudflare_workersai import CloudflareWorkersAI
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_cloudflare_workersai_call() -> None:
|
||||
responses.add(
|
||||
responses.POST,
|
||||
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8",
|
||||
json={"result": {"response": "4"}},
|
||||
status=200,
|
||||
)
|
||||
|
||||
llm = CloudflareWorkersAI(
|
||||
account_id="my_account_id",
|
||||
api_token="my_api_token",
|
||||
model="@cf/meta/llama-2-7b-chat-int8",
|
||||
)
|
||||
output = llm("What is 2 + 2?")
|
||||
|
||||
assert output == "4"
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_cloudflare_workersai_stream() -> None:
|
||||
response_body = ['data: {"response": "Hello"}', "data: [DONE]"]
|
||||
responses.add(
|
||||
responses.POST,
|
||||
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8",
|
||||
body="\n".join(response_body),
|
||||
status=200,
|
||||
)
|
||||
|
||||
llm = CloudflareWorkersAI(
|
||||
account_id="my_account_id",
|
||||
api_token="my_api_token",
|
||||
model="@cf/meta/llama-2-7b-chat-int8",
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
outputs = []
|
||||
for chunk in llm.stream("Say Hello"):
|
||||
outputs.append(chunk)
|
||||
|
||||
assert "".join(outputs) == "Hello"
|
Reference in New Issue
Block a user