Compare commits

...

6 Commits

Author SHA1 Message Date
Chester Curme
ca462b71dc typo 2025-10-29 11:10:38 -04:00
Chester Curme
33a91fdf9a update snapshots for xai 2025-10-29 11:10:29 -04:00
Chester Curme
8b05cb4522 fix snapshot 2025-10-29 11:04:07 -04:00
Chester Curme
9982e28aaa update some snapshots 2025-10-29 10:52:32 -04:00
Chester Curme
e53e91bcb2 update openai 2025-10-29 10:50:25 -04:00
Chester Curme
8797b167f5 update core 2025-10-29 10:41:16 -04:00
10 changed files with 117 additions and 33 deletions

View File

@@ -310,6 +310,13 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
does not properly support streaming.
"""
model_provider: str | None = None
"""The model provider name, e.g., 'openai', 'anthropic', etc.
Used to assign provenance on messages generated by the model, and to look up
model capabilities (e.g., context window sizes and feature support).
"""
output_version: str | None = Field(
default_factory=from_env("LC_OUTPUT_VERSION", default=None)
)
@@ -517,7 +524,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
for chunk in self._stream(input_messages, stop=stop, **kwargs):
if chunk.message.id is None:
chunk.message.id = run_id
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
chunk.message.response_metadata = _gen_info_and_msg_metadata(
chunk, model_provider=self.model_provider
)
if self.output_version == "v1":
# Overwrite .content with .content_blocks
chunk.message = _update_message_content_to_blocks(
@@ -649,7 +658,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
):
if chunk.message.id is None:
chunk.message.id = run_id
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
chunk.message.response_metadata = _gen_info_and_msg_metadata(
chunk, model_provider=self.model_provider
)
if self.output_version == "v1":
# Overwrite .content with .content_blocks
chunk.message = _update_message_content_to_blocks(
@@ -1147,7 +1158,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
index = -1
index_type = ""
for chunk in self._stream(messages, stop=stop, **kwargs):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
chunk.message.response_metadata = _gen_info_and_msg_metadata(
chunk, model_provider=self.model_provider
)
if self.output_version == "v1":
# Overwrite .content with .content_blocks
chunk.message = _update_message_content_to_blocks(
@@ -1195,6 +1208,20 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
else:
result = self._generate(messages, stop=stop, **kwargs)
# Add response metadata to each generation
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation, model_provider=self.model_provider
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if self.output_version == "v1":
# Overwrite .content with .content_blocks
for generation in result.generations:
@@ -1202,18 +1229,6 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
generation.message, "v1"
)
# Add response metadata to each generation
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if check_cache and llm_cache:
llm_cache.update(prompt, llm_string, result.generations)
return result
@@ -1265,7 +1280,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
index = -1
index_type = ""
async for chunk in self._astream(messages, stop=stop, **kwargs):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
chunk.message.response_metadata = _gen_info_and_msg_metadata(
chunk, model_provider=self.model_provider
)
if self.output_version == "v1":
# Overwrite .content with .content_blocks
chunk.message = _update_message_content_to_blocks(
@@ -1313,6 +1330,19 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
# Add response metadata to each generation
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation, model_provider=self.model_provider
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if self.output_version == "v1":
# Overwrite .content with .content_blocks
for generation in result.generations:
@@ -1320,18 +1350,6 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
generation.message, "v1"
)
# Add response metadata to each generation
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if len(result.generations) == 1 and result.llm_output is not None:
result.generations[0].message.response_metadata = {
**result.llm_output,
**result.generations[0].message.response_metadata,
}
if check_cache and llm_cache:
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
@@ -1719,11 +1737,15 @@ class SimpleChatModel(BaseChatModel):
def _gen_info_and_msg_metadata(
generation: ChatGeneration | ChatGenerationChunk,
model_provider: str | None = None,
) -> dict:
return {
metadata = {
**(generation.generation_info or {}),
**generation.message.response_metadata,
}
if model_provider and "model_provider" not in metadata:
metadata["model_provider"] = model_provider
return metadata
def _cleanup_llm_representation(serialized: Any, depth: int) -> None:

View File

@@ -1217,3 +1217,57 @@ def test_get_ls_params() -> None:
ls_params = llm._get_ls_params(stop=["stop"])
assert ls_params["ls_stop"] == ["stop"]
@pytest.mark.parametrize("output_version", ["v0", "v1"])
def test_model_provider_on_metadata(output_version: str) -> None:
"""Test we assign model_provider to response metadata."""
messages = [AIMessage("hello")]
chunks = [AIMessageChunk(content="good"), AIMessageChunk(content="bye")]
model = _AnotherFakeChatModel(
responses=iter(messages),
chunks=iter(chunks),
output_version=output_version,
model_provider="provider_foo",
)
response = model.invoke("hello")
assert response.response_metadata["model_provider"] == "provider_foo"
response = model.invoke("hello", stream=True)
assert response.response_metadata["model_provider"] == "provider_foo"
model.chunks = iter([AIMessageChunk(content="good"), AIMessageChunk(content="bye")])
full: AIMessageChunk | None = None
for chunk in model.stream("hello"):
full = chunk if full is None else full + chunk
assert full is not None
assert full.response_metadata["model_provider"] == "provider_foo"
@pytest.mark.parametrize("output_version", ["v0", "v1"])
async def test_model_provider_on_metadata_async(output_version: str) -> None:
"""Test we assign model_provider to response metadata."""
messages = [AIMessage("hello")]
chunks = [AIMessageChunk(content="good"), AIMessageChunk(content="bye")]
model = _AnotherFakeChatModel(
responses=iter(messages),
chunks=iter(chunks),
output_version=output_version,
model_provider="provider_foo",
)
response = await model.ainvoke("hello")
assert response.response_metadata["model_provider"] == "provider_foo"
response = await model.ainvoke("hello", stream=True)
assert response.response_metadata["model_provider"] == "provider_foo"
model.chunks = iter([AIMessageChunk(content="good"), AIMessageChunk(content="bye")])
full: AIMessageChunk | None = None
async for chunk in model.astream("hello"):
full = chunk if full is None else full + chunk
assert full is not None
assert full.response_metadata["model_provider"] == "provider_foo"

View File

@@ -135,6 +135,7 @@ def test_configurable() -> None:
"model_name": "gpt-4o",
"temperature": None,
"model_kwargs": {},
"model_provider": "openai",
"openai_api_key": SecretStr("foo"),
"openai_api_base": None,
"openai_organization": None,
@@ -268,6 +269,7 @@ def test_configurable_with_default() -> None:
"betas": None,
"default_headers": None,
"model_kwargs": {},
"model_provider": None,
"streaming": False,
"stream_usage": True,
"output_version": None,

View File

@@ -136,6 +136,7 @@ def test_configurable() -> None:
"model_name": "gpt-4o",
"temperature": None,
"model_kwargs": {},
"model_provider": "openai",
"openai_api_key": SecretStr("foo"),
"openai_api_base": None,
"openai_organization": None,
@@ -268,6 +269,7 @@ def test_configurable_with_default() -> None:
"betas": None,
"default_headers": None,
"model_kwargs": {},
"model_provider": None,
"streaming": False,
"stream_usage": True,
"output_version": None,

View File

@@ -550,6 +550,8 @@ class BaseChatOpenAI(BaseChatModel):
!!! warning "Behavior changed in 0.3.35"
Enabled for default base URL and client.
"""
model_provider: str | None = "openai"
"""The model provider name (openai)."""
max_retries: int | None = None
"""Maximum number of retries to make when generating."""
presence_penalty: float | None = None
@@ -1034,7 +1036,6 @@ class BaseChatOpenAI(BaseChatModel):
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
message_chunk.response_metadata["model_provider"] = "openai"
return ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
@@ -1398,7 +1399,6 @@ class BaseChatOpenAI(BaseChatModel):
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_provider": "openai",
"model_name": response_dict.get("model", self.model_name),
"system_fingerprint": response_dict.get("system_fingerprint", ""),
}
@@ -4081,7 +4081,6 @@ def _construct_lc_result_from_responses_api(
if metadata:
response_metadata.update(metadata)
# for compatibility with chat completion calls.
response_metadata["model_provider"] = "openai"
response_metadata["model_name"] = response_metadata.get("model")
if response.usage:
usage_metadata = _create_usage_metadata_responses(
@@ -4277,7 +4276,6 @@ def _convert_responses_chunk_to_generation_chunk(
tool_call_chunks: list = []
additional_kwargs: dict = {}
response_metadata = metadata or {}
response_metadata["model_provider"] = "openai"
usage_metadata = None
chunk_position: Literal["last"] | None = None
id = None

View File

@@ -15,6 +15,7 @@
}),
'max_retries': 2,
'max_tokens': 100,
'model_provider': 'openai',
'openai_api_key': dict({
'id': list([
'AZURE_OPENAI_API_KEY',

View File

@@ -11,6 +11,7 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'gpt-3.5-turbo',
'model_provider': 'openai',
'openai_api_key': dict({
'id': list([
'OPENAI_API_KEY',

View File

@@ -11,6 +11,7 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'gpt-3.5-turbo',
'model_provider': 'openai',
'openai_api_key': dict({
'id': list([
'OPENAI_API_KEY',

View File

@@ -398,6 +398,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
model_name: str = Field(default="grok-4", alias="model")
"""Model name to use."""
model_provider: str | None = "xai"
"""The model provider name (xai)."""
xai_api_key: SecretStr | None = Field(
alias="api_key",
default_factory=secret_from_env("XAI_API_KEY", default=None),

View File

@@ -10,6 +10,7 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-4',
'model_provider': 'xai',
'request_timeout': 60.0,
'stop': list([
]),