瀏覽代碼

compatible with original provider name

takatost 8 月之前
父節點
當前提交
f798add31c

+ 14 - 0
api/commands.py

@@ -25,6 +25,7 @@ from models.dataset import Document as DatasetDocument
 from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
 from models.provider import Provider, ProviderModel
 from services.account_service import RegisterService, TenantService
+from services.plugin.data_migration import PluginDataMigration
 
 
 @click.command("reset-password", help="Reset the account password.")
@@ -639,6 +640,18 @@ where sites.id is null limit 1000"""
     click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
 
 
+@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
+def migrate_data_for_plugin():
+    """
+    Migrate data for plugin.
+    """
+    click.echo(click.style("Starting migrate data for plugin.", fg="white"))
+
+    PluginDataMigration.migrate()
+
+    click.echo(click.style("Migrate data for plugin completed.", fg="green"))
+
+
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
@@ -649,3 +662,4 @@ def register_commands(app):
     app.cli.add_command(create_tenant)
     app.cli.add_command(upgrade_db)
     app.cli.add_command(fix_app_site_missing)
+    app.cli.add_command(migrate_data_for_plugin)

+ 10 - 1
api/core/app/app_config/easy_ui_based_app/model_config/manager.py

@@ -1,4 +1,5 @@
 from core.app.app_config.entities import ModelConfigEntity
+from core.entities import DEFAULT_PLUGIN_ID
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
 from core.provider_manager import ProviderManager
@@ -53,7 +54,15 @@ class ModelConfigManager:
         model_provider_factory = ModelProviderFactory(tenant_id)
         provider_entities = model_provider_factory.get_providers()
         model_provider_names = [provider.provider for provider in provider_entities]
-        if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
+        if "provider" not in config["model"]:
+            raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
+
+        if "/" not in config["model"]["provider"]:
+            config["model"]["provider"] = (
+                f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
+            )
+
+        if config["model"]["provider"] not in model_provider_names:
             raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
 
         # model.name

+ 7 - 0
api/core/entities/provider_configuration.py

@@ -9,6 +9,7 @@ from typing import Optional
 from pydantic import BaseModel, ConfigDict
 
 from constants import HIDDEN_VALUE
+from core.entities import DEFAULT_PLUGIN_ID
 from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
 from core.entities.provider_entities import (
     CustomConfiguration,
@@ -1047,6 +1048,9 @@ class ProviderConfigurations(BaseModel):
         return list(self.values())
 
     def __getitem__(self, key):
+        if "/" not in key:
+            key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
+
         return self.configurations[key]
 
     def __setitem__(self, key, value):
@@ -1059,6 +1063,9 @@ class ProviderConfigurations(BaseModel):
         return iter(self.configurations.values())
 
     def get(self, key, default=None):
+        if "/" not in key:
+            key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
+
         return self.configurations.get(key, default)
 
 

+ 0 - 0
api/services/plugin/__init__.py


+ 184 - 0
api/services/plugin/data_migration.py

@@ -0,0 +1,184 @@
+import json
+import logging
+
+import click
+
+from core.entities import DEFAULT_PLUGIN_ID
+from extensions.ext_database import db
+
+logger = logging.getLogger(__name__)
+
+
+class PluginDataMigration:
+    @classmethod
+    def migrate(cls) -> None:
+        cls.migrate_db_records("providers", "provider_name")  # large table
+        cls.migrate_db_records("provider_models", "provider_name")
+        cls.migrate_db_records("provider_orders", "provider_name")
+        cls.migrate_db_records("tenant_default_models", "provider_name")
+        cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
+        cls.migrate_db_records("provider_model_settings", "provider_name")
+        cls.migrate_db_records("load_balancing_model_configs", "provider_name")
+        cls.migrate_datasets()
+        cls.migrate_db_records("embeddings", "provider_name")  # large table
+        cls.migrate_db_records("dataset_collection_bindings", "provider_name")
+
+    @classmethod
+    def migrate_datasets(cls) -> None:
+        table_name = "datasets"
+        provider_column_name = "embedding_model_provider"
+
+        click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
+
+        processed_count = 0
+        failed_ids = []
+        while True:
+            sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
+where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
+limit 1000"""
+            with db.engine.begin() as conn:
+                rs = conn.execute(db.text(sql))
+
+                current_iter_count = 0
+                for i in rs:
+                    record_id = str(i.id)
+                    provider_name = str(i.provider_name)
+                    retrieval_model = i.retrieval_model
+                    print(type(retrieval_model))
+
+                    if record_id in failed_ids:
+                        continue
+
+                    retrieval_model_changed = False
+                    if retrieval_model:
+                        if (
+                            "reranking_model" in retrieval_model
+                            and "reranking_provider_name" in retrieval_model["reranking_model"]
+                            and retrieval_model["reranking_model"]["reranking_provider_name"]
+                            and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
+                        ):
+                            click.echo(
+                                click.style(
+                                    f"[{processed_count}] Migrating {table_name} {record_id} "
+                                    f"(reranking_provider_name: "
+                                    f"{retrieval_model['reranking_model']['reranking_provider_name']})",
+                                    fg="white",
+                                )
+                            )
+                            retrieval_model["reranking_model"]["reranking_provider_name"] = (
+                                f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
+                            )
+                            retrieval_model_changed = True
+
+                    click.echo(
+                        click.style(
+                            f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
+                            fg="white",
+                        )
+                    )
+
+                    try:
+                        # update provider name append with "langgenius/{provider_name}/{provider_name}"
+                        params = {"record_id": record_id}
+                        update_retrieval_model_sql = ""
+                        if retrieval_model and retrieval_model_changed:
+                            update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
+                            params["retrieval_model"] = json.dumps(retrieval_model)
+
+                        sql = f"""update {table_name} 
+                        set {provider_column_name} = 
+                        concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) 
+                        {update_retrieval_model_sql}
+                        where id = :record_id"""
+                        conn.execute(db.text(sql), params)
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
+                                fg="green",
+                            )
+                        )
+                    except Exception:
+                        failed_ids.append(record_id)
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
+                                fg="red",
+                            )
+                        )
+                        logger.exception(
+                            f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
+                        )
+                        continue
+
+                    current_iter_count += 1
+                    processed_count += 1
+
+            if not current_iter_count:
+                break
+
+        click.echo(
+            click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
+        )
+
+    @classmethod
+    def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
+        click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
+
+        processed_count = 0
+        failed_ids = []
+        while True:
+            sql = f"""select id, {provider_column_name} as provider_name from {table_name}
+where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
+limit 1000"""
+            with db.engine.begin() as conn:
+                rs = conn.execute(db.text(sql))
+
+                current_iter_count = 0
+                for i in rs:
+                    current_iter_count += 1
+                    processed_count += 1
+                    record_id = str(i.id)
+                    provider_name = str(i.provider_name)
+
+                    if record_id in failed_ids:
+                        continue
+
+                    click.echo(
+                        click.style(
+                            f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
+                            fg="white",
+                        )
+                    )
+
+                    try:
+                        # update provider name append with "langgenius/{provider_name}/{provider_name}"
+                        sql = f"""update {table_name} 
+                        set {provider_column_name} = 
+                        concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) 
+                        where id = :record_id"""
+                        conn.execute(db.text(sql), {"record_id": record_id})
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
+                                fg="green",
+                            )
+                        )
+                    except Exception:
+                        failed_ids.append(record_id)
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
+                                fg="red",
+                            )
+                        )
+                        logger.exception(
+                            f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
+                        )
+                        continue
+
+            if not current_iter_count:
+                break
+
+        click.echo(
+            click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
+        )