mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-02 07:45:39 +00:00
Hybrid search
This commit is contained in:
30
.github/workflows/actions/install_dependencies/action.yml
vendored
Normal file
30
.github/workflows/actions/install_dependencies/action.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: "Install Dependencies"
|
||||||
|
description: "Action to build the project dependencies from the main versions"
|
||||||
|
inputs:
|
||||||
|
python_version:
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
default: "3.11.4"
|
||||||
|
poetry_version:
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
default: "1.5.1"
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: composite
|
||||||
|
steps:
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1
|
||||||
|
with:
|
||||||
|
version: ${{ inputs.poetry_version }}
|
||||||
|
virtualenvs-create: true
|
||||||
|
virtualenvs-in-project: false
|
||||||
|
installer-parallel: true
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ inputs.python_version }}
|
||||||
|
cache: "poetry"
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: poetry install --extras "ui vector-stores-qdrant" --no-root
|
||||||
|
shell: bash
|
||||||
|
|
45
.github/workflows/docker.yml
vendored
Normal file
45
.github/workflows/docker.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
name: docker
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [ published ]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
env:
|
||||||
|
REGISTRY: ghcr.io
|
||||||
|
IMAGE_NAME: ${{ github.repository }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push-image:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Log in to the Container registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=ref,event=pr
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
|
type=sha
|
||||||
|
- name: Build and push Docker image
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: Dockerfile.external
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
21
.github/workflows/fern-check.yml
vendored
Normal file
21
.github/workflows/fern-check.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: fern check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- "fern/**"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
fern-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repo
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Fern
|
||||||
|
run: npm install -g fern-api
|
||||||
|
|
||||||
|
- name: Check Fern API is valid
|
||||||
|
run: fern check
|
50
.github/workflows/preview-docs.yml
vendored
Normal file
50
.github/workflows/preview-docs.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: deploy preview docs
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- "fern/**"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
preview-docs:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: refs/pull/${{ github.event.pull_request.number }}/merge
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "18"
|
||||||
|
|
||||||
|
- name: Install Fern
|
||||||
|
run: npm install -g fern-api
|
||||||
|
|
||||||
|
- name: Generate Documentation Preview with Fern
|
||||||
|
id: generate_docs
|
||||||
|
env:
|
||||||
|
FERN_TOKEN: ${{ secrets.FERN_TOKEN }}
|
||||||
|
run: |
|
||||||
|
output=$(fern generate --docs --preview --log-level debug)
|
||||||
|
echo "$output"
|
||||||
|
# Extract the URL
|
||||||
|
preview_url=$(echo "$output" | grep -oP '(?<=Published docs to )https://[^\s]*')
|
||||||
|
# Set the output for the step
|
||||||
|
echo "::set-output name=preview_url::$preview_url"
|
||||||
|
- name: Comment PR with URL using github-actions bot
|
||||||
|
uses: actions/github-script@v4
|
||||||
|
if: ${{ steps.generate_docs.outputs.preview_url }}
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const preview_url = '${{ steps.generate_docs.outputs.preview_url }}';
|
||||||
|
const issue_number = context.issue.number;
|
||||||
|
github.issues.createComment({
|
||||||
|
...context.repo,
|
||||||
|
issue_number: issue_number,
|
||||||
|
body: `Published docs preview URL: ${preview_url}`
|
||||||
|
})
|
26
.github/workflows/publish-docs.yml
vendored
Normal file
26
.github/workflows/publish-docs.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
name: publish docs
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- "fern/**"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
publish-docs:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repo
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup node
|
||||||
|
uses: actions/setup-node@v3
|
||||||
|
|
||||||
|
- name: Download Fern
|
||||||
|
run: npm install -g fern-api
|
||||||
|
|
||||||
|
- name: Generate and Publish Docs
|
||||||
|
env:
|
||||||
|
FERN_TOKEN: ${{ secrets.FERN_TOKEN }}
|
||||||
|
run: fern generate --docs --log-level debug
|
19
.github/workflows/release-please.yml
vendored
Normal file
19
.github/workflows/release-please.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
name: release-please
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release-please:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: google-github-actions/release-please-action@v3
|
||||||
|
with:
|
||||||
|
release-type: simple
|
||||||
|
version-file: version.txt
|
30
.github/workflows/stale.yml
vendored
Normal file
30
.github/workflows/stale.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
||||||
|
#
|
||||||
|
# You can adjust the behavior by modifying this file.
|
||||||
|
# For more information, see:
|
||||||
|
# https://github.com/actions/stale
|
||||||
|
name: Mark stale issues and pull requests
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '42 5 * * *'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v8
|
||||||
|
with:
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
days-before-stale: 15
|
||||||
|
stale-issue-message: 'Stale issue'
|
||||||
|
stale-pr-message: 'Stale pull request'
|
||||||
|
stale-issue-label: 'stale'
|
||||||
|
stale-pr-label: 'stale'
|
||||||
|
exempt-issue-labels: 'autorelease: pending'
|
||||||
|
exempt-pr-labels: 'autorelease: pending'
|
67
.github/workflows/tests.yml
vendored
Normal file
67
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
name: tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref || github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
setup:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: ./.github/workflows/actions/install_dependencies
|
||||||
|
|
||||||
|
checks:
|
||||||
|
needs: setup
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
name: ${{ matrix.quality-command }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
quality-command:
|
||||||
|
- black
|
||||||
|
- ruff
|
||||||
|
- mypy
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: ./.github/workflows/actions/install_dependencies
|
||||||
|
- name: run ${{ matrix.quality-command }}
|
||||||
|
run: make ${{ matrix.quality-command }}
|
||||||
|
|
||||||
|
test:
|
||||||
|
needs: setup
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
name: test
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: ./.github/workflows/actions/install_dependencies
|
||||||
|
- name: run test
|
||||||
|
run: make test-coverage
|
||||||
|
# Run even if make test fails for coverage reports
|
||||||
|
# TODO: select a better xml results displayer
|
||||||
|
- name: Archive test results coverage results
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: test_results
|
||||||
|
path: tests-results.xml
|
||||||
|
- name: Archive code coverage results
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: code-coverage-report
|
||||||
|
path: htmlcov/
|
||||||
|
|
||||||
|
all_checks_passed:
|
||||||
|
# Used to easily force requirements checks in GitHub
|
||||||
|
needs:
|
||||||
|
- checks
|
||||||
|
- test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- run: echo "All checks passed"
|
@@ -0,0 +1,32 @@
|
|||||||
|
"""Update is_enabled to false by default
|
||||||
|
|
||||||
|
Revision ID: 14281ff34686
|
||||||
|
Revises: b7b896502e8e
|
||||||
|
Create Date: 2024-03-18 16:33:43.133458
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '14281ff34686'
|
||||||
|
down_revision: Union[str, None] = 'b7b896502e8e'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||||
|
pass
|
||||||
|
# ### end Alembic commands ###
|
38
alembic/versions/b7b896502e8e_update.py
Normal file
38
alembic/versions/b7b896502e8e_update.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""update
|
||||||
|
|
||||||
|
Revision ID: b7b896502e8e
|
||||||
|
Revises:
|
||||||
|
Create Date: 2024-03-17 15:07:10.795935
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'b7b896502e8e'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint('document_department_association_department_id_fkey', 'document_department_association', type_='foreignkey')
|
||||||
|
op.drop_constraint('document_department_association_document_id_fkey', 'document_department_association', type_='foreignkey')
|
||||||
|
op.create_foreign_key(None, 'document_department_association', 'document', ['document_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
|
||||||
|
op.create_foreign_key(None, 'document_department_association', 'departments', ['department_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
|
||||||
|
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||||
|
op.drop_constraint(None, 'document_department_association', type_='foreignkey')
|
||||||
|
op.drop_constraint(None, 'document_department_association', type_='foreignkey')
|
||||||
|
op.create_foreign_key('document_department_association_document_id_fkey', 'document_department_association', 'document', ['document_id'], ['id'])
|
||||||
|
op.create_foreign_key('document_department_association_department_id_fkey', 'document_department_association', 'departments', ['department_id'], ['id'])
|
||||||
|
# ### end Alembic commands ###
|
@@ -11,7 +11,8 @@ from typing import Any
|
|||||||
|
|
||||||
from llama_index.core.data_structs import IndexDict
|
from llama_index.core.data_structs import IndexDict
|
||||||
from llama_index.core.embeddings.utils import EmbedType
|
from llama_index.core.embeddings.utils import EmbedType
|
||||||
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
|
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage, SimpleKeywordTableIndex
|
||||||
|
from private_gpt.utils.vector_store import VectorStoreIndex1
|
||||||
from llama_index.core.indices.base import BaseIndex
|
from llama_index.core.indices.base import BaseIndex
|
||||||
from llama_index.core.ingestion import run_transformations
|
from llama_index.core.ingestion import run_transformations
|
||||||
from llama_index.core.schema import BaseNode, Document, TransformComponent
|
from llama_index.core.schema import BaseNode, Document, TransformComponent
|
||||||
@@ -83,7 +84,7 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
# There are no index in the storage context, creating a new one
|
# There are no index in the storage context, creating a new one
|
||||||
logger.info("Creating a new vector store index")
|
logger.info("Creating a new vector store index")
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex1.from_documents(
|
||||||
[],
|
[],
|
||||||
storage_context=self.storage_context,
|
storage_context=self.storage_context,
|
||||||
store_nodes_override=True, # Force store nodes in index and document stores
|
store_nodes_override=True, # Force store nodes in index and document stores
|
||||||
@@ -92,6 +93,17 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
|||||||
transformations=self.transformations,
|
transformations=self.transformations,
|
||||||
)
|
)
|
||||||
index.storage_context.persist(persist_dir=local_data_path)
|
index.storage_context.persist(persist_dir=local_data_path)
|
||||||
|
|
||||||
|
keyword_index = SimpleKeywordTableIndex.from_documents(
|
||||||
|
[],
|
||||||
|
storage_context=self.storage_context,
|
||||||
|
store_nodes_override=True, # Force store nodes in index and document stores
|
||||||
|
show_progress=self.show_progress,
|
||||||
|
transformations=self.transformations,
|
||||||
|
llm=
|
||||||
|
)
|
||||||
|
# Store the keyword index in the vector store
|
||||||
|
index.keyword_index = keyword_index
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def _save_index(self) -> None:
|
def _save_index(self) -> None:
|
||||||
|
55
private_gpt/components/vector_store/retriever.py
Normal file
55
private_gpt/components/vector_store/retriever.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# import QueryBundle
|
||||||
|
from llama_index.core import QueryBundle
|
||||||
|
|
||||||
|
# import NodeWithScore
|
||||||
|
from llama_index.core.schema import NodeWithScore
|
||||||
|
|
||||||
|
# Retrievers
|
||||||
|
from llama_index.core.retrievers import (
|
||||||
|
BaseRetriever,
|
||||||
|
VectorIndexRetriever,
|
||||||
|
KeywordTableSimpleRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class CustomRetriever(BaseRetriever):
|
||||||
|
"""Custom retriever that performs both semantic search and hybrid search."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_retriever: VectorIndexRetriever,
|
||||||
|
keyword_retriever: KeywordTableSimpleRetriever,
|
||||||
|
mode: str = "AND",
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
|
||||||
|
self._vector_retriever = vector_retriever
|
||||||
|
self._keyword_retriever = keyword_retriever
|
||||||
|
if mode not in ("AND", "OR"):
|
||||||
|
raise ValueError("Invalid mode.")
|
||||||
|
self._mode = mode
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
"""Retrieve nodes given query."""
|
||||||
|
|
||||||
|
vector_nodes = self._vector_retriever.retrieve(query_bundle)
|
||||||
|
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
|
||||||
|
|
||||||
|
vector_ids = {n.node.node_id for n in vector_nodes}
|
||||||
|
keyword_ids = {n.node.node_id for n in keyword_nodes}
|
||||||
|
|
||||||
|
combined_dict = {n.node.node_id: n for n in vector_nodes}
|
||||||
|
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
|
||||||
|
|
||||||
|
if self._mode == "AND":
|
||||||
|
retrieve_ids = vector_ids.intersection(keyword_ids)
|
||||||
|
else:
|
||||||
|
retrieve_ids = vector_ids.union(keyword_ids)
|
||||||
|
|
||||||
|
retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
|
||||||
|
return retrieve_nodes
|
@@ -3,12 +3,14 @@ import typing
|
|||||||
|
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
|
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
|
||||||
|
from llama_index.core.indices import SimpleKeywordTableIndex
|
||||||
from llama_index.core.vector_stores.types import (
|
from llama_index.core.vector_stores.types import (
|
||||||
FilterCondition,
|
FilterCondition,
|
||||||
MetadataFilter,
|
MetadataFilter,
|
||||||
MetadataFilters,
|
MetadataFilters,
|
||||||
VectorStore,
|
VectorStore,
|
||||||
)
|
)
|
||||||
|
from private_gpt.utils.vector_store import VectorStoreIndex1
|
||||||
|
|
||||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||||
from private_gpt.paths import local_data_path
|
from private_gpt.paths import local_data_path
|
||||||
@@ -130,21 +132,43 @@ class VectorStoreComponent:
|
|||||||
|
|
||||||
def get_retriever(
|
def get_retriever(
|
||||||
self,
|
self,
|
||||||
index: VectorStoreIndex,
|
index: VectorStoreIndex1,
|
||||||
context_filter: ContextFilter | None = None,
|
context_filter: ContextFilter | None = None,
|
||||||
similarity_top_k: int = 2,
|
similarity_top_k: int = 2,
|
||||||
) -> VectorIndexRetriever:
|
) -> VectorIndexRetriever:
|
||||||
# This way we support qdrant (using doc_ids) and the rest (using filters)
|
# This way we support qdrant (using doc_ids) and the rest (using filters)
|
||||||
return VectorIndexRetriever(
|
|
||||||
|
# from llama_index.core import get_response_synthesizer
|
||||||
|
# from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
from .retriever import CustomRetriever
|
||||||
|
from llama_index.core.retrievers import (
|
||||||
|
VectorIndexRetriever,
|
||||||
|
KeywordTableSimpleRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_retriever = VectorIndexRetriever(
|
||||||
index=index,
|
index=index,
|
||||||
similarity_top_k=similarity_top_k,
|
similarity_top_k=similarity_top_k,
|
||||||
doc_ids=context_filter.docs_ids if context_filter else None,
|
doc_ids=context_filter.docs_ids if context_filter else None,
|
||||||
filters=(
|
filters=(
|
||||||
_doc_id_metadata_filter(context_filter)
|
_doc_id_metadata_filter(context_filter)
|
||||||
if self.settings.vectorstore.database != "qdrant"
|
if self.settings.vectorstore.database != "qdrant"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
keyword_retriever = KeywordTableSimpleRetriever(index=index.keyword_index)
|
||||||
|
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever)
|
||||||
|
|
||||||
|
# define response synthesizer
|
||||||
|
# response_synthesizer = get_response_synthesizer()
|
||||||
|
|
||||||
|
# # assemble query engine
|
||||||
|
# custom_query_engine = RetrieverQueryEngine(
|
||||||
|
# retriever=custom_retriever,
|
||||||
|
# response_synthesizer=response_synthesizer,
|
||||||
|
# )
|
||||||
|
|
||||||
|
return custom_retriever
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if hasattr(self.vector_store.client, "close"):
|
if hasattr(self.vector_store.client, "close"):
|
||||||
|
@@ -5,7 +5,7 @@ from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
|
|||||||
from llama_index.core.chat_engine.types import (
|
from llama_index.core.chat_engine.types import (
|
||||||
BaseChatEngine,
|
BaseChatEngine,
|
||||||
)
|
)
|
||||||
from llama_index.core.indices import VectorStoreIndex
|
from llama_index.core.indices import VectorStoreIndex, SimpleKeywordTableIndex
|
||||||
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
|
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
|
||||||
from llama_index.core.llms import ChatMessage, MessageRole
|
from llama_index.core.llms import ChatMessage, MessageRole
|
||||||
from llama_index.core.postprocessor import (
|
from llama_index.core.postprocessor import (
|
||||||
@@ -99,6 +99,12 @@ class ChatService:
|
|||||||
embed_model=embedding_component.embedding_model,
|
embed_model=embedding_component.embedding_model,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
)
|
)
|
||||||
|
self.keyword_index = SimpleKeywordTableIndex.from_documents(
|
||||||
|
vector_store_component.vector_store,
|
||||||
|
storage_context=self.storage_context,
|
||||||
|
embed_model=embedding_component.embedding_model,
|
||||||
|
show_progress=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _chat_engine(
|
def _chat_engine(
|
||||||
self,
|
self,
|
||||||
@@ -110,6 +116,7 @@ class ChatService:
|
|||||||
if use_context:
|
if use_context:
|
||||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||||
index=self.index,
|
index=self.index,
|
||||||
|
keyword_index=self.keyword_index,
|
||||||
context_filter=context_filter,
|
context_filter=context_filter,
|
||||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||||
)
|
)
|
||||||
|
13
private_gpt/utils/vector_store.py
Normal file
13
private_gpt/utils/vector_store.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||||
|
from llama_index.core.indices import SimpleKeywordTableIndex
|
||||||
|
|
||||||
|
class VectorStoreIndex1(VectorStoreIndex):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.keyword_index = None
|
||||||
|
|
||||||
|
def set_keyword_index(self, keyword_index: SimpleKeywordTableIndex):
|
||||||
|
self.keyword_index = keyword_index
|
||||||
|
|
||||||
|
def get_keyword_index(self) -> SimpleKeywordTableIndex:
|
||||||
|
return self.keyword_index
|
Reference in New Issue
Block a user