mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
parent
cda43c5a11
commit
e0e688a277
174
docs/docs/modules/model_io/chat/logprobs.ipynb
Normal file
174
docs/docs/modules/model_io/chat/logprobs.ipynb
Normal file
@ -0,0 +1,174 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "78b45321-7740-4399-b2ad-459811131de3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Get log probabilities\n",
|
||||
"\n",
|
||||
"Certain chat models can be configured to return token-level log probabilities. This guide walks through how to get logprobs for a number of models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7f5016bf-2a7b-4140-9b80-8c35c7e5c0d5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## OpenAI\n",
|
||||
"\n",
|
||||
"Install the LangChain x OpenAI package and set your API key"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe5143fe-84d3-4a91-bae8-629807bbe2cb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain-openai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fd1a2bff-7ac8-46cb-ab95-72c616b45f2c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f88ffa0d-f4a7-482c-88de-cbec501a79b1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"For the OpenAI API to return log probabilities we need to configure the `logprobs=True` param"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d1bf0a9a-e402-4931-ab53-32899f8e0326",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\").bind(logprobs=True)\n",
|
||||
"\n",
|
||||
"msg = llm.invoke((\"human\", \"how are you today\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e002c48a-af03-4796-a367-a69c5c8ae0c4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The logprobs are included on each output Message as part of the `response_metadata`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "e3e17872-62df-4b17-a8d4-4cae713a301b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'token': 'As',\n",
|
||||
" 'bytes': [65, 115],\n",
|
||||
" 'logprob': -1.5358024,\n",
|
||||
" 'top_logprobs': []},\n",
|
||||
" {'token': ' an',\n",
|
||||
" 'bytes': [32, 97, 110],\n",
|
||||
" 'logprob': -0.028062303,\n",
|
||||
" 'top_logprobs': []},\n",
|
||||
" {'token': ' AI',\n",
|
||||
" 'bytes': [32, 65, 73],\n",
|
||||
" 'logprob': -0.009415812,\n",
|
||||
" 'top_logprobs': []},\n",
|
||||
" {'token': ',', 'bytes': [44], 'logprob': -0.07371779, 'top_logprobs': []},\n",
|
||||
" {'token': ' I',\n",
|
||||
" 'bytes': [32, 73],\n",
|
||||
" 'logprob': -4.298773e-05,\n",
|
||||
" 'top_logprobs': []}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"msg.response_metadata[\"logprobs\"][\"content\"][:5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d1ee1c29-d27e-4353-8c3c-2ed7e7f95ff5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And are part of streamed Message chunks as well:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4bfaf309-3b23-43b7-b333-01fc4848992d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[]\n",
|
||||
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}]\n",
|
||||
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}]\n",
|
||||
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}]\n",
|
||||
"[{'token': 'As', 'bytes': [65, 115], 'logprob': -1.7523563, 'top_logprobs': []}, {'token': ' an', 'bytes': [32, 97, 110], 'logprob': -0.019908238, 'top_logprobs': []}, {'token': ' AI', 'bytes': [32, 65, 73], 'logprob': -0.0093033705, 'top_logprobs': []}, {'token': ',', 'bytes': [44], 'logprob': -0.08852102, 'top_logprobs': []}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ct = 0\n",
|
||||
"full = None\n",
|
||||
"for chunk in llm.stream((\"human\", \"how are you today\")):\n",
|
||||
" if ct < 5:\n",
|
||||
" full = chunk if full is None else full + chunk\n",
|
||||
" if \"logprobs\" in full.response_metadata:\n",
|
||||
" print(full.response_metadata[\"logprobs\"][\"content\"])\n",
|
||||
" else:\n",
|
||||
" break\n",
|
||||
" ct += 1"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "poetry-venv-2",
|
||||
"language": "python",
|
||||
"name": "poetry-venv-2"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -15,6 +15,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@ -240,6 +241,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
@ -317,6 +319,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
async for chunk in _stream_implementation(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
@ -586,38 +589,35 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
check_cache = self.cache or self.cache is None
|
||||
if check_cache:
|
||||
if llm_cache:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
elif self.cache is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return self._generate(messages, stop=stop, **kwargs)
|
||||
if inspect.signature(self._generate).parameters.get("run_manager"):
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
|
||||
for generation in result.generations:
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if check_cache and llm_cache:
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
@ -626,38 +626,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
check_cache = self.cache or self.cache is None
|
||||
if check_cache:
|
||||
if llm_cache:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = await llm_cache.alookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
elif self.cache is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return await self._agenerate(messages, stop=stop, **kwargs)
|
||||
if inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = await llm_cache.alookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
await llm_cache.aupdate(prompt, llm_string, result.generations)
|
||||
return result
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
for generation in result.generations:
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if check_cache and llm_cache:
|
||||
await llm_cache.aupdate(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
@ -852,3 +848,12 @@ class SimpleChatModel(BaseChatModel):
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _gen_info_and_msg_metadata(
|
||||
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||
) -> dict:
|
||||
return {
|
||||
**(generation.generation_info or {}),
|
||||
**generation.message.response_metadata,
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
@ -49,9 +50,12 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
example=self.example,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -114,54 +115,6 @@ class BaseMessageChunk(BaseMessage):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def _merge_kwargs_dict(
|
||||
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge additional_kwargs from another BaseMessageChunk into this one,
|
||||
handling specific scenarios where a key exists in both dictionaries
|
||||
but has a value of None in 'left'. In such cases, the method uses the
|
||||
value from 'right' for that key in the merged dictionary.
|
||||
Example:
|
||||
If left = {"function_call": {"arguments": None}} and
|
||||
right = {"function_call": {"arguments": "{\n"}}
|
||||
then, after merging, for the key "function_call",
|
||||
the value from 'right' is used,
|
||||
resulting in merged = {"function_call": {"arguments": "{\n"}}.
|
||||
"""
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif merged[k] is None and v:
|
||||
merged[k] = v
|
||||
elif v is None:
|
||||
continue
|
||||
elif merged[k] == v:
|
||||
continue
|
||||
elif type(merged[k]) != type(v):
|
||||
raise TypeError(
|
||||
f'additional_kwargs["{k}"] already exists in this message,'
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||
elif isinstance(merged[k], list):
|
||||
merged[k] = merged[k].copy()
|
||||
for i, e in enumerate(v):
|
||||
if isinstance(e, dict) and isinstance(e.get("index"), int):
|
||||
i = e["index"]
|
||||
if i < len(merged[k]):
|
||||
merged[k][i] = self._merge_kwargs_dict(merged[k][i], e)
|
||||
else:
|
||||
merged[k] = merged[k] + [e]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Additional kwargs key {k} already exists in this message."
|
||||
)
|
||||
return merged
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, BaseMessageChunk):
|
||||
# If both are (subclasses of) BaseMessageChunk,
|
||||
@ -170,9 +123,12 @@ class BaseMessageChunk(BaseMessage):
|
||||
return self.__class__( # type: ignore[call-arg]
|
||||
id=self.id,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
@ -47,17 +48,23 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
elif isinstance(other, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
else:
|
||||
return super().__add__(other)
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
@ -47,9 +48,12 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
name=self.name,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class ToolMessage(BaseMessage):
|
||||
@ -47,9 +48,12 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
tool_call_id=self.tool_call_id,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@ -16,27 +16,44 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resulting in merged = {"function_call": {"arguments": "{\n"}}.
|
||||
"""
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif v is not None and merged[k] is None:
|
||||
merged[k] = v
|
||||
elif v is None or merged[k] == v:
|
||||
for right_k, right_v in right.items():
|
||||
if right_k not in merged:
|
||||
merged[right_k] = right_v
|
||||
elif right_v is not None and merged[right_k] is None:
|
||||
merged[right_k] = right_v
|
||||
elif right_v is None:
|
||||
continue
|
||||
elif type(merged[k]) != type(v):
|
||||
elif type(merged[right_k]) != type(right_v):
|
||||
raise TypeError(
|
||||
f'additional_kwargs["{k}"] already exists in this message,'
|
||||
f'additional_kwargs["{right_k}"] already exists in this message,'
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = merge_dicts(merged[k], v)
|
||||
elif isinstance(merged[k], list):
|
||||
merged[k] = merged[k] + v
|
||||
elif isinstance(merged[right_k], str):
|
||||
merged[right_k] += right_v
|
||||
elif isinstance(merged[right_k], dict):
|
||||
merged[right_k] = merge_dicts(merged[right_k], right_v)
|
||||
elif isinstance(merged[right_k], list):
|
||||
merged[right_k] = merged[right_k].copy()
|
||||
for e in right_v:
|
||||
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
|
||||
to_merge = [
|
||||
i
|
||||
for i, e_left in enumerate(merged[right_k])
|
||||
if e_left["index"] == e["index"]
|
||||
]
|
||||
if to_merge:
|
||||
merged[right_k][to_merge[0]] = merge_dicts(
|
||||
merged[right_k][to_merge[0]], e
|
||||
)
|
||||
else:
|
||||
merged[right_k] = merged[right_k] + [e]
|
||||
else:
|
||||
merged[right_k] = merged[right_k] + [e]
|
||||
elif merged[right_k] == right_v:
|
||||
continue
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Additional kwargs key {k} already exists in left dict and value has "
|
||||
f"unsupported type {type(merged[k])}."
|
||||
f"Additional kwargs key {right_k} already exists in left dict and "
|
||||
f"value has unsupported type {type(merged[right_k])}."
|
||||
)
|
||||
return merged
|
||||
|
@ -48,9 +48,9 @@ def test_check_package_version(
|
||||
({"a": 1.5}, {"a": 1.5}, {"a": 1.5}),
|
||||
({"a": True}, {"a": True}, {"a": True}),
|
||||
({"a": False}, {"a": False}, {"a": False}),
|
||||
({"a": "txt"}, {"a": "txt"}, {"a": "txt"}),
|
||||
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2]}),
|
||||
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txt"}}),
|
||||
({"a": "txt"}, {"a": "txt"}, {"a": "txttxt"}),
|
||||
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2, 1, 2]}),
|
||||
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txttxt"}}),
|
||||
# Merge strings.
|
||||
({"a": "one"}, {"a": "two"}, {"a": "onetwo"}),
|
||||
# Merge dicts.
|
||||
@ -89,6 +89,17 @@ def test_check_package_version(
|
||||
),
|
||||
),
|
||||
),
|
||||
# 'index' keyword has special handling
|
||||
(
|
||||
{"a": [{"index": 0, "b": "{"}]},
|
||||
{"a": [{"index": 0, "b": "f"}]},
|
||||
{"a": [{"index": 0, "b": "{f"}]},
|
||||
),
|
||||
(
|
||||
{"a": [{"idx": 0, "b": "{"}]},
|
||||
{"a": [{"idx": 0, "b": "f"}]},
|
||||
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_merge_dicts(
|
||||
|
@ -14,10 +14,7 @@ from langchain_core.outputs import ChatGeneration, Generation
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain.cache import (
|
||||
InMemoryCache,
|
||||
SQLAlchemyCache,
|
||||
)
|
||||
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
|
||||
|
||||
@ -67,7 +64,7 @@ async def test_llm_caching() -> None:
|
||||
llm_string=create_llm_string(llm),
|
||||
return_val=[Generation(text=cached_response)],
|
||||
)
|
||||
assert llm(prompt) == cached_response
|
||||
assert llm.invoke(prompt) == cached_response
|
||||
# async test
|
||||
await llm_cache.aupdate(
|
||||
prompt=prompt,
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""Test AzureChatOpenAI wrapper."""
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
@ -164,16 +164,20 @@ async def test_async_chat_openai_streaming() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming(llm: AzureChatOpenAI) -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream(llm: AzureChatOpenAI) -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
|
@ -1,9 +1,15 @@
|
||||
"""Test ChatOpenAI chat model."""
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
@ -99,7 +105,7 @@ def test_chat_openai_streaming() -> None:
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
response = chat.invoke([message])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, BaseMessage)
|
||||
|
||||
@ -336,16 +342,20 @@ def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
@ -395,33 +405,33 @@ def test_invoke() -> None:
|
||||
|
||||
def test_logprobs() -> None:
|
||||
llm = ChatOpenAI()
|
||||
result = llm.generate([[HumanMessage(content="I'm PickleRick")]], logprobs=True)
|
||||
assert result.generations[0][0].generation_info
|
||||
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||
result = llm.invoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
|
||||
assert result.response_metadata
|
||||
assert "content" in result.response_metadata["logprobs"]
|
||||
|
||||
|
||||
async def test_async_logprobs() -> None:
|
||||
llm = ChatOpenAI()
|
||||
result = await llm.agenerate(
|
||||
[[HumanMessage(content="I'm PickleRick")]], logprobs=True
|
||||
)
|
||||
assert result.generations[0][0].generation_info
|
||||
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||
result = await llm.ainvoke([HumanMessage(content="I'm PickleRick")], logprobs=True)
|
||||
assert result.response_metadata
|
||||
assert "content" in result.response_metadata["logprobs"]
|
||||
|
||||
|
||||
def test_logprobs_streaming() -> None:
|
||||
llm = ChatOpenAI()
|
||||
result = llm.generate(
|
||||
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
|
||||
)
|
||||
assert result.generations[0][0].generation_info
|
||||
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream("I'm Pickle Rick", logprobs=True):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert cast(BaseMessageChunk, full).response_metadata
|
||||
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
|
||||
|
||||
|
||||
async def test_async_logprobs_streaming() -> None:
|
||||
llm = ChatOpenAI()
|
||||
result = await llm.agenerate(
|
||||
[[HumanMessage(content="I'm PickleRick")]], logprobs=True, stream=True
|
||||
)
|
||||
assert result.generations[0][0].generation_info
|
||||
assert "content" in result.generations[0][0].generation_info["logprobs"]
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream("I'm Pickle Rick", logprobs=True):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert cast(BaseMessageChunk, full).response_metadata
|
||||
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
|
||||
|
Loading…
Reference in New Issue
Block a user