various templates improvements (#12500)

This commit is contained in:
Harrison Chase
2023-10-28 22:13:22 -07:00
committed by GitHub
parent d85d4d7822
commit 9e0ae56287
50 changed files with 462 additions and 282 deletions

View File

@@ -1,26 +1,47 @@
import os
from pathlib import Path
from elasticsearch import Elasticsearch
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.json import SimpleJsonOutputParser
from langchain.pydantic_v1 import BaseModel
from .elastic_index_info import get_indices_infos
from .prompts import DSL_PROMPT
es_host = os.environ["ELASTIC_SEARCH_SERVER"]
es_password = os.environ["ELASTIC_PASSWORD"]
# Setup Elasticsearch
# This shows how to set it up for a cloud hosted version
# Password for the 'elastic' user generated by Elasticsearch
ELASTIC_PASSWORD = "..."
# Found in the 'Manage Deployment' page
CLOUD_ID = "..."
# Create the client instance
db = Elasticsearch(
es_host,
http_auth=('elastic', es_password),
ca_certs=Path(__file__).parents[1] / 'http_ca.crt' # Replace with your actual path
cloud_id=CLOUD_ID,
basic_auth=("elastic", ELASTIC_PASSWORD)
)
# Specify indices to include
# If you want to use on your own indices, you will need to change this.
INCLUDE_INDICES = ["customers"]
# With the Elasticsearch connection created, we can now move on to the chain
_model = ChatOpenAI(temperature=0, model="gpt-4")
chain = {
"input": lambda x: x["input"],
"indices_info": lambda _: get_indices_infos(db),
# This line only get index info for "customers" index.
# If you are running this on your own data, you will want to change.
"indices_info": lambda _: get_indices_infos(db, include_indices=INCLUDE_INDICES),
"top_k": lambda x: x.get("top_k", 5),
} | DSL_PROMPT | _model | SimpleJsonOutputParser()
# Nicely typed inputs for playground
class ChainInputs(BaseModel):
input: str
top_k: int = 5
chain = chain.with_types(input_type=ChainInputs)

View File

@@ -13,8 +13,18 @@ def _list_indices(database, include_indices=None, ignore_indices=None) -> List[s
return all_indices
def get_indices_infos(database, sample_documents_in_index_info=5) -> str:
indices = _list_indices(database)
def get_indices_infos(
database,
sample_documents_in_index_info=5,
include_indices=None,
ignore_indices=None
) -> str:
indices = _list_indices(
database,
include_indices=include_indices,
ignore_indices=ignore_indices
)
mappings = database.indices.get_mapping(index=",".join(indices))
if sample_documents_in_index_info > 0:
for k, v in mappings.items():