mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 02:33:34 +00:00
Compare commits
6 Commits
sr/wrap-mo
...
cc/model_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca462b71dc | ||
|
|
33a91fdf9a | ||
|
|
8b05cb4522 | ||
|
|
9982e28aaa | ||
|
|
e53e91bcb2 | ||
|
|
8797b167f5 |
@@ -310,6 +310,13 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
does not properly support streaming.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
"""The model provider name, e.g., 'openai', 'anthropic', etc.
|
||||
|
||||
Used to assign provenance on messages generated by the model, and to look up
|
||||
model capabilities (e.g., context window sizes and feature support).
|
||||
"""
|
||||
|
||||
output_version: str | None = Field(
|
||||
default_factory=from_env("LC_OUTPUT_VERSION", default=None)
|
||||
)
|
||||
@@ -517,7 +524,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
for chunk in self._stream(input_messages, stop=stop, **kwargs):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
chunk, model_provider=self.model_provider
|
||||
)
|
||||
if self.output_version == "v1":
|
||||
# Overwrite .content with .content_blocks
|
||||
chunk.message = _update_message_content_to_blocks(
|
||||
@@ -649,7 +658,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = run_id
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
chunk, model_provider=self.model_provider
|
||||
)
|
||||
if self.output_version == "v1":
|
||||
# Overwrite .content with .content_blocks
|
||||
chunk.message = _update_message_content_to_blocks(
|
||||
@@ -1147,7 +1158,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
index = -1
|
||||
index_type = ""
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
chunk, model_provider=self.model_provider
|
||||
)
|
||||
if self.output_version == "v1":
|
||||
# Overwrite .content with .content_blocks
|
||||
chunk.message = _update_message_content_to_blocks(
|
||||
@@ -1195,6 +1208,20 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation, model_provider=self.model_provider
|
||||
)
|
||||
|
||||
if len(result.generations) == 1 and result.llm_output is not None:
|
||||
result.generations[0].message.response_metadata = {
|
||||
**result.llm_output,
|
||||
**result.generations[0].message.response_metadata,
|
||||
}
|
||||
|
||||
if self.output_version == "v1":
|
||||
# Overwrite .content with .content_blocks
|
||||
for generation in result.generations:
|
||||
@@ -1202,18 +1229,6 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
generation.message, "v1"
|
||||
)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if len(result.generations) == 1 and result.llm_output is not None:
|
||||
result.generations[0].message.response_metadata = {
|
||||
**result.llm_output,
|
||||
**result.generations[0].message.response_metadata,
|
||||
}
|
||||
if check_cache and llm_cache:
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
@@ -1265,7 +1280,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
index = -1
|
||||
index_type = ""
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
chunk, model_provider=self.model_provider
|
||||
)
|
||||
if self.output_version == "v1":
|
||||
# Overwrite .content with .content_blocks
|
||||
chunk.message = _update_message_content_to_blocks(
|
||||
@@ -1313,6 +1330,19 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation, model_provider=self.model_provider
|
||||
)
|
||||
if len(result.generations) == 1 and result.llm_output is not None:
|
||||
result.generations[0].message.response_metadata = {
|
||||
**result.llm_output,
|
||||
**result.generations[0].message.response_metadata,
|
||||
}
|
||||
|
||||
if self.output_version == "v1":
|
||||
# Overwrite .content with .content_blocks
|
||||
for generation in result.generations:
|
||||
@@ -1320,18 +1350,6 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
generation.message, "v1"
|
||||
)
|
||||
|
||||
# Add response metadata to each generation
|
||||
for idx, generation in enumerate(result.generations):
|
||||
if run_manager and generation.message.id is None:
|
||||
generation.message.id = f"{LC_ID_PREFIX}-{run_manager.run_id}-{idx}"
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if len(result.generations) == 1 and result.llm_output is not None:
|
||||
result.generations[0].message.response_metadata = {
|
||||
**result.llm_output,
|
||||
**result.generations[0].message.response_metadata,
|
||||
}
|
||||
if check_cache and llm_cache:
|
||||
await llm_cache.aupdate(prompt, llm_string, result.generations)
|
||||
return result
|
||||
@@ -1719,11 +1737,15 @@ class SimpleChatModel(BaseChatModel):
|
||||
|
||||
def _gen_info_and_msg_metadata(
|
||||
generation: ChatGeneration | ChatGenerationChunk,
|
||||
model_provider: str | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
metadata = {
|
||||
**(generation.generation_info or {}),
|
||||
**generation.message.response_metadata,
|
||||
}
|
||||
if model_provider and "model_provider" not in metadata:
|
||||
metadata["model_provider"] = model_provider
|
||||
return metadata
|
||||
|
||||
|
||||
def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
|
||||
|
||||
@@ -1217,3 +1217,57 @@ def test_get_ls_params() -> None:
|
||||
|
||||
ls_params = llm._get_ls_params(stop=["stop"])
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("output_version", ["v0", "v1"])
|
||||
def test_model_provider_on_metadata(output_version: str) -> None:
|
||||
"""Test we assign model_provider to response metadata."""
|
||||
messages = [AIMessage("hello")]
|
||||
chunks = [AIMessageChunk(content="good"), AIMessageChunk(content="bye")]
|
||||
|
||||
model = _AnotherFakeChatModel(
|
||||
responses=iter(messages),
|
||||
chunks=iter(chunks),
|
||||
output_version=output_version,
|
||||
model_provider="provider_foo",
|
||||
)
|
||||
|
||||
response = model.invoke("hello")
|
||||
assert response.response_metadata["model_provider"] == "provider_foo"
|
||||
|
||||
response = model.invoke("hello", stream=True)
|
||||
assert response.response_metadata["model_provider"] == "provider_foo"
|
||||
|
||||
model.chunks = iter([AIMessageChunk(content="good"), AIMessageChunk(content="bye")])
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in model.stream("hello"):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert full is not None
|
||||
assert full.response_metadata["model_provider"] == "provider_foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("output_version", ["v0", "v1"])
|
||||
async def test_model_provider_on_metadata_async(output_version: str) -> None:
|
||||
"""Test we assign model_provider to response metadata."""
|
||||
messages = [AIMessage("hello")]
|
||||
chunks = [AIMessageChunk(content="good"), AIMessageChunk(content="bye")]
|
||||
|
||||
model = _AnotherFakeChatModel(
|
||||
responses=iter(messages),
|
||||
chunks=iter(chunks),
|
||||
output_version=output_version,
|
||||
model_provider="provider_foo",
|
||||
)
|
||||
|
||||
response = await model.ainvoke("hello")
|
||||
assert response.response_metadata["model_provider"] == "provider_foo"
|
||||
|
||||
response = await model.ainvoke("hello", stream=True)
|
||||
assert response.response_metadata["model_provider"] == "provider_foo"
|
||||
|
||||
model.chunks = iter([AIMessageChunk(content="good"), AIMessageChunk(content="bye")])
|
||||
full: AIMessageChunk | None = None
|
||||
async for chunk in model.astream("hello"):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert full is not None
|
||||
assert full.response_metadata["model_provider"] == "provider_foo"
|
||||
|
||||
@@ -135,6 +135,7 @@ def test_configurable() -> None:
|
||||
"model_name": "gpt-4o",
|
||||
"temperature": None,
|
||||
"model_kwargs": {},
|
||||
"model_provider": "openai",
|
||||
"openai_api_key": SecretStr("foo"),
|
||||
"openai_api_base": None,
|
||||
"openai_organization": None,
|
||||
@@ -268,6 +269,7 @@ def test_configurable_with_default() -> None:
|
||||
"betas": None,
|
||||
"default_headers": None,
|
||||
"model_kwargs": {},
|
||||
"model_provider": None,
|
||||
"streaming": False,
|
||||
"stream_usage": True,
|
||||
"output_version": None,
|
||||
|
||||
@@ -136,6 +136,7 @@ def test_configurable() -> None:
|
||||
"model_name": "gpt-4o",
|
||||
"temperature": None,
|
||||
"model_kwargs": {},
|
||||
"model_provider": "openai",
|
||||
"openai_api_key": SecretStr("foo"),
|
||||
"openai_api_base": None,
|
||||
"openai_organization": None,
|
||||
@@ -268,6 +269,7 @@ def test_configurable_with_default() -> None:
|
||||
"betas": None,
|
||||
"default_headers": None,
|
||||
"model_kwargs": {},
|
||||
"model_provider": None,
|
||||
"streaming": False,
|
||||
"stream_usage": True,
|
||||
"output_version": None,
|
||||
|
||||
@@ -550,6 +550,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
!!! warning "Behavior changed in 0.3.35"
|
||||
Enabled for default base URL and client.
|
||||
"""
|
||||
model_provider: str | None = "openai"
|
||||
"""The model provider name (openai)."""
|
||||
max_retries: int | None = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
presence_penalty: float | None = None
|
||||
@@ -1034,7 +1036,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
message_chunk.response_metadata["model_provider"] = "openai"
|
||||
return ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
@@ -1398,7 +1399,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
generations.append(gen)
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_provider": "openai",
|
||||
"model_name": response_dict.get("model", self.model_name),
|
||||
"system_fingerprint": response_dict.get("system_fingerprint", ""),
|
||||
}
|
||||
@@ -4081,7 +4081,6 @@ def _construct_lc_result_from_responses_api(
|
||||
if metadata:
|
||||
response_metadata.update(metadata)
|
||||
# for compatibility with chat completion calls.
|
||||
response_metadata["model_provider"] = "openai"
|
||||
response_metadata["model_name"] = response_metadata.get("model")
|
||||
if response.usage:
|
||||
usage_metadata = _create_usage_metadata_responses(
|
||||
@@ -4277,7 +4276,6 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
tool_call_chunks: list = []
|
||||
additional_kwargs: dict = {}
|
||||
response_metadata = metadata or {}
|
||||
response_metadata["model_provider"] = "openai"
|
||||
usage_metadata = None
|
||||
chunk_position: Literal["last"] | None = None
|
||||
id = None
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
}),
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_provider': 'openai',
|
||||
'openai_api_key': dict({
|
||||
'id': list([
|
||||
'AZURE_OPENAI_API_KEY',
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
'model_provider': 'openai',
|
||||
'openai_api_key': dict({
|
||||
'id': list([
|
||||
'OPENAI_API_KEY',
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
'model_provider': 'openai',
|
||||
'openai_api_key': dict({
|
||||
'id': list([
|
||||
'OPENAI_API_KEY',
|
||||
|
||||
@@ -398,6 +398,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
model_name: str = Field(default="grok-4", alias="model")
|
||||
"""Model name to use."""
|
||||
model_provider: str | None = "xai"
|
||||
"""The model provider name (xai)."""
|
||||
xai_api_key: SecretStr | None = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("XAI_API_KEY", default=None),
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'grok-4',
|
||||
'model_provider': 'xai',
|
||||
'request_timeout': 60.0,
|
||||
'stop': list([
|
||||
]),
|
||||
|
||||
Reference in New Issue
Block a user