Ver código fonte

add qdrant migration (#1046)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 ano atrás
pai
commit
a43e80dd9c
2 arquivos alterados com 91 adições e 5 exclusões
  1. 63 1
      api/commands.py
  2. 28 4
      api/core/index/vector_index/base.py

+ 63 - 1
api/commands.py

@@ -329,7 +329,7 @@ def create_qdrant_indexes():
                 except Exception:
                     provider = Provider(
                         id='provider_id',
-                        tenant_id='tenant_id',
+                        tenant_id=dataset.tenant_id,
                         provider_name='openai',
                         provider_type=ProviderType.CUSTOM.value,
                         encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
@@ -369,6 +369,67 @@ def create_qdrant_indexes():
     click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
 
 
+@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
+def update_qdrant_indexes():
+    click.echo(click.style('Start Update qdrant indexes.', fg='green'))
+    create_count = 0
+
+    page = 1
+    while True:
+        try:
+            datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
+                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for dataset in datasets:
+            if dataset.index_struct_dict:
+                if dataset.index_struct_dict['type'] != 'qdrant':
+                    try:
+                        click.echo('Update dataset qdrant index: {}'.format(dataset.id))
+                        try:
+                            embedding_model = ModelFactory.get_embedding_model(
+                                tenant_id=dataset.tenant_id,
+                                model_provider_name=dataset.embedding_model_provider,
+                                model_name=dataset.embedding_model
+                            )
+                        except Exception:
+                            provider = Provider(
+                                id='provider_id',
+                                tenant_id=dataset.tenant_id,
+                                provider_name='openai',
+                                provider_type=ProviderType.CUSTOM.value,
+                                encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
+                                is_valid=True,
+                            )
+                            model_provider = OpenAIProvider(provider=provider)
+                            embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
+                        embeddings = CacheEmbedding(embedding_model)
+
+                        from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
+
+                        index = QdrantVectorIndex(
+                            dataset=dataset,
+                            config=QdrantConfig(
+                                endpoint=current_app.config.get('QDRANT_URL'),
+                                api_key=current_app.config.get('QDRANT_API_KEY'),
+                                root_path=current_app.root_path
+                            ),
+                            embeddings=embeddings
+                        )
+                        if index:
+                            index.update_qdrant_dataset(dataset)
+                            create_count += 1
+                        else:
+                            click.echo('passed.')
+                    except Exception as e:
+                        click.echo(
+                            click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+                        continue
+
+    click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
+
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
@@ -378,3 +439,4 @@ def register_commands(app):
     app.cli.add_command(sync_anthropic_hosted_providers)
     app.cli.add_command(clean_unused_dataset_indexes)
     app.cli.add_command(create_qdrant_indexes)
+    app.cli.add_command(update_qdrant_indexes)

+ 28 - 4
api/core/index/vector_index/base.py

@@ -15,12 +15,12 @@ from models.dataset import Document as DatasetDocument
 
 
 class BaseVectorIndex(BaseIndex):
-    
+
     def __init__(self, dataset: Dataset, embeddings: Embeddings):
         super().__init__(dataset)
         self._embeddings = embeddings
         self._vector_store = None
-        
+
     def get_type(self) -> str:
         raise NotImplementedError
 
@@ -143,7 +143,7 @@ class BaseVectorIndex(BaseIndex):
                 DocumentSegment.status == 'completed',
                 DocumentSegment.enabled == True
             ).all()
-            
+
             for segment in segments:
                 document = Document(
                     page_content=segment.content,
@@ -218,4 +218,28 @@ class BaseVectorIndex(BaseIndex):
             except Exception as e:
                 raise e
 
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
+        logging.info(f"Dataset {dataset.id} recreate successfully.")
+
+    def update_qdrant_dataset(self, dataset: Dataset):
+        logging.info(f"update_qdrant_dataset {dataset.id}")
+
+        segment = db.session.query(DocumentSegment).filter(
+            DocumentSegment.dataset_id == dataset.id,
+            DocumentSegment.status == 'completed',
+            DocumentSegment.enabled == True
+        ).first()
+
+        if segment:
+            try:
+                exist = self.text_exists(segment.index_node_id)
+                if exist:
+                    index_struct = {
+                        "type": 'qdrant',
+                        "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
+                    }
+                    dataset.index_struct = json.dumps(index_struct)
+                    db.session.commit()
+            except Exception as e:
+                raise e
+
+        logging.info(f"Dataset {dataset.id} recreate successfully.")