mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 05:09:12 +00:00
Compare commits
1 Commits
langchain-
...
harrison/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee8203859c |
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Sequence, Set
|
||||
from typing import Any, List, Optional, Sequence, Set
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Predict text from text."""
|
||||
|
||||
@abstractmethod
|
||||
def predict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Predict message from messages."""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
"""Predict text from text."""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
"""Predict message from messages."""
|
||||
|
||||
|
||||
@@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
@@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
|
||||
@@ -64,11 +64,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params.update(kwargs)
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
@@ -82,9 +84,9 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
)
|
||||
try:
|
||||
results = [
|
||||
self._generate(m, stop=stop, run_manager=run_manager)
|
||||
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._generate(m, stop=stop)
|
||||
else self._generate(m, stop=stop, **kwargs)
|
||||
for m in messages
|
||||
]
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@@ -103,10 +105,12 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
params.update(kwargs)
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
@@ -121,9 +125,9 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager)
|
||||
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
|
||||
if new_arg_supported
|
||||
else self._agenerate(m, stop=stop)
|
||||
else self._agenerate(m, stop=stop, **kwargs)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
@@ -143,18 +147,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
||||
|
||||
async def agenerate_prompt(
|
||||
self,
|
||||
prompts: List[PromptValue],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
|
||||
return await self.agenerate(
|
||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
@@ -162,6 +170,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@@ -171,6 +180,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
@@ -179,9 +189,10 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
) -> BaseMessage:
|
||||
generation = self.generate(
|
||||
[messages], stop=stop, callbacks=callbacks
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
).generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
@@ -193,50 +204,69 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
) -> BaseMessage:
|
||||
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
|
||||
result = await self.agenerate(
|
||||
[messages], stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
generation = result.generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
|
||||
return self.predict(message, stop=stop)
|
||||
def call_as_llm(
|
||||
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
return self.predict(message, stop=stop, **kwargs)
|
||||
|
||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
def predict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = self([HumanMessage(content=text)], stop=_stop)
|
||||
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
||||
return result.content
|
||||
|
||||
def predict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return self(messages, stop=_stop)
|
||||
return self(messages, stop=_stop, **kwargs)
|
||||
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
async def apredict(
|
||||
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
|
||||
result = await self._call_async(
|
||||
[HumanMessage(content=text)], stop=_stop, **kwargs
|
||||
)
|
||||
return result.content
|
||||
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return await self._call_async(messages, stop=_stop)
|
||||
return await self._call_async(messages, stop=_stop, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@@ -261,8 +291,9 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager)
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
@@ -273,6 +304,7 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
"""Simpler interface."""
|
||||
|
||||
@@ -281,6 +313,9 @@ class SimpleChatModel(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
func = partial(self._generate, messages, stop=stop, run_manager=run_manager)
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
@@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
@@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _response_to_result(response, stop)
|
||||
@@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
@@ -311,6 +314,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _response_to_result(response, stop)
|
||||
|
||||
@@ -302,8 +302,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
@@ -324,13 +325,14 @@ class ChatOpenAI(BaseChatModel):
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = dict(self._invocation_params)
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
params.update(kwargs)
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
@@ -348,8 +350,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
|
||||
@@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
@@ -52,7 +53,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
message_dicts, params = super()._create_message_dicts(messages, stop)
|
||||
for i, generation in enumerate(generated_responses.generations):
|
||||
response_dict, params = super()._create_message_dicts(
|
||||
[generation.message], stop
|
||||
[generation.message], stop, **kwargs
|
||||
)
|
||||
pl_request_id = promptlayer_api_request(
|
||||
"langchain.PromptLayerChatOpenAI",
|
||||
@@ -79,6 +80,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> ChatResult:
|
||||
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request_async
|
||||
@@ -89,7 +91,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
message_dicts, params = super()._create_message_dicts(messages, stop)
|
||||
for i, generation in enumerate(generated_responses.generations):
|
||||
response_dict, params = super()._create_message_dicts(
|
||||
[generation.message], stop
|
||||
[generation.message], stop, **kwargs
|
||||
)
|
||||
pl_request_id = await promptlayer_api_request_async(
|
||||
"langchain.PromptLayerChatOpenAI.async",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Wrapper around Google VertexAI chat-based models."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
@@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate next turn in the conversation.
|
||||
|
||||
@@ -122,7 +123,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
chat = self.client.start_chat(context=context, **self._default_params)
|
||||
for pair in history.history:
|
||||
chat._history.append((pair.question.content, pair.answer.content))
|
||||
response = chat.send_message(question.content, **self._default_params)
|
||||
response = chat.send_message(question.content, **self._default_params, **kwargs)
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
|
||||
@@ -131,6 +132,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError(
|
||||
"""Vertex AI doesn't support async requests at the moment."""
|
||||
|
||||
Reference in New Issue
Block a user