core[minor]: generation info on msg (#18592)

related to #16403 #17188
This commit is contained in:
Bagatur 2024-03-11 21:43:17 -07:00 committed by GitHub
parent cda43c5a11
commit e0e688a277
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 357 additions and 164 deletions

View File

@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "78b45321-7740-4399-b2ad-459811131de3",
"metadata": {},
"source": [
"# Get log probabilities\n",
"\n",
"Certain chat models can be configured to return token-level log probabilities. This guide walks through how to get logprobs for a number of models."
]
},
{
"cell_type": "markdown",
"id": "7f5016bf-2a7b-4140-9b80-8c35c7e5c0d5",
"metadata": {},
"source": [
"## OpenAI\n",
"\n",
"Install the LangChain x OpenAI package and set your API key"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe5143fe-84d3-4a91-bae8-629807bbe2cb",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-openai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd1a2bff-7ac8-46cb-ab95-72c616b45f2c",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"id": "f88ffa0d-f4a7-482c-88de-cbec501a79b1",
"metadata": {},
"source": [
"For the OpenAI API to return log probabilities we need to configure the `logprobs=True` param"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d1bf0a9a-e402-4931-ab53-32899f8e0326",
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\").bind(logprobs=True)\n",
"\n",
"msg = llm.invoke((\"human\", \"how are you today\"))"
]
},
{
"cell_type": "markdown",
"id": "e002c48a-af03-4796-a367-a69c5c8ae0c4",
"metadata": {},
"source": [
"The logprobs are included on each output Message as part of the `response_metadata`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e3e17872-62df-4b17-a8d4-4cae713a301b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'token': 'As',\n",
" 'bytes': [65, 115],\n",
" 'logprob': -1.5358024,\n",
" 'top_logprobs': []},\n",
" {'token': ' an',\n",
" 'bytes': [32, 97, 110],\n",
" 'logprob': -0.028062303,\n",
" 'top_logprobs': []},\n",
" {'token': ' AI',\n",
" 'bytes': [32, 65, 73],\n",
" 'logprob': -0.009415812,\n",
" 'top_logprobs': []},\n",
" {'token': ',', 'bytes': [44], 'logprob': -0.07371779, 'top_logprobs': []},\n",
" {'token': ' I',\n",
" 'bytes': [32, 73],\n",
" 'logprob': -4.298773e-05,\n",
" 'top_logprobs': []}]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"msg.response_metadata[\"logprobs\"][\"content\"][:5]"
]
},
{
"cell_type": "markdown",
"id": "d1ee1c29-d27e-4353-8c3c-2ed7e7f95ff5",
"metadata": {},
"source": [
"And are part of streamed Message chunks as well:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4bfaf309-3b23-43b7-b333-01fc4848992d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}]\n",
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}, {'token': ',', 'bytes': [44], 'logprob': -0.08852102, 'top_logprobs': []}]\n"
]
}
],
"source": [
"ct = 0\n",
"full = None\n",
"for chunk in llm.stream((\"human\", \"how are you today\")):\n",
" if ct < 5:\n",
" full = chunk if full is None else full + chunk\n",
" if \"logprobs\" in full.response_metadata:\n",
" print(full.response_metadata[\"logprobs\"][\"content\"])\n",
" else:\n",
" break\n",
" ct += 1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-2",
"language": "python",
"name": "poetry-venv-2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -15,6 +15,7 @@ from typing import (
List, List,
Optional, Optional,
Sequence, Sequence,
Union,
cast, cast,
) )
@ -240,6 +241,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
for chunk in self._stream( for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
): ):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message yield chunk.message
if generation is None: if generation is None:
generation = chunk generation = chunk
@ -317,6 +319,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async for chunk in _stream_implementation( async for chunk in _stream_implementation(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
): ):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message yield chunk.message
if generation is None: if generation is None:
generation = chunk generation = chunk
@ -586,38 +589,35 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache() llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache: check_cache = self.cache or self.cache is None
# This happens when langchain.cache is None, but self.cache is True if check_cache:
if self.cache is not None and self.cache: if llm_cache:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
elif self.cache is None:
pass
else:
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
if new_arg_supported: if inspect.signature(self._generate).parameters.get("run_manager"):
return self._generate( result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
) )
else:
return self._generate(messages, stop=stop, **kwargs)
else: else:
llm_string = self._get_llm_string(stop=stop, **kwargs) result = self._generate(messages, stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string) for generation in result.generations:
if isinstance(cache_val, list): generation.message.response_metadata = _gen_info_and_msg_metadata(
return ChatResult(generations=cache_val) generation
else: )
if new_arg_supported: if check_cache and llm_cache:
result = self._generate( llm_cache.update(prompt, llm_string, result.generations)
messages, stop=stop, run_manager=run_manager, **kwargs return result
)
else:
result = self._generate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache( async def _agenerate_with_cache(
self, self,
@ -626,38 +626,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache() llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache: check_cache = self.cache or self.cache is None
# This happens when langchain.cache is None, but self.cache is True if check_cache:
if self.cache is not None and self.cache: if llm_cache:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
elif self.cache is None:
pass
else:
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
if new_arg_supported: if inspect.signature(self._agenerate).parameters.get("run_manager"):
return await self._agenerate( result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
) )
else:
return await self._agenerate(messages, stop=stop, **kwargs)
else: else:
llm_string = self._get_llm_string(stop=stop, **kwargs) result = await self._agenerate(messages, stop=stop, **kwargs)
prompt = dumps(messages) for generation in result.generations:
cache_val = await llm_cache.alookup(prompt, llm_string) generation.message.response_metadata = _gen_info_and_msg_metadata(
if isinstance(cache_val, list): generation
return ChatResult(generations=cache_val) )
else: if check_cache and llm_cache:
if new_arg_supported: await llm_cache.aupdate(prompt, llm_string, result.generations)
result = await self._agenerate( return result
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
@abstractmethod @abstractmethod
def _generate( def _generate(
@ -852,3 +848,12 @@ class SimpleChatModel(BaseChatModel):
run_manager=run_manager.get_sync() if run_manager else None, run_manager=run_manager.get_sync() if run_manager else None,
**kwargs, **kwargs,
) )
def _gen_info_and_msg_metadata(
generation: Union[ChatGeneration, ChatGenerationChunk],
) -> dict:
return {
**(generation.generation_info or {}),
**generation.message.response_metadata,
}

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk, BaseMessageChunk,
merge_content, merge_content,
) )
from langchain_core.utils._merge import merge_dicts
class AIMessage(BaseMessage): class AIMessage(BaseMessage):
@ -49,9 +50,12 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
return self.__class__( return self.__class__(
example=self.example, example=self.example,
content=merge_content(self.content, other.content), content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs self.additional_kwargs, other.additional_kwargs
), ),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
) )
return super().__add__(other) return super().__add__(other)

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING: if TYPE_CHECKING:
@ -114,54 +115,6 @@ class BaseMessageChunk(BaseMessage):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one,
handling specific scenarios where a key exists in both dictionaries
but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Example:
If left = {"function_call": {"arguments": None}} and
right = {"function_call": {"arguments": "{\n"}}
then, after merging, for the key "function_call",
the value from 'right' is used,
resulting in merged = {"function_call": {"arguments": "{\n"}}.
"""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif merged[k] is None and v:
merged[k] = v
elif v is None:
continue
elif merged[k] == v:
continue
elif type(merged[k]) != type(v):
raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
elif isinstance(merged[k], list):
merged[k] = merged[k].copy()
for i, e in enumerate(v):
if isinstance(e, dict) and isinstance(e.get("index"), int):
i = e["index"]
if i < len(merged[k]):
merged[k][i] = self._merge_kwargs_dict(merged[k][i], e)
else:
merged[k] = merged[k] + [e]
else:
raise TypeError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk): if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk, # If both are (subclasses of) BaseMessageChunk,
@ -170,9 +123,12 @@ class BaseMessageChunk(BaseMessage):
return self.__class__( # type: ignore[call-arg] return self.__class__( # type: ignore[call-arg]
id=self.id, id=self.id,
content=merge_content(self.content, other.content), content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs self.additional_kwargs, other.additional_kwargs
), ),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
) )
else: else:
raise TypeError( raise TypeError(

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk, BaseMessageChunk,
merge_content, merge_content,
) )
from langchain_core.utils._merge import merge_dicts
class ChatMessage(BaseMessage): class ChatMessage(BaseMessage):
@ -47,17 +48,23 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
return self.__class__( return self.__class__(
role=self.role, role=self.role,
content=merge_content(self.content, other.content), content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs self.additional_kwargs, other.additional_kwargs
), ),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
) )
elif isinstance(other, BaseMessageChunk): elif isinstance(other, BaseMessageChunk):
return self.__class__( return self.__class__(
role=self.role, role=self.role,
content=merge_content(self.content, other.content), content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs self.additional_kwargs, other.additional_kwargs
), ),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
) )
else: else:
return super().__add__(other) return super().__add__(other)

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk, BaseMessageChunk,
merge_content, merge_content,
) )
from langchain_core.utils._merge import merge_dicts
class FunctionMessage(BaseMessage): class FunctionMessage(BaseMessage):
@ -47,9 +48,12 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
return self.__class__( return self.__class__(
name=self.name, name=self.name,
content=merge_content(self.content, other.content), content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs self.additional_kwargs, other.additional_kwargs
), ),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
) )
return super().__add__(other) return super().__add__(other)

View File

@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk, BaseMessageChunk,
merge_content, merge_content,
) )
from langchain_core.utils._merge import merge_dicts
class ToolMessage(BaseMessage): class ToolMessage(BaseMessage):
@ -47,9 +48,12 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
return self.__class__( return self.__class__(
tool_call_id=self.tool_call_id, tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content), content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs self.additional_kwargs, other.additional_kwargs
), ),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
) )
return super().__add__(other) return super().__add__(other)

View File

@ -16,27 +16,44 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
resulting in merged = {"function_call": {"arguments": "{\n"}}. resulting in merged = {"function_call": {"arguments": "{\n"}}.
""" """
merged = left.copy() merged = left.copy()
for k, v in right.items(): for right_k, right_v in right.items():
if k not in merged: if right_k not in merged:
merged[k] = v merged[right_k] = right_v
elif v is not None and merged[k] is None: elif right_v is not None and merged[right_k] is None:
merged[k] = v merged[right_k] = right_v
elif v is None or merged[k] == v: elif right_v is None:
continue continue
elif type(merged[k]) != type(v): elif type(merged[right_k]) != type(right_v):
raise TypeError( raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,' f'additional_kwargs["{right_k}"] already exists in this message,'
" but with a different type." " but with a different type."
) )
elif isinstance(merged[k], str): elif isinstance(merged[right_k], str):
merged[k] += v merged[right_k] += right_v
elif isinstance(merged[k], dict): elif isinstance(merged[right_k], dict):
merged[k] = merge_dicts(merged[k], v) merged[right_k] = merge_dicts(merged[right_k], right_v)
elif isinstance(merged[k], list): elif isinstance(merged[right_k], list):
merged[k] = merged[k] + v merged[right_k] = merged[right_k].copy()
for e in right_v:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [
i
for i, e_left in enumerate(merged[right_k])
if e_left["index"] == e["index"]
]
if to_merge:
merged[right_k][to_merge[0]] = merge_dicts(
merged[right_k][to_merge[0]], e
)
else:
merged[right_k] = merged[right_k] + [e]
else:
merged[right_k] = merged[right_k] + [e]
elif merged[right_k] == right_v:
continue
else: else:
raise TypeError( raise TypeError(
f"Additional kwargs key {k} already exists in left dict and value has " f"Additional kwargs key {right_k} already exists in left dict and "
f"unsupported type {type(merged[k])}." f"value has unsupported type {type(merged[right_k])}."
) )
return merged return merged

View File

@ -48,9 +48,9 @@ def test_check_package_version(
({"a": 1.5}, {"a": 1.5}, {"a": 1.5}), ({"a": 1.5}, {"a": 1.5}, {"a": 1.5}),
({"a": True}, {"a": True}, {"a": True}), ({"a": True}, {"a": True}, {"a": True}),
({"a": False}, {"a": False}, {"a": False}), ({"a": False}, {"a": False}, {"a": False}),
({"a": "txt"}, {"a": "txt"}, {"a": "txt"}), ({"a": "txt"}, {"a": "txt"}, {"a": "txttxt"}),
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2]}), ({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2, 1, 2]}),
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txt"}}), ({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txttxt"}}),
# Merge strings. # Merge strings.
({"a": "one"}, {"a": "two"}, {"a": "onetwo"}), ({"a": "one"}, {"a": "two"}, {"a": "onetwo"}),
# Merge dicts. # Merge dicts.
@ -89,6 +89,17 @@ def test_check_package_version(
), ),
), ),
), ),
# 'index' keyword has special handling
(
{"a": [{"index": 0, "b": "{"}]},
{"a": [{"index": 0, "b": "f"}]},
{"a": [{"index": 0, "b": "{f"}]},
),
(
{"a": [{"idx": 0, "b": "{"}]},
{"a": [{"idx": 0, "b": "f"}]},
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
),
), ),
) )
def test_merge_dicts( def test_merge_dicts(

View File

@ -14,10 +14,7 @@ from langchain_core.outputs import ChatGeneration, Generation
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langchain.cache import ( from langchain.cache import InMemoryCache, SQLAlchemyCache
InMemoryCache,
SQLAlchemyCache,
)
from langchain.globals import get_llm_cache, set_llm_cache from langchain.globals import get_llm_cache, set_llm_cache
@ -67,7 +64,7 @@ async def test_llm_caching() -> None:
llm_string=create_llm_string(llm), llm_string=create_llm_string(llm),
return_val=[Generation(text=cached_response)], return_val=[Generation(text=cached_response)],
) )
assert llm(prompt) == cached_response assert llm.invoke(prompt) == cached_response
# async test # async test
await llm_cache.aupdate( await llm_cache.aupdate(
prompt=prompt, prompt=prompt,

View File

@ -1,10 +1,10 @@
"""Test AzureChatOpenAI wrapper.""" """Test AzureChatOpenAI wrapper."""
import os import os
from typing import Any from typing import Any, Optional
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_openai import AzureChatOpenAI from langchain_openai import AzureChatOpenAI
@ -164,16 +164,20 @@ async def test_async_chat_openai_streaming() -> None:
@pytest.mark.scheduled @pytest.mark.scheduled
def test_openai_streaming(llm: AzureChatOpenAI) -> None: def test_openai_streaming(llm: AzureChatOpenAI) -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
full: Optional[BaseMessageChunk] = None
for token in llm.stream("I'm Pickle Rick"): for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str) assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
@pytest.mark.scheduled @pytest.mark.scheduled
async def test_openai_astream(llm: AzureChatOpenAI) -> None: async def test_openai_astream(llm: AzureChatOpenAI) -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str) full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
@pytest.mark.scheduled @pytest.mark.scheduled

View File

@ -1,9 +1,15 @@
"""Test ChatOpenAI chat model.""" """Test ChatOpenAI chat model."""
from typing import Any, Optional from typing import Any, Optional, cast
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ( from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
ChatResult, ChatResult,
@ -99,7 +105,7 @@ def test_chat_openai_streaming() -> None:
verbose=True, verbose=True,
) )
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = chat([message]) response = chat.invoke([message])
assert callback_handler.llm_streams > 0 assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
@ -336,16 +342,20 @@ def test_stream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
llm = ChatOpenAI() llm = ChatOpenAI()
for token in llm.stream("I'm Pickle Rick"): full: Optional[BaseMessageChunk] = None
assert isinstance(token.content, str) for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
async def test_astream() -> None: async def test_astream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
llm = ChatOpenAI() llm = ChatOpenAI()
async for token in llm.astream("I'm Pickle Rick"): full: Optional[BaseMessageChunk] = None
assert isinstance(token.content, str) async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
async def test_abatch() -> None: async def test_abatch() -> None:
@ -395,33 +405,33 @@ def test_invoke() -> None:
def test_logprobs() -> None: def test_logprobs() -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
result = llm.generate([[HumanMessage(content="I'm PickleRick")]], logprobs=True) result = llm.invoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
assert result.generations[0][0].generation_info assert result.response_metadata
assert "content" in result.generations[0][0].generation_info["logprobs"] assert "content" in result.response_metadata["logprobs"]
async def test_async_logprobs() -> None: async def test_async_logprobs() -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
result = await llm.agenerate( result = await llm.ainvoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
[[HumanMessage(content="I'm PickleRick")]], logprobs=True assert result.response_metadata
) assert "content" in result.response_metadata["logprobs"]
assert result.generations[0][0].generation_info
assert "content" in result.generations[0][0].generation_info["logprobs"]
def test_logprobs_streaming() -> None: def test_logprobs_streaming() -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
result = llm.generate( full: Optional[BaseMessageChunk] = None
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True for chunk in llm.stream("I'm Pickle Rick", logprobs=True):
) assert isinstance(chunk.content, str)
assert result.generations[0][0].generation_info full = chunk if full is None else full + chunk
assert "content" in result.generations[0][0].generation_info["logprobs"] assert cast(BaseMessageChunk, full).response_metadata
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
async def test_async_logprobs_streaming() -> None: async def test_async_logprobs_streaming() -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
result = await llm.agenerate( full: Optional[BaseMessageChunk] = None
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True async for chunk in llm.astream("I'm Pickle Rick", logprobs=True):
) assert isinstance(chunk.content, str)
assert result.generations[0][0].generation_info full = chunk if full is None else full + chunk
assert "content" in result.generations[0][0].generation_info["logprobs"] assert cast(BaseMessageChunk, full).response_metadata
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]