浏览代码

Feat/vector db manage (#997)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 年之前
父节点
当前提交
b1fd1b3ab3

+ 68 - 0
api/commands.py

@@ -1,4 +1,5 @@
 import datetime
+import json
 import math
 import random
 import string
@@ -6,10 +7,16 @@ import time
 
 import click
 from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
 from werkzeug.exceptions import NotFound
 
+from core.embedding.cached_embedding import CacheEmbedding
 from core.index.index import IndexBuilder
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
+from core.model_providers.models.entity.model_params import ModelType
 from core.model_providers.providers.hosted import hosted_model_providers
+from core.model_providers.providers.openai_provider import OpenAIProvider
 from libs.password import password_pattern, valid_password, hash_password
 from libs.helper import email as email_validate
 from extensions.ext_database import db
@@ -296,6 +303,66 @@ def sync_anthropic_hosted_providers():
     click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
 
 
+@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
+def create_qdrant_indexes():
+    click.echo(click.style('Start create qdrant indexes.', fg='green'))
+    create_count = 0
+
+    page = 1
+    while True:
+        try:
+            datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
+                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for dataset in datasets:
+            try:
+                click.echo('Create dataset qdrant index: {}'.format(dataset.id))
+                try:
+                    embedding_model = ModelFactory.get_embedding_model(
+                        tenant_id=dataset.tenant_id,
+                        model_provider_name=dataset.embedding_model_provider,
+                        model_name=dataset.embedding_model
+                    )
+                except Exception:
+                    provider = Provider(
+                        id='provider_id',
+                        tenant_id='tenant_id',
+                        provider_name='openai',
+                        provider_type=ProviderType.CUSTOM.value,
+                        encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
+                        is_valid=True,
+                    )
+                    model_provider = OpenAIProvider(provider=provider)
+                    embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
+                embeddings = CacheEmbedding(embedding_model)
+
+                from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
+
+                index = QdrantVectorIndex(
+                    dataset=dataset,
+                    config=QdrantConfig(
+                        endpoint=current_app.config.get('QDRANT_URL'),
+                        api_key=current_app.config.get('QDRANT_API_KEY'),
+                        root_path=current_app.root_path
+                    ),
+                    embeddings=embeddings
+                )
+                if index:
+                    index.create_qdrant_dataset(dataset)
+                    create_count += 1
+                else:
+                    click.echo('passed.')
+            except Exception as e:
+                click.echo(
+                    click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+                continue
+
+    click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
+
+
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
@@ -304,3 +371,4 @@ def register_commands(app):
     app.cli.add_command(recreate_all_dataset_indexes)
     app.cli.add_command(sync_anthropic_hosted_providers)
     app.cli.add_command(clean_unused_dataset_indexes)
+    app.cli.add_command(create_qdrant_indexes)

+ 1 - 1
api/core/data_loader/loader/excel.py

@@ -38,7 +38,7 @@ class ExcelLoader(BaseLoader):
                 else:
                     row_dict = dict(zip(keys, list(map(str, row))))
                     row_dict = {k: v for k, v in row_dict.items() if v}
-                    item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
+                    item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
                     document = Document(page_content=item, metadata={'source': self._file_path})
                     data.append(document)
 

+ 46 - 0
api/core/index/vector_index/base.py

@@ -173,3 +173,49 @@ class BaseVectorIndex(BaseIndex):
 
         self.dataset = dataset
         logging.info(f"Dataset {dataset.id} recreate successfully.")
+
+    def create_qdrant_dataset(self, dataset: Dataset):
+        logging.info(f"create_qdrant_dataset {dataset.id}")
+
+        try:
+            self.delete()
+        except UnexpectedStatusCodeException as e:
+            if e.status_code != 400:
+                # 400 means index not exists
+                raise e
+
+        dataset_documents = db.session.query(DatasetDocument).filter(
+            DatasetDocument.dataset_id == dataset.id,
+            DatasetDocument.indexing_status == 'completed',
+            DatasetDocument.enabled == True,
+            DatasetDocument.archived == False,
+        ).all()
+
+        documents = []
+        for dataset_document in dataset_documents:
+            segments = db.session.query(DocumentSegment).filter(
+                DocumentSegment.document_id == dataset_document.id,
+                DocumentSegment.status == 'completed',
+                DocumentSegment.enabled == True
+            ).all()
+
+            for segment in segments:
+                document = Document(
+                    page_content=segment.content,
+                    metadata={
+                        "doc_id": segment.index_node_id,
+                        "doc_hash": segment.index_node_hash,
+                        "document_id": segment.document_id,
+                        "dataset_id": segment.dataset_id,
+                    }
+                )
+
+                documents.append(document)
+
+        if documents:
+            try:
+                self.create(documents)
+            except Exception as e:
+                raise e
+
+        logging.info(f"Dataset {dataset.id} recreate successfully.")

+ 114 - 0
api/core/index/vector_index/milvus_vector_index.py

@@ -0,0 +1,114 @@
+from typing import Optional, cast
+
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document, BaseRetriever
+from langchain.vectorstores import VectorStore, milvus
+from pydantic import BaseModel, root_validator
+
+from core.index.base import BaseIndex
+from core.index.vector_index.base import BaseVectorIndex
+from core.vector_store.milvus_vector_store import MilvusVectorStore
+from core.vector_store.weaviate_vector_store import WeaviateVectorStore
+from models.dataset import Dataset
+
+
+class MilvusConfig(BaseModel):
+    endpoint: str
+    user: str
+    password: str
+    batch_size: int = 100
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['endpoint']:
+            raise ValueError("config MILVUS_ENDPOINT is required")
+        if not values['user']:
+            raise ValueError("config MILVUS_USER is required")
+        if not values['password']:
+            raise ValueError("config MILVUS_PASSWORD is required")
+        return values
+
+
+class MilvusVectorIndex(BaseVectorIndex):
+    def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
+        super().__init__(dataset, embeddings)
+        self._client = self._init_client(config)
+
+    def get_type(self) -> str:
+        return 'milvus'
+
+    def get_index_name(self, dataset: Dataset) -> str:
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                class_prefix += '_Node'
+
+            return class_prefix
+
+        dataset_id = dataset.id
+        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
+        }
+
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = WeaviateVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            client=self._client,
+            index_name=self.get_index_name(self.dataset),
+            uuids=uuids,
+            by_text=False
+        )
+
+        return self
+
+    def _get_vector_store(self) -> VectorStore:
+        """Only for created index."""
+        if self._vector_store:
+            return self._vector_store
+
+        attributes = ['doc_id', 'dataset_id', 'document_id']
+        if self._is_origin():
+            attributes = ['doc_id']
+
+        return WeaviateVectorStore(
+            client=self._client,
+            index_name=self.get_index_name(self.dataset),
+            text_key='text',
+            embedding=self._embeddings,
+            attributes=attributes,
+            by_text=False
+        )
+
+    def _get_vector_store_class(self) -> type:
+        return MilvusVectorStore
+
+    def delete_by_document_id(self, document_id: str):
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        vector_store.del_texts({
+            "operator": "Equal",
+            "path": ["document_id"],
+            "valueText": document_id
+        })
+
+    def _is_origin(self):
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                return True
+
+        return False

文件差异内容过多而无法显示
+ 1691 - 0
api/core/index/vector_index/qdrant.py


+ 15 - 8
api/core/index/vector_index/qdrant_vector_index.py

@@ -44,15 +44,20 @@ class QdrantVectorIndex(BaseVectorIndex):
 
     def get_index_name(self, dataset: Dataset) -> str:
         if self.dataset.index_struct_dict:
-            return self.dataset.index_struct_dict['vector_store']['collection_name']
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                class_prefix += '_Node'
+
+            return class_prefix
 
         dataset_id = dataset.id
-        return "Index_" + dataset_id.replace("-", "_")
+        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
 
     def to_index_struct(self) -> dict:
         return {
             "type": self.get_type(),
-            "vector_store": {"collection_name": self.get_index_name(self.dataset)}
+            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
         }
 
     def create(self, texts: list[Document], **kwargs) -> BaseIndex:
@@ -62,7 +67,7 @@ class QdrantVectorIndex(BaseVectorIndex):
             self._embeddings,
             collection_name=self.get_index_name(self.dataset),
             ids=uuids,
-            content_payload_key='text',
+            content_payload_key='page_content',
             **self._client_config.to_qdrant_params()
         )
 
@@ -72,7 +77,9 @@ class QdrantVectorIndex(BaseVectorIndex):
         """Only for created index."""
         if self._vector_store:
             return self._vector_store
-        
+        attributes = ['doc_id', 'dataset_id', 'document_id']
+        if self._is_origin():
+            attributes = ['doc_id']
         client = qdrant_client.QdrantClient(
             **self._client_config.to_qdrant_params()
         )
@@ -81,7 +88,7 @@ class QdrantVectorIndex(BaseVectorIndex):
             client=client,
             collection_name=self.get_index_name(self.dataset),
             embeddings=self._embeddings,
-            content_payload_key='text'
+            content_payload_key='page_content'
         )
 
     def _get_vector_store_class(self) -> type:
@@ -108,8 +115,8 @@ class QdrantVectorIndex(BaseVectorIndex):
 
     def _is_origin(self):
         if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
-            if class_prefix.startswith('Vector_'):
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
                 # original class_prefix
                 return True
 

+ 38 - 0
api/core/vector_store/milvus_vector_store.py

@@ -0,0 +1,38 @@
+from langchain.vectorstores import  Milvus
+
+
+class MilvusVectorStore(Milvus):
+    def del_texts(self, where_filter: dict):
+        if not where_filter:
+            raise ValueError('where_filter must not be empty')
+
+        self._client.batch.delete_objects(
+            class_name=self._index_name,
+            where=where_filter,
+            output='minimal'
+        )
+
+    def del_text(self, uuid: str) -> None:
+        self._client.data_object.delete(
+            uuid,
+            class_name=self._index_name
+        )
+
+    def text_exists(self, uuid: str) -> bool:
+        result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
+            "path": ["doc_id"],
+            "operator": "Equal",
+            "valueText": uuid,
+        }).with_limit(1).do()
+
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        entries = result["data"]["Get"][self._index_name]
+        if len(entries) == 0:
+            return False
+
+        return True
+
+    def delete(self):
+        self._client.schema.delete_class(self._index_name)

+ 2 - 1
api/core/vector_store/qdrant_vector_store.py

@@ -1,10 +1,11 @@
 from typing import cast, Any
 
 from langchain.schema import Document
-from langchain.vectorstores import Qdrant
 from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
 from qdrant_client.local.qdrant_local import QdrantLocal
 
+from core.index.vector_index.qdrant import Qdrant
+
 
 class QdrantVectorStore(Qdrant):
     def del_texts(self, filter: Filter):