mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-22 09:41:52 +00:00
Thank you for contributing to LangChain! - [x] **PR title**: "IBM: Added WatsonxChat to chat models preview, update passing params to invoke method" - [x] **PR message**: - **Description:** Added WatsonxChat passing params to invoke method, added integration tests - **Dependencies:** `ibm_watsonx_ai` - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Erick Friis <erick@langchain.dev>
244 lines
7.7 KiB
Python
244 lines
7.7 KiB
Python
import json
|
|
import os
|
|
|
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.pydantic_v1 import BaseModel
|
|
|
|
from langchain_ibm import ChatWatsonx
|
|
|
|
WX_APIKEY = os.environ.get("WATSONX_APIKEY", "")
|
|
WX_PROJECT_ID = os.environ.get("WATSONX_PROJECT_ID", "")
|
|
|
|
URL = "https://us-south.ml.cloud.ibm.com"
|
|
MODEL_ID = "mistralai/mixtral-8x7b-instruct-v01"
|
|
|
|
|
|
def test_01_generate_chat() -> None:
|
|
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
|
|
messages = [
|
|
("system", "You are a helpful assistant that translates English to French."),
|
|
(
|
|
"human",
|
|
"Translate this sentence from English to French. I love programming.",
|
|
),
|
|
]
|
|
response = chat.invoke(messages)
|
|
assert response
|
|
|
|
|
|
def test_01a_generate_chat_with_invoke_params() -> None:
|
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
|
|
|
|
params = {
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
|
|
}
|
|
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
|
|
messages = [
|
|
("system", "You are a helpful assistant that translates English to French."),
|
|
(
|
|
"human",
|
|
"Translate this sentence from English to French. I love programming.",
|
|
),
|
|
]
|
|
response = chat.invoke(messages, params=params)
|
|
assert response
|
|
|
|
|
|
def test_01b_generate_chat_with_invoke_params() -> None:
|
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
|
|
|
|
params_1 = {
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
|
|
}
|
|
params_2 = {
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
|
|
}
|
|
chat = ChatWatsonx(
|
|
model_id=MODEL_ID,
|
|
url=URL, # type: ignore[arg-type]
|
|
project_id=WX_PROJECT_ID,
|
|
params=params_1, # type: ignore[arg-type]
|
|
)
|
|
messages = [
|
|
("system", "You are a helpful assistant that translates English to French."),
|
|
(
|
|
"human",
|
|
"Translate this sentence from English to French. I love programming.",
|
|
),
|
|
]
|
|
response = chat.invoke(messages, params=params_2)
|
|
assert response
|
|
|
|
|
|
def test_02_generate_chat_with_few_inputs() -> None:
|
|
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
|
|
message = HumanMessage(content="Hello")
|
|
response = chat.generate([[message], [message]])
|
|
assert response
|
|
|
|
|
|
def test_03_generate_chat_with_few_various_inputs() -> None:
|
|
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
|
|
system_message = SystemMessage(content="You are to chat with the user.")
|
|
human_message = HumanMessage(content="Hello")
|
|
response = chat.invoke([system_message, human_message])
|
|
assert isinstance(response, BaseMessage)
|
|
assert isinstance(response.content, str)
|
|
|
|
|
|
def test_05_generate_chat_with_stream() -> None:
|
|
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
|
|
response = chat.stream("What's the weather in san francisco")
|
|
for chunk in response:
|
|
assert isinstance(chunk.content, str)
|
|
|
|
|
|
def test_05a_generate_chat_with_stream_with_param() -> None:
|
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
|
|
|
|
params = {
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 1,
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
|
|
}
|
|
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
|
|
response = chat.stream("What's the weather in san francisco", params=params)
|
|
for chunk in response:
|
|
assert isinstance(chunk.content, str)
|
|
|
|
|
|
def test_10_chaining() -> None:
|
|
chat = ChatWatsonx(model_id=MODEL_ID, url=URL, project_id=WX_PROJECT_ID) # type: ignore[arg-type]
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
(
|
|
"system",
|
|
"You are a helpful assistant that "
|
|
"translates {input_language} to {output_language}.",
|
|
),
|
|
("human", "{input}"),
|
|
]
|
|
)
|
|
chain = prompt | chat
|
|
|
|
response = chain.invoke(
|
|
{
|
|
"input_language": "English",
|
|
"output_language": "German",
|
|
"input": "I love programming.",
|
|
}
|
|
)
|
|
assert response
|
|
|
|
|
|
def test_11_chaining_with_params() -> None:
|
|
parameters = {
|
|
GenTextParamsMetaNames.DECODING_METHOD: "sample",
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 5,
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
|
|
}
|
|
chat = ChatWatsonx(
|
|
model_id=MODEL_ID,
|
|
url=URL, # type: ignore[arg-type]
|
|
project_id=WX_PROJECT_ID,
|
|
params=parameters, # type: ignore[arg-type]
|
|
)
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
(
|
|
"system",
|
|
"You are a helpful assistant that translates "
|
|
"{input_language} to {output_language}.",
|
|
),
|
|
("human", "{input}"),
|
|
]
|
|
)
|
|
chain = prompt | chat
|
|
|
|
response = chain.invoke(
|
|
{
|
|
"input_language": "English",
|
|
"output_language": "German",
|
|
"input": "I love programming.",
|
|
}
|
|
)
|
|
assert response
|
|
|
|
|
|
def test_20_tool_choice() -> None:
|
|
"""Test that tool choice is respected."""
|
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
|
|
|
|
params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500}
|
|
chat = ChatWatsonx(
|
|
model_id=MODEL_ID,
|
|
url=URL, # type: ignore[arg-type]
|
|
project_id=WX_PROJECT_ID,
|
|
params=params, # type: ignore[arg-type]
|
|
)
|
|
|
|
class MyTool(BaseModel):
|
|
name: str
|
|
age: int
|
|
|
|
with_tool = chat.bind_tools([MyTool], tool_choice="MyTool")
|
|
|
|
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
|
assert isinstance(resp, AIMessage)
|
|
assert resp.content == "" # should just be tool call
|
|
tool_calls = resp.additional_kwargs["tool_calls"]
|
|
assert len(tool_calls) == 1
|
|
tool_call = tool_calls[0]
|
|
assert tool_call["function"]["name"] == "MyTool"
|
|
assert json.loads(tool_call["function"]["arguments"]) == {
|
|
"age": 27,
|
|
"name": "Erick",
|
|
}
|
|
assert tool_call["type"] == "function"
|
|
assert isinstance(resp.tool_calls, list)
|
|
assert len(resp.tool_calls) == 1
|
|
tool_call = resp.tool_calls[0]
|
|
assert tool_call["name"] == "MyTool"
|
|
assert tool_call["args"] == {"age": 27, "name": "Erick"}
|
|
|
|
|
|
def test_21_tool_choice_bool() -> None:
|
|
"""Test that tool choice is respected just passing in True."""
|
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
|
|
|
|
params = {GenTextParamsMetaNames.MAX_NEW_TOKENS: 500}
|
|
chat = ChatWatsonx(
|
|
model_id=MODEL_ID,
|
|
url=URL, # type: ignore[arg-type]
|
|
project_id=WX_PROJECT_ID,
|
|
params=params, # type: ignore[arg-type]
|
|
)
|
|
|
|
class MyTool(BaseModel):
|
|
name: str
|
|
age: int
|
|
|
|
with_tool = chat.bind_tools([MyTool], tool_choice=True)
|
|
|
|
resp = with_tool.invoke("Who was the 27 year old named Erick?")
|
|
assert isinstance(resp, AIMessage)
|
|
assert resp.content == "" # should just be tool call
|
|
tool_calls = resp.additional_kwargs["tool_calls"]
|
|
assert len(tool_calls) == 1
|
|
tool_call = tool_calls[0]
|
|
assert tool_call["function"]["name"] == "MyTool"
|
|
assert json.loads(tool_call["function"]["arguments"]) == {
|
|
"age": 27,
|
|
"name": "Erick",
|
|
}
|
|
assert tool_call["type"] == "function"
|