mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
Harrison/ibm (#14133)
Co-authored-by: Mateusz Szewczyk <139469471+MateuszOssGit@users.noreply.github.com>
This commit is contained in:
parent
943aa01c14
commit
ae646701c4
297
docs/docs/integrations/llms/watsonxllm.ipynb
Normal file
297
docs/docs/integrations/llms/watsonxllm.ipynb
Normal file
@ -0,0 +1,297 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "70996d8a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# WatsonxLLM\n",
|
||||
"\n",
|
||||
"[WatsonxLLM](https://ibm.github.io/watson-machine-learning-sdk/fm_extensions.html) is wrapper for IBM [watsonx.ai](https://www.ibm.com/products/watsonx-ai) foundation models.\n",
|
||||
"This example shows how to communicate with watsonx.ai models using LangChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ea35b2b7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Install the package [`ibm_watson_machine_learning`](https://ibm.github.io/watson-machine-learning-sdk/install.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2f1fff4e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install ibm_watson_machine_learning"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f406e092",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This cell defines the WML credentials required to work with watsonx Foundation Model inferencing.\n",
|
||||
"\n",
|
||||
"**Action:** Provide the IBM Cloud user API key. For details, see\n",
|
||||
"[documentation](https://cloud.ibm.com/docs/account?topic=account-userapikey&interface=ui)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "11d572a1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"watsonx_api_key = getpass()\n",
|
||||
"os.environ[\"WATSONX_APIKEY\"] = watsonx_api_key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e36acbef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load the model\n",
|
||||
"You might need to adjust model `parameters` for different models or tasks, to do so please refer to [documentation](https://ibm.github.io/watson-machine-learning-sdk/model.html#metanames.GenTextParamsMetaNames)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "407cd500",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams\n",
|
||||
"\n",
|
||||
"parameters = {\n",
|
||||
" GenParams.DECODING_METHOD: \"sample\",\n",
|
||||
" GenParams.MAX_NEW_TOKENS: 100,\n",
|
||||
" GenParams.MIN_NEW_TOKENS: 1,\n",
|
||||
" GenParams.TEMPERATURE: 0.5,\n",
|
||||
" GenParams.TOP_K: 50,\n",
|
||||
" GenParams.TOP_P: 1,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2b586538",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Initialize the `WatsonxLLM` class with previous set params."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "359898de",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import WatsonxLLM\n",
|
||||
"\n",
|
||||
"watsonx_llm = WatsonxLLM(\n",
|
||||
" model_id=\"google/flan-ul2\",\n",
|
||||
" url=\"https://us-south.ml.cloud.ibm.com\",\n",
|
||||
" project_id=\"***\",\n",
|
||||
" params=parameters,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2202f4e0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Alternatively you can use Cloud Pak for Data credentials. For details, see [documentation](https://ibm.github.io/watson-machine-learning-sdk/setup_cpd.html).\n",
|
||||
"```\n",
|
||||
"watsonx_llm = WatsonxLLM(\n",
|
||||
" model_id='google/flan-ul2',\n",
|
||||
" url=\"***\",\n",
|
||||
" username=\"***\",\n",
|
||||
" password=\"***\",\n",
|
||||
" instance_id=\"openshift\",\n",
|
||||
" version=\"4.8\",\n",
|
||||
" project_id='***',\n",
|
||||
" params=parameters\n",
|
||||
")\n",
|
||||
"``` "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c25ecbd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Chain\n",
|
||||
"Create `PromptTemplate` objects which will be responsible for creating a random question."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "c7d80c05",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"template = \"Generate a random question about {topic}: Question: \"\n",
|
||||
"prompt = PromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "79056d8e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Provide a topic and run the `LLMChain`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "dc076c56",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'How many breeds of dog are there?'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=watsonx_llm)\n",
|
||||
"llm_chain.run(\"dog\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f571001d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Calling the Model Directly\n",
|
||||
"To obtain completions, you can can the model directly using string prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "beea2b5b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'dog'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Calling a single prompt\n",
|
||||
"\n",
|
||||
"watsonx_llm(\"Who is man's best friend?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "8ab1a25a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[Generation(text='greyhounds', generation_info={'generated_token_count': 4, 'input_token_count': 8, 'finish_reason': 'eos_token'})], [Generation(text='The Basenji is a dog breed from South Africa.', generation_info={'generated_token_count': 13, 'input_token_count': 7, 'finish_reason': 'eos_token'})]], llm_output={'model_id': 'google/flan-ul2'}, run=[RunInfo(run_id=UUID('03c73a42-db68-428e-ab8d-8ae10abc84fc')), RunInfo(run_id=UUID('c289f67a-87d6-4c8b-a8b7-0b5012c94ca8'))])"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Calling multiple prompts\n",
|
||||
"\n",
|
||||
"watsonx_llm.generate(\n",
|
||||
" [\n",
|
||||
" \"The fastest dog in the world?\",\n",
|
||||
" \"Describe your chosen dog breed\",\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2c9da33",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Streaming the Model output \n",
|
||||
"\n",
|
||||
"You can stream the model output."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"id": "3f63166a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The golden retriever is my favorite dog because it is very friendly and good with children."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for chunk in watsonx_llm.stream(\n",
|
||||
" \"Describe your favorite breed of dog and why it is your favorite.\"\n",
|
||||
"):\n",
|
||||
" print(chunk, end=\"\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -504,6 +504,12 @@ def _import_vllm_openai() -> Any:
|
||||
return VLLMOpenAI
|
||||
|
||||
|
||||
def _import_watsonxllm() -> Any:
|
||||
from langchain.llms.watsonxllm import WatsonxLLM
|
||||
|
||||
return WatsonxLLM
|
||||
|
||||
|
||||
def _import_writer() -> Any:
|
||||
from langchain.llms.writer import Writer
|
||||
|
||||
@ -685,6 +691,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_vllm()
|
||||
elif name == "VLLMOpenAI":
|
||||
return _import_vllm_openai()
|
||||
elif name == "WatsonxLLM":
|
||||
return _import_watsonxllm()
|
||||
elif name == "Writer":
|
||||
return _import_writer()
|
||||
elif name == "Xinference":
|
||||
@ -777,6 +785,7 @@ __all__ = [
|
||||
"VertexAIModelGarden",
|
||||
"VLLM",
|
||||
"VLLMOpenAI",
|
||||
"WatsonxLLM",
|
||||
"Writer",
|
||||
"OctoAIEndpoint",
|
||||
"Xinference",
|
||||
@ -861,6 +870,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"openllm_client": _import_openllm,
|
||||
"vllm": _import_vllm,
|
||||
"vllm_openai": _import_vllm_openai,
|
||||
"watsonxllm": _import_watsonxllm,
|
||||
"writer": _import_writer,
|
||||
"xinference": _import_xinference,
|
||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||
|
354
libs/langchain/langchain/llms/watsonxllm.py
Normal file
354
libs/langchain/langchain/llms/watsonxllm.py
Normal file
@ -0,0 +1,354 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.pydantic_v1 import Extra, SecretStr, root_validator
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.output import Generation, GenerationChunk
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatsonxLLM(BaseLLM):
|
||||
"""
|
||||
IBM watsonx.ai large language models.
|
||||
|
||||
To use, you should have ``ibm_watson_machine_learning`` python package installed,
|
||||
and the environment variable ``WATSONX_APIKEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames
|
||||
parameters = {
|
||||
GenTextParamsMetaNames.DECODING_METHOD: "sample",
|
||||
GenTextParamsMetaNames.MAX_NEW_TOKENS: 100,
|
||||
GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
|
||||
GenTextParamsMetaNames.TEMPERATURE: 0.5,
|
||||
GenTextParamsMetaNames.TOP_K: 50,
|
||||
GenTextParamsMetaNames.TOP_P: 1,
|
||||
}
|
||||
|
||||
from langchain.llms import WatsonxLLM
|
||||
llm = WatsonxLLM(
|
||||
model_id="google/flan-ul2",
|
||||
url="https://us-south.ml.cloud.ibm.com",
|
||||
apikey="*****",
|
||||
project_id="*****",
|
||||
params=parameters,
|
||||
)
|
||||
"""
|
||||
|
||||
model_id: str = ""
|
||||
"""Type of model to use."""
|
||||
|
||||
project_id: str = ""
|
||||
"""ID of the Watson Studio project."""
|
||||
|
||||
space_id: str = ""
|
||||
"""ID of the Watson Studio space."""
|
||||
|
||||
url: Optional[SecretStr] = None
|
||||
"""Url to Watson Machine Learning instance"""
|
||||
|
||||
apikey: Optional[SecretStr] = None
|
||||
"""Apikey to Watson Machine Learning instance"""
|
||||
|
||||
token: Optional[SecretStr] = None
|
||||
"""Token to Watson Machine Learning instance"""
|
||||
|
||||
password: Optional[SecretStr] = None
|
||||
"""Password to Watson Machine Learning instance"""
|
||||
|
||||
username: Optional[SecretStr] = None
|
||||
"""Username to Watson Machine Learning instance"""
|
||||
|
||||
instance_id: Optional[SecretStr] = None
|
||||
"""Instance_id of Watson Machine Learning instance"""
|
||||
|
||||
version: Optional[SecretStr] = None
|
||||
"""Version of Watson Machine Learning instance"""
|
||||
|
||||
params: Optional[dict] = None
|
||||
"""Model parameters to use during generate requests."""
|
||||
|
||||
verify: Union[str, bool] = ""
|
||||
"""User can pass as verify one of following:
|
||||
the path to a CA_BUNDLE file
|
||||
the path of directory with certificates of trusted CAs
|
||||
True - default path to truststore will be taken
|
||||
False - no verification will be made"""
|
||||
|
||||
streaming: bool = False
|
||||
""" Whether to stream the results or not. """
|
||||
|
||||
watsonx_model: Any
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"url": "WATSONX_URL",
|
||||
"apikey": "WATSONX_APIKEY",
|
||||
"token": "WATSONX_TOKEN",
|
||||
"password": "WATSONX_PASSWORD",
|
||||
"username": "WATSONX_USERNAME",
|
||||
"instance_id": "WATSONX_INSTANCE_ID",
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that credentials and python package exists in environment."""
|
||||
values["url"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "url", "WATSONX_URL")
|
||||
)
|
||||
if "cloud.ibm.com" in values.get("url", "").get_secret_value():
|
||||
values["apikey"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "apikey", "WATSONX_APIKEY")
|
||||
)
|
||||
else:
|
||||
if (
|
||||
not values["token"]
|
||||
and "WATSONX_TOKEN" not in os.environ
|
||||
and not values["password"]
|
||||
and "WATSONX_PASSWORD" not in os.environ
|
||||
and not values["apikey"]
|
||||
and "WATSONX_APIKEY" not in os.environ
|
||||
):
|
||||
raise ValueError(
|
||||
"Did not find 'token', 'password' or 'apikey',"
|
||||
" please add an environment variable"
|
||||
" `WATSONX_TOKEN`, 'WATSONX_PASSWORD' or 'WATSONX_APIKEY' "
|
||||
"which contains it,"
|
||||
" or pass 'token', 'password' or 'apikey'"
|
||||
" as a named parameter."
|
||||
)
|
||||
elif values["token"] or "WATSONX_TOKEN" in os.environ:
|
||||
values["token"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "token", "WATSONX_TOKEN")
|
||||
)
|
||||
elif values["password"] or "WATSONX_PASSWORD" in os.environ:
|
||||
values["password"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "password", "WATSONX_PASSWORD")
|
||||
)
|
||||
values["username"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "username", "WATSONX_USERNAME")
|
||||
)
|
||||
elif values["apikey"] or "WATSONX_APIKEY" in os.environ:
|
||||
values["apikey"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "apikey", "WATSONX_APIKEY")
|
||||
)
|
||||
values["username"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "username", "WATSONX_USERNAME")
|
||||
)
|
||||
if not values["instance_id"] or "WATSONX_INSTANCE_ID" not in os.environ:
|
||||
values["instance_id"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "instance_id", "WATSONX_INSTANCE_ID")
|
||||
)
|
||||
|
||||
try:
|
||||
from ibm_watson_machine_learning.foundation_models import Model
|
||||
|
||||
credentials = {
|
||||
"url": values["url"].get_secret_value() if values["url"] else None,
|
||||
"apikey": values["apikey"].get_secret_value()
|
||||
if values["apikey"]
|
||||
else None,
|
||||
"token": values["token"].get_secret_value()
|
||||
if values["token"]
|
||||
else None,
|
||||
"password": values["password"].get_secret_value()
|
||||
if values["password"]
|
||||
else None,
|
||||
"username": values["username"].get_secret_value()
|
||||
if values["username"]
|
||||
else None,
|
||||
"instance_id": values["instance_id"].get_secret_value()
|
||||
if values["instance_id"]
|
||||
else None,
|
||||
"version": values["version"].get_secret_value()
|
||||
if values["version"]
|
||||
else None,
|
||||
}
|
||||
credentials_without_none_value = {
|
||||
key: value for key, value in credentials.items() if value is not None
|
||||
}
|
||||
|
||||
watsonx_model = Model(
|
||||
model_id=values["model_id"],
|
||||
credentials=credentials_without_none_value,
|
||||
params=values["params"],
|
||||
project_id=values["project_id"],
|
||||
space_id=values["space_id"],
|
||||
verify=values["verify"],
|
||||
)
|
||||
values["watsonx_model"] = watsonx_model
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import ibm_watson_machine_learning python package. "
|
||||
"Please install it with `pip install ibm_watson_machine_learning`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
"params": self.params,
|
||||
"project_id": self.project_id,
|
||||
"space_id": self.space_id,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "IBM watsonx.ai"
|
||||
|
||||
@staticmethod
|
||||
def _extract_token_usage(
|
||||
response: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
if response is None:
|
||||
return {"generated_token_count": 0, "input_token_count": 0}
|
||||
|
||||
input_token_count = 0
|
||||
generated_token_count = 0
|
||||
|
||||
def get_count_value(key: str, result: Dict[str, Any]) -> int:
|
||||
return result.get(key, 0) or 0
|
||||
|
||||
for res in response:
|
||||
results = res.get("results")
|
||||
if results:
|
||||
input_token_count += get_count_value("input_token_count", results[0])
|
||||
generated_token_count += get_count_value(
|
||||
"generated_token_count", results[0]
|
||||
)
|
||||
|
||||
return {
|
||||
"generated_token_count": generated_token_count,
|
||||
"input_token_count": input_token_count,
|
||||
}
|
||||
|
||||
def _create_llm_result(self, response: List[dict]) -> LLMResult:
|
||||
"""Create the LLMResult from the choices and prompts."""
|
||||
generations = []
|
||||
for res in response:
|
||||
results = res.get("results")
|
||||
if results:
|
||||
finish_reason = results[0].get("stop_reason")
|
||||
gen = Generation(
|
||||
text=results[0].get("generated_text"),
|
||||
generation_info={"finish_reason": finish_reason},
|
||||
)
|
||||
generations.append([gen])
|
||||
final_token_usage = self._extract_token_usage(response)
|
||||
llm_output = {"token_usage": final_token_usage, "model_id": self.model_id}
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the IBM watsonx.ai inference endpoint.
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager: Optional callback manager.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = watsonxllm("What is a molecule")
|
||||
"""
|
||||
result = self._generate(
|
||||
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call the IBM watsonx.ai inference endpoint which then generate the response.
|
||||
Args:
|
||||
prompts: List of strings (prompts) to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager: Optional callback manager.
|
||||
Returns:
|
||||
The full LLMResult output.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = watsonxllm.generate(["What is a molecule"])
|
||||
"""
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError(
|
||||
f"WatsonxLLM currently only supports single prompt, got {prompts}"
|
||||
)
|
||||
generation = GenerationChunk(text="")
|
||||
stream_iter = self._stream(
|
||||
prompts[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
for chunk in stream_iter:
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return LLMResult(generations=[[generation]])
|
||||
else:
|
||||
response = self.watsonx_model.generate(prompt=prompts)
|
||||
return self._create_llm_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Call the IBM watsonx.ai inference endpoint which then streams the response.
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager: Optional callback manager.
|
||||
Returns:
|
||||
The iterator which yields generation chunks.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = watsonxllm.stream("What is a molecule")
|
||||
for chunk in response:
|
||||
print(chunk, end='')
|
||||
"""
|
||||
for chunk in self.watsonx_model.generate_text_stream(prompt=prompt):
|
||||
if chunk:
|
||||
yield GenerationChunk(text=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk)
|
@ -0,0 +1,14 @@
|
||||
"""Test WatsonxLLM API wrapper."""
|
||||
|
||||
from langchain.llms import WatsonxLLM
|
||||
|
||||
|
||||
def test_watsonxllm_call() -> None:
|
||||
watsonxllm = WatsonxLLM(
|
||||
model_id="google/flan-ul2",
|
||||
url="https://us-south.ml.cloud.ibm.com",
|
||||
apikey="***",
|
||||
project_id="***",
|
||||
)
|
||||
response = watsonxllm("What color sunflower is?")
|
||||
assert isinstance(response, str)
|
@ -82,6 +82,7 @@ EXPECT_ALL = [
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"VolcEngineMaasLLM",
|
||||
"WatsonxLLM",
|
||||
]
|
||||
|
||||
|
||||
|
55
libs/langchain/tests/unit_tests/llms/test_watsonxllm.py
Normal file
55
libs/langchain/tests/unit_tests/llms/test_watsonxllm.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Test WatsonxLLM API wrapper."""
|
||||
|
||||
from langchain.llms import WatsonxLLM
|
||||
|
||||
|
||||
def test_initialize_watsonxllm_bad_path_without_url() -> None:
|
||||
try:
|
||||
WatsonxLLM(
|
||||
model_id="google/flan-ul2",
|
||||
)
|
||||
except ValueError as e:
|
||||
assert "WATSONX_URL" in e.__str__()
|
||||
|
||||
|
||||
def test_initialize_watsonxllm_cloud_bad_path() -> None:
|
||||
try:
|
||||
WatsonxLLM(model_id="google/flan-ul2", url="https://us-south.ml.cloud.ibm.com")
|
||||
except ValueError as e:
|
||||
assert "WATSONX_APIKEY" in e.__str__()
|
||||
|
||||
|
||||
def test_initialize_watsonxllm_cpd_bad_path_without_all() -> None:
|
||||
try:
|
||||
WatsonxLLM(
|
||||
model_id="google/flan-ul2",
|
||||
url="https://cpd-zen.apps.cpd48.cp.fyre.ibm.com",
|
||||
)
|
||||
except ValueError as e:
|
||||
assert (
|
||||
"WATSONX_APIKEY" in e.__str__()
|
||||
and "WATSONX_PASSWORD" in e.__str__()
|
||||
and "WATSONX_TOKEN" in e.__str__()
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_watsonxllm_cpd_bad_path_password_without_username() -> None:
|
||||
try:
|
||||
WatsonxLLM(
|
||||
model_id="google/flan-ul2",
|
||||
url="https://cpd-zen.apps.cpd48.cp.fyre.ibm.com",
|
||||
password="test_password",
|
||||
)
|
||||
except ValueError as e:
|
||||
assert "WATSONX_USERNAME" in e.__str__()
|
||||
|
||||
|
||||
def test_initialize_watsonxllm_cpd_bad_path_apikey_without_username() -> None:
|
||||
try:
|
||||
WatsonxLLM(
|
||||
model_id="google/flan-ul2",
|
||||
url="https://cpd-zen.apps.cpd48.cp.fyre.ibm.com",
|
||||
apikey="test_apikey",
|
||||
)
|
||||
except ValueError as e:
|
||||
assert "WATSONX_USERNAME" in e.__str__()
|
Loading…
Reference in New Issue
Block a user