mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
community[patch]: update use of deprecated llm methods (#20393)
.predict and .predict_messages for BaseLanguageModel and BaseChatModel
This commit is contained in:
parent
3a068b26f3
commit
38faa74c23
@ -98,7 +98,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
|||||||
... mode='prompt'
|
... mode='prompt'
|
||||||
... )
|
... )
|
||||||
>>> llm = OpenAI(callbacks=[handler])
|
>>> llm = OpenAI(callbacks=[handler])
|
||||||
>>> llm.predict('Tell me a story about a dog.')
|
>>> llm.invoke('Tell me a story about a dog.')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
|
DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
|
||||||
|
@ -204,7 +204,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
|||||||
llmonitor_callback = LLMonitorCallbackHandler()
|
llmonitor_callback = LLMonitorCallbackHandler()
|
||||||
llm = OpenAI(callbacks=[llmonitor_callback],
|
llm = OpenAI(callbacks=[llmonitor_callback],
|
||||||
metadata={"userId": "user-123"})
|
metadata={"userId": "user-123"})
|
||||||
llm.predict("Hello, how are you?")
|
llm.invoke("Hello, how are you?")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
@ -95,10 +96,11 @@ class OpaquePrompts(LLM):
|
|||||||
|
|
||||||
# TODO: Add in callbacks once child runs for LLMs are supported by LangSmith.
|
# TODO: Add in callbacks once child runs for LLMs are supported by LangSmith.
|
||||||
# call the LLM with the sanitized prompt and get the response
|
# call the LLM with the sanitized prompt and get the response
|
||||||
llm_response = self.base_llm.predict(
|
llm_response = self.base_llm.bind(stop=stop).invoke(
|
||||||
sanitized_prompt_value_str,
|
sanitized_prompt_value_str,
|
||||||
stop=stop,
|
|
||||||
)
|
)
|
||||||
|
if isinstance(llm_response, AIMessage):
|
||||||
|
llm_response = llm_response.content
|
||||||
|
|
||||||
# desanitize the response by restoring the original sensitive information
|
# desanitize the response by restoring the original sensitive information
|
||||||
desanitize_response: op.DesanitizeResponse = op.desanitize(
|
desanitize_response: op.DesanitizeResponse = op.desanitize(
|
||||||
|
@ -96,8 +96,8 @@ def test_openai_predict(mock_completion: dict) -> None:
|
|||||||
"client",
|
"client",
|
||||||
mock_client,
|
mock_client,
|
||||||
):
|
):
|
||||||
res = llm.predict("bar")
|
res = llm.invoke("bar")
|
||||||
assert res == "Bar Baz"
|
assert res.content == "Bar Baz"
|
||||||
assert completed
|
assert completed
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def test_batch_size() -> None:
|
|||||||
|
|
||||||
llm = FakeListLLM(responses=["foo"] * 1)
|
llm = FakeListLLM(responses=["foo"] * 1)
|
||||||
with collect_runs() as cb:
|
with collect_runs() as cb:
|
||||||
llm.predict("foo")
|
llm.invoke("foo")
|
||||||
assert len(cb.traced_runs) == 1
|
assert len(cb.traced_runs) == 1
|
||||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from langchain_core.tracers.context import collect_runs
|
|||||||
def test_collect_runs() -> None:
|
def test_collect_runs() -> None:
|
||||||
llm = FakeListLLM(responses=["hello"])
|
llm = FakeListLLM(responses=["hello"])
|
||||||
with collect_runs() as cb:
|
with collect_runs() as cb:
|
||||||
llm.predict("hi")
|
llm.invoke("hi")
|
||||||
assert cb.traced_runs
|
assert cb.traced_runs
|
||||||
assert len(cb.traced_runs) == 1
|
assert len(cb.traced_runs) == 1
|
||||||
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
|
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
|
||||||
|
@ -183,7 +183,7 @@ class AnthropicFunctions(BaseChatModel):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"if `function_call` provided, `functions` must also be"
|
"if `function_call` provided, `functions` must also be"
|
||||||
)
|
)
|
||||||
response = self.model.predict_messages(
|
response = self.model.invoke(
|
||||||
messages, stop=stop, callbacks=run_manager, **kwargs
|
messages, stop=stop, callbacks=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
completion = cast(str, response.content)
|
completion = cast(str, response.content)
|
||||||
|
@ -89,7 +89,7 @@ function in "functions".'
|
|||||||
)
|
)
|
||||||
if "functions" in kwargs:
|
if "functions" in kwargs:
|
||||||
del kwargs["functions"]
|
del kwargs["functions"]
|
||||||
response_message = self.llm.predict_messages(
|
response_message = self.llm.invoke(
|
||||||
[system_message] + messages, stop=stop, callbacks=run_manager, **kwargs
|
[system_message] + messages, stop=stop, callbacks=run_manager, **kwargs
|
||||||
)
|
)
|
||||||
chat_generation_content = response_message.content
|
chat_generation_content = response_message.content
|
||||||
|
@ -49,7 +49,7 @@ def model_cfg_sys_msg() -> Llama2Chat:
|
|||||||
def test_default_system_message(model: Llama2Chat) -> None:
|
def test_default_system_message(model: Llama2Chat) -> None:
|
||||||
messages = [HumanMessage(content="usr-msg-1")]
|
messages = [HumanMessage(content="usr-msg-1")]
|
||||||
|
|
||||||
actual = model.predict_messages(messages).content # type: ignore
|
actual = model.invoke(messages).content # type: ignore
|
||||||
expected = (
|
expected = (
|
||||||
f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||||
)
|
)
|
||||||
@ -62,7 +62,7 @@ def test_configured_system_message(
|
|||||||
) -> None:
|
) -> None:
|
||||||
messages = [HumanMessage(content="usr-msg-1")]
|
messages = [HumanMessage(content="usr-msg-1")]
|
||||||
|
|
||||||
actual = model_cfg_sys_msg.predict_messages(messages).content # type: ignore
|
actual = model_cfg_sys_msg.invoke(messages).content # type: ignore
|
||||||
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||||
|
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
@ -73,7 +73,7 @@ async def test_configured_system_message_async(
|
|||||||
) -> None:
|
) -> None:
|
||||||
messages = [HumanMessage(content="usr-msg-1")]
|
messages = [HumanMessage(content="usr-msg-1")]
|
||||||
|
|
||||||
actual = await model_cfg_sys_msg.apredict_messages(messages) # type: ignore
|
actual = await model_cfg_sys_msg.ainvoke(messages) # type: ignore
|
||||||
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||||
|
|
||||||
assert actual.content == expected
|
assert actual.content == expected
|
||||||
@ -87,7 +87,7 @@ def test_provided_system_message(
|
|||||||
HumanMessage(content="usr-msg-1"),
|
HumanMessage(content="usr-msg-1"),
|
||||||
]
|
]
|
||||||
|
|
||||||
actual = model_cfg_sys_msg.predict_messages(messages).content
|
actual = model_cfg_sys_msg.invoke(messages).content
|
||||||
expected = "<s>[INST] <<SYS>>\ncustom-sys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
expected = "<s>[INST] <<SYS>>\ncustom-sys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"
|
||||||
|
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
@ -102,7 +102,7 @@ def test_human_ai_dialogue(model_cfg_sys_msg: Llama2Chat) -> None:
|
|||||||
HumanMessage(content="usr-msg-3"),
|
HumanMessage(content="usr-msg-3"),
|
||||||
]
|
]
|
||||||
|
|
||||||
actual = model_cfg_sys_msg.predict_messages(messages).content
|
actual = model_cfg_sys_msg.invoke(messages).content
|
||||||
expected = (
|
expected = (
|
||||||
"<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST] ai-msg-1 </s>"
|
"<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST] ai-msg-1 </s>"
|
||||||
"<s>[INST] usr-msg-2 [/INST] ai-msg-2 </s><s>[INST] usr-msg-3 [/INST]"
|
"<s>[INST] usr-msg-2 [/INST] ai-msg-2 </s><s>[INST] usr-msg-3 [/INST]"
|
||||||
@ -113,14 +113,14 @@ def test_human_ai_dialogue(model_cfg_sys_msg: Llama2Chat) -> None:
|
|||||||
|
|
||||||
def test_no_message(model: Llama2Chat) -> None:
|
def test_no_message(model: Llama2Chat) -> None:
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(ValueError) as info:
|
||||||
model.predict_messages([])
|
model.invoke([])
|
||||||
|
|
||||||
assert info.value.args[0] == "at least one HumanMessage must be provided"
|
assert info.value.args[0] == "at least one HumanMessage must be provided"
|
||||||
|
|
||||||
|
|
||||||
def test_ai_message_first(model: Llama2Chat) -> None:
|
def test_ai_message_first(model: Llama2Chat) -> None:
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(ValueError) as info:
|
||||||
model.predict_messages([AIMessage(content="ai-msg-1")])
|
model.invoke([AIMessage(content="ai-msg-1")])
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
info.value.args[0]
|
info.value.args[0]
|
||||||
@ -136,7 +136,7 @@ def test_human_ai_messages_not_alternating(model: Llama2Chat) -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(ValueError) as info:
|
||||||
model.predict_messages(messages) # type: ignore
|
model.invoke(messages) # type: ignore
|
||||||
|
|
||||||
assert info.value.args[0] == (
|
assert info.value.args[0] == (
|
||||||
"messages must be alternating human- and ai-messages, "
|
"messages must be alternating human- and ai-messages, "
|
||||||
@ -151,6 +151,6 @@ def test_last_message_not_human_message(model: Llama2Chat) -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(ValueError) as info:
|
||||||
model.predict_messages(messages)
|
model.invoke(messages)
|
||||||
|
|
||||||
assert info.value.args[0] == "last message must be a HumanMessage"
|
assert info.value.args[0] == "last message must be a HumanMessage"
|
||||||
|
@ -23,7 +23,7 @@ def test_prompt(model: Mixtral) -> None:
|
|||||||
HumanMessage(content="usr-msg-2"),
|
HumanMessage(content="usr-msg-2"),
|
||||||
]
|
]
|
||||||
|
|
||||||
actual = model.predict_messages(messages).content # type: ignore
|
actual = model.invoke(messages).content # type: ignore
|
||||||
expected = (
|
expected = (
|
||||||
"<s>[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 </s> [INST] usr-msg-2 [/INST]" # noqa: E501
|
"<s>[INST] sys-msg\nusr-msg-1 [/INST] ai-msg-1 </s> [INST] usr-msg-2 [/INST]" # noqa: E501
|
||||||
)
|
)
|
||||||
|
@ -23,7 +23,7 @@ def test_prompt(model: Orca) -> None:
|
|||||||
HumanMessage(content="usr-msg-2"),
|
HumanMessage(content="usr-msg-2"),
|
||||||
]
|
]
|
||||||
|
|
||||||
actual = model.predict_messages(messages).content # type: ignore
|
actual = model.invoke(messages).content # type: ignore
|
||||||
expected = "### System:\nsys-msg\n\n### User:\nusr-msg-1\n\n### Assistant:\nai-msg-1\n\n### User:\nusr-msg-2\n\n" # noqa: E501
|
expected = "### System:\nsys-msg\n\n### User:\nusr-msg-1\n\n### Assistant:\nai-msg-1\n\n### User:\nusr-msg-2\n\n" # noqa: E501
|
||||||
|
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
@ -23,7 +23,7 @@ def test_prompt(model: Vicuna) -> None:
|
|||||||
HumanMessage(content="usr-msg-2"),
|
HumanMessage(content="usr-msg-2"),
|
||||||
]
|
]
|
||||||
|
|
||||||
actual = model.predict_messages(messages).content # type: ignore
|
actual = model.invoke(messages).content # type: ignore
|
||||||
expected = "sys-msg USER: usr-msg-1 ASSISTANT: ai-msg-1 </s>USER: usr-msg-2 "
|
expected = "sys-msg USER: usr-msg-1 ASSISTANT: ai-msg-1 </s>USER: usr-msg-2 "
|
||||||
|
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
@ -86,6 +86,5 @@ def test_redis_cache_chat() -> None:
|
|||||||
llm = FakeChatModel()
|
llm = FakeChatModel()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
with pytest.warns():
|
llm.invoke("foo")
|
||||||
llm.predict("foo")
|
|
||||||
langchain.llm_cache.redis.flushall()
|
langchain.llm_cache.redis.flushall()
|
||||||
|
Loading…
Reference in New Issue
Block a user