Browse Source

Feat/vdb migrate command (#2562)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 year ago
parent
commit
0620fa3094

+ 103 - 54
api/commands.py

@@ -1,20 +1,21 @@
 import base64
 import json
 import secrets
+from typing import cast
 
 import click
 from flask import current_app
 from werkzeug.exceptions import NotFound
 
-from core.embedding.cached_embedding import CacheEmbedding
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from libs.helper import email as email_validate
 from libs.password import hash_password, password_pattern, valid_password
 from libs.rsa import generate_key_pair
 from models.account import Tenant
-from models.dataset import Dataset
+from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
+from models.dataset import Document as DatasetDocument
 from models.model import Account
 from models.provider import Provider, ProviderModel
 
@@ -124,14 +125,15 @@ def reset_encrypt_key_pair():
                            'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
 
 
-@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
-def create_qdrant_indexes():
+@click.command('vdb-migrate', help='migrate vector db.')
+def vdb_migrate():
     """
-    Migrate other vector database datas to Qdrant.
+    Migrate vector database datas to target vector database .
     """
-    click.echo(click.style('Start create qdrant indexes.', fg='green'))
+    click.echo(click.style('Start migrate vector db.', fg='green'))
     create_count = 0
-
+    config = cast(dict, current_app.config)
+    vector_type = config.get('VECTOR_STORE')
     page = 1
     while True:
         try:
@@ -140,54 +142,101 @@ def create_qdrant_indexes():
         except NotFound:
             break
 
-        model_manager = ModelManager()
-
         page += 1
         for dataset in datasets:
-            if dataset.index_struct_dict:
-                if dataset.index_struct_dict['type'] != 'qdrant':
-                    try:
-                        click.echo('Create dataset qdrant index: {}'.format(dataset.id))
-                        try:
-                            embedding_model = model_manager.get_model_instance(
-                                tenant_id=dataset.tenant_id,
-                                provider=dataset.embedding_model_provider,
-                                model_type=ModelType.TEXT_EMBEDDING,
-                                model=dataset.embedding_model
-
-                            )
-                        except Exception:
-                            continue
-                        embeddings = CacheEmbedding(embedding_model)
-
-                        from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
-
-                        index = QdrantVectorIndex(
-                            dataset=dataset,
-                            config=QdrantConfig(
-                                endpoint=current_app.config.get('QDRANT_URL'),
-                                api_key=current_app.config.get('QDRANT_API_KEY'),
-                                root_path=current_app.root_path
-                            ),
-                            embeddings=embeddings
-                        )
-                        if index:
-                            index.create_qdrant_dataset(dataset)
-                            index_struct = {
-                                "type": 'qdrant',
-                                "vector_store": {
-                                    "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
-                            }
-                            dataset.index_struct = json.dumps(index_struct)
-                            db.session.commit()
-                            create_count += 1
+            try:
+                click.echo('Create dataset vdb index: {}'.format(dataset.id))
+                if dataset.index_struct_dict:
+                    if dataset.index_struct_dict['type'] == vector_type:
+                        continue
+                if vector_type == "weaviate":
+                    dataset_id = dataset.id
+                    collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+                    index_struct_dict = {
+                        "type": 'weaviate',
+                        "vector_store": {"class_prefix": collection_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
+                elif vector_type == "qdrant":
+                    if dataset.collection_binding_id:
+                        dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+                            filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
+                            one_or_none()
+                        if dataset_collection_binding:
+                            collection_name = dataset_collection_binding.collection_name
                         else:
-                            click.echo('passed.')
+                            raise ValueError('Dataset Collection Bindings is not exist!')
+                    else:
+                        dataset_id = dataset.id
+                        collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+                    index_struct_dict = {
+                        "type": 'qdrant',
+                        "vector_store": {"class_prefix": collection_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
+
+                elif vector_type == "milvus":
+                    dataset_id = dataset.id
+                    collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+                    index_struct_dict = {
+                        "type": 'milvus',
+                        "vector_store": {"class_prefix": collection_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
+                else:
+                    raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+
+                vector = Vector(dataset)
+                click.echo(f"vdb_migrate {dataset.id}")
+
+                try:
+                    vector.delete()
+                except Exception as e:
+                    raise e
+
+                dataset_documents = db.session.query(DatasetDocument).filter(
+                    DatasetDocument.dataset_id == dataset.id,
+                    DatasetDocument.indexing_status == 'completed',
+                    DatasetDocument.enabled == True,
+                    DatasetDocument.archived == False,
+                ).all()
+
+                documents = []
+                for dataset_document in dataset_documents:
+                    segments = db.session.query(DocumentSegment).filter(
+                        DocumentSegment.document_id == dataset_document.id,
+                        DocumentSegment.status == 'completed',
+                        DocumentSegment.enabled == True
+                    ).all()
+
+                    for segment in segments:
+                        document = Document(
+                            page_content=segment.content,
+                            metadata={
+                                "doc_id": segment.index_node_id,
+                                "doc_hash": segment.index_node_hash,
+                                "document_id": segment.document_id,
+                                "dataset_id": segment.dataset_id,
+                            }
+                        )
+
+                        documents.append(document)
+
+                if documents:
+                    try:
+                        vector.create(documents)
                     except Exception as e:
-                        click.echo(
-                            click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
-                                        fg='red'))
-                        continue
+                        raise e
+                click.echo(f"Dataset {dataset.id} create successfully.")
+                db.session.add(dataset)
+                db.session.commit()
+                create_count += 1
+            except Exception as e:
+                db.session.rollback()
+                click.echo(
+                    click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
+                                fg='red'))
+                continue
 
     click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
 
@@ -196,4 +245,4 @@ def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
     app.cli.add_command(reset_encrypt_key_pair)
-    app.cli.add_command(create_qdrant_indexes)
+    app.cli.add_command(vdb_migrate)

+ 1 - 0
api/core/indexing_runner.py

@@ -664,6 +664,7 @@ class IndexingRunner:
                 )
             # load index
             index_processor.load(dataset, chunk_documents)
+            db.session.add(dataset)
 
             document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             db.session.query(DocumentSegment).filter(

+ 7 - 1
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -127,9 +127,15 @@ class MilvusVector(BaseVector):
         self._client.delete(collection_name=self._collection_name, pks=doc_ids)
 
     def delete(self) -> None:
+        alias = uuid4().hex
+        if self._client_config.secure:
+            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+        else:
+            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
 
         from pymilvus import utility
-        utility.drop_collection(self._collection_name, None)
+        utility.drop_collection(self._collection_name, None, using=alias)
 
     def text_exists(self, id: str) -> bool:
 

+ 18 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -1,3 +1,4 @@
+import json
 from typing import Any, cast
 
 from flask import current_app
@@ -39,6 +40,11 @@ class Vector:
             else:
                 dataset_id = self._dataset.id
                 collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+                index_struct_dict = {
+                    "type": 'weaviate',
+                    "vector_store": {"class_prefix": collection_name}
+                }
+                self._dataset.index_struct = json.dumps(index_struct_dict)
             return WeaviateVector(
                 collection_name=collection_name,
                 config=WeaviateConfig(
@@ -66,6 +72,13 @@ class Vector:
                     dataset_id = self._dataset.id
                     collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
 
+            if not self._dataset.index_struct_dict:
+                index_struct_dict = {
+                    "type": 'qdrant',
+                    "vector_store": {"class_prefix": collection_name}
+                }
+                self._dataset.index_struct = json.dumps(index_struct_dict)
+
             return QdrantVector(
                 collection_name=collection_name,
                 group_id=self._dataset.id,
@@ -84,6 +97,11 @@ class Vector:
             else:
                 dataset_id = self._dataset.id
                 collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+                index_struct_dict = {
+                    "type": 'milvus',
+                    "vector_store": {"class_prefix": collection_name}
+                }
+                self._dataset.index_struct = json.dumps(index_struct_dict)
             return MilvusVector(
                 collection_name=collection_name,
                 config=MilvusConfig(

+ 4 - 1
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -127,7 +127,10 @@ class WeaviateVector(BaseVector):
         )
 
     def delete(self):
-        self._client.schema.delete_class(self._collection_name)
+        # check whether the index already exists
+        schema = self._default_schema(self._collection_name)
+        if self._client.schema.contains(schema):
+            self._client.schema.delete_class(self._collection_name)
 
     def text_exists(self, id: str) -> bool:
         collection_name = self._collection_name