mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
core[minor], langchain[patch], experimental[patch]: Added missing py.typed
to langchain_core
(#14143)
See PR title. From what I can see, `poetry` will auto-include this. Please let me know if I am missing something here. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
f7c257553d
commit
24385a00de
0
libs/core/langchain_core/py.typed
Normal file
0
libs/core/langchain_core/py.typed
Normal file
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, List, cast
|
||||
|
||||
from langchain.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
@ -68,9 +68,9 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
|
||||
time_prompt = SystemMessage(
|
||||
content=f"The current time and date is {time.strftime('%c')}"
|
||||
)
|
||||
used_tokens = self.token_counter(base_prompt.content) + self.token_counter(
|
||||
time_prompt.content
|
||||
)
|
||||
used_tokens = self.token_counter(
|
||||
cast(str, base_prompt.content)
|
||||
) + self.token_counter(cast(str, time_prompt.content))
|
||||
memory: VectorStoreRetriever = kwargs["memory"]
|
||||
previous_messages = kwargs["messages"]
|
||||
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
|
||||
@ -88,7 +88,7 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
|
||||
f"from your past:\n{relevant_memory}\n\n"
|
||||
)
|
||||
memory_message = SystemMessage(content=content_format)
|
||||
used_tokens += self.token_counter(memory_message.content)
|
||||
used_tokens += self.token_counter(cast(str, memory_message.content))
|
||||
historical_messages: List[BaseMessage] = []
|
||||
for message in previous_messages[-10:][::-1]:
|
||||
message_tokens = self.token_counter(message.content)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Generic Wrapper for chat LLMs, with sample implementations
|
||||
for Llama-2-chat, Llama-2-instruct and Vicuna models.
|
||||
"""
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -90,8 +90,12 @@ class ChatWrapper(BaseChatModel):
|
||||
if self.usr_0_end is None:
|
||||
self.usr_0_end = self.usr_n_end
|
||||
|
||||
prompt_parts.append(self.sys_beg + messages[0].content + self.sys_end)
|
||||
prompt_parts.append(self.usr_0_beg + messages[1].content + self.usr_0_end)
|
||||
prompt_parts.append(
|
||||
self.sys_beg + cast(str, messages[0].content) + self.sys_end
|
||||
)
|
||||
prompt_parts.append(
|
||||
self.usr_0_beg + cast(str, messages[1].content) + self.usr_0_end
|
||||
)
|
||||
|
||||
for ai_message, human_message in zip(messages[2::2], messages[3::2]):
|
||||
if not isinstance(ai_message, AIMessage) or not isinstance(
|
||||
@ -102,8 +106,12 @@ class ChatWrapper(BaseChatModel):
|
||||
"optionally prepended by a system message"
|
||||
)
|
||||
|
||||
prompt_parts.append(self.ai_n_beg + ai_message.content + self.ai_n_end)
|
||||
prompt_parts.append(self.usr_n_beg + human_message.content + self.usr_n_end)
|
||||
prompt_parts.append(
|
||||
self.ai_n_beg + cast(str, ai_message.content) + self.ai_n_end
|
||||
)
|
||||
prompt_parts.append(
|
||||
self.usr_n_beg + cast(str, human_message.content) + self.usr_n_end
|
||||
)
|
||||
|
||||
return "".join(prompt_parts)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
@ -54,10 +54,10 @@ class BaseModeration:
|
||||
message = prompt.messages[-1]
|
||||
self.chat_message_index = len(prompt.messages) - 1
|
||||
if isinstance(message, HumanMessage):
|
||||
input_text = message.content
|
||||
input_text = cast(str, message.content)
|
||||
|
||||
if isinstance(message, AIMessage):
|
||||
input_text = message.content
|
||||
input_text = cast(str, message.content)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input_text)}. "
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, DefaultDict, Dict, List, Optional
|
||||
from typing import Any, DefaultDict, Dict, List, Optional, cast
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
@ -176,7 +176,7 @@ class AnthropicFunctions(BaseChatModel):
|
||||
response = self.model.predict_messages(
|
||||
messages, stop=stop, callbacks=run_manager, **kwargs
|
||||
)
|
||||
completion = response.content
|
||||
completion = cast(str, response.content)
|
||||
if forced:
|
||||
tag_parser = TagParser()
|
||||
|
||||
@ -210,7 +210,7 @@ class AnthropicFunctions(BaseChatModel):
|
||||
message = AIMessage(content=msg, additional_kwargs=kwargs)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
else:
|
||||
response.content = response.content.strip()
|
||||
response.content = cast(str, response.content).strip()
|
||||
return ChatResult(generations=[ChatGeneration(message=response)])
|
||||
|
||||
@property
|
||||
|
@ -239,7 +239,7 @@ def __getattr__(name: str) -> Any:
|
||||
|
||||
return FewShotPromptTemplate
|
||||
elif name == "Prompt":
|
||||
from langchain_core.prompts import Prompt
|
||||
from langchain.prompts import Prompt
|
||||
|
||||
_warn_on_import(name, replacement="langchain.prompts.Prompt")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -33,9 +33,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||
content = _dict.get("choice", {}).get("message", {}).get("content", "")
|
||||
return AIMessage(
|
||||
content=content,
|
||||
)
|
||||
return AIMessage(content=content)
|
||||
|
||||
|
||||
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||
@ -118,7 +116,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||
msg = convert_dict_to_message(res)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(msg.content)
|
||||
run_manager.on_llm_new_token(cast(str, msg.content))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -135,7 +133,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
res = self.client.chat(params)
|
||||
msg = convert_dict_to_message(res)
|
||||
completion = msg.content
|
||||
completion = cast(str, msg.content)
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
@ -41,7 +41,7 @@ class ForefrontAI(LLM):
|
||||
repetition_penalty: int = 1
|
||||
"""Penalizes repeated tokens according to frequency."""
|
||||
|
||||
forefrontai_api_key: SecretStr = None
|
||||
forefrontai_api_key: SecretStr
|
||||
|
||||
base_url: Optional[str] = None
|
||||
"""Base url to use, if None decides based on model name."""
|
||||
@ -51,7 +51,7 @@ class ForefrontAI(LLM):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["forefrontai_api_key"] = convert_to_secret_str(
|
||||
|
@ -159,6 +159,7 @@ class _VertexAIBase(BaseModel):
|
||||
|
||||
class _VertexAICommon(_VertexAIBase):
|
||||
client: "_LanguageModel" = None #: :meta private:
|
||||
client_preview: "_LanguageModel" = None #: :meta private:
|
||||
model_name: str
|
||||
"Underlying model name."
|
||||
temperature: float = 0.0
|
||||
@ -406,13 +407,16 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
values["async_client"] = PredictionServiceAsyncClient(
|
||||
client_options=client_options, client_info=client_info
|
||||
)
|
||||
values["endpoint_path"] = values["client"].endpoint_path(
|
||||
project=values["project"],
|
||||
location=values["location"],
|
||||
endpoint=values["endpoint_id"],
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def endpoint_path(self) -> str:
|
||||
return self.client.endpoint_path(
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
endpoint=self.endpoint_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai_model_garden"
|
||||
|
Loading…
Reference in New Issue
Block a user