Преглед на файлове

compatible with original provider name

takatost преди 8 месеца
родител
ревизия
f798add31c

+ 14 - 0
api/commands.py

@@ -25,6 +25,7 @@ from models.dataset import Document as DatasetDocument
 from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
 from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
 from models.provider import Provider, ProviderModel
 from models.provider import Provider, ProviderModel
 from services.account_service import RegisterService, TenantService
 from services.account_service import RegisterService, TenantService
+from services.plugin.data_migration import PluginDataMigration
 
 
 
 
 @click.command("reset-password", help="Reset the account password.")
 @click.command("reset-password", help="Reset the account password.")
@@ -639,6 +640,18 @@ where sites.id is null limit 1000"""
     click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
     click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
 
 
 
 
+@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
+def migrate_data_for_plugin():
+    """
+    Migrate data for plugin.
+    """
+    click.echo(click.style("Starting migrate data for plugin.", fg="white"))
+
+    PluginDataMigration.migrate()
+
+    click.echo(click.style("Migrate data for plugin completed.", fg="green"))
+
+
 def register_commands(app):
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
     app.cli.add_command(reset_email)
@@ -649,3 +662,4 @@ def register_commands(app):
     app.cli.add_command(create_tenant)
     app.cli.add_command(create_tenant)
     app.cli.add_command(upgrade_db)
     app.cli.add_command(upgrade_db)
     app.cli.add_command(fix_app_site_missing)
     app.cli.add_command(fix_app_site_missing)
+    app.cli.add_command(migrate_data_for_plugin)

+ 10 - 1
api/core/app/app_config/easy_ui_based_app/model_config/manager.py

@@ -1,4 +1,5 @@
 from core.app.app_config.entities import ModelConfigEntity
 from core.app.app_config.entities import ModelConfigEntity
+from core.entities import DEFAULT_PLUGIN_ID
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
 from core.provider_manager import ProviderManager
 from core.provider_manager import ProviderManager
@@ -53,7 +54,15 @@ class ModelConfigManager:
         model_provider_factory = ModelProviderFactory(tenant_id)
         model_provider_factory = ModelProviderFactory(tenant_id)
         provider_entities = model_provider_factory.get_providers()
         provider_entities = model_provider_factory.get_providers()
         model_provider_names = [provider.provider for provider in provider_entities]
         model_provider_names = [provider.provider for provider in provider_entities]
-        if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
+        if "provider" not in config["model"]:
+            raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
+
+        if "/" not in config["model"]["provider"]:
+            config["model"]["provider"] = (
+                f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
+            )
+
+        if config["model"]["provider"] not in model_provider_names:
             raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
             raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
 
 
         # model.name
         # model.name

+ 7 - 0
api/core/entities/provider_configuration.py

@@ -9,6 +9,7 @@ from typing import Optional
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 
 
 from constants import HIDDEN_VALUE
 from constants import HIDDEN_VALUE
+from core.entities import DEFAULT_PLUGIN_ID
 from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
 from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
 from core.entities.provider_entities import (
 from core.entities.provider_entities import (
     CustomConfiguration,
     CustomConfiguration,
@@ -1047,6 +1048,9 @@ class ProviderConfigurations(BaseModel):
         return list(self.values())
         return list(self.values())
 
 
     def __getitem__(self, key):
     def __getitem__(self, key):
+        if "/" not in key:
+            key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
+
         return self.configurations[key]
         return self.configurations[key]
 
 
     def __setitem__(self, key, value):
     def __setitem__(self, key, value):
@@ -1059,6 +1063,9 @@ class ProviderConfigurations(BaseModel):
         return iter(self.configurations.values())
         return iter(self.configurations.values())
 
 
     def get(self, key, default=None):
     def get(self, key, default=None):
+        if "/" not in key:
+            key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
+
         return self.configurations.get(key, default)
         return self.configurations.get(key, default)
 
 
 
 

+ 0 - 0
api/services/plugin/__init__.py


+ 184 - 0
api/services/plugin/data_migration.py

@@ -0,0 +1,184 @@
+import json
+import logging
+
+import click
+
+from core.entities import DEFAULT_PLUGIN_ID
+from extensions.ext_database import db
+
+logger = logging.getLogger(__name__)
+
+
+class PluginDataMigration:
+    @classmethod
+    def migrate(cls) -> None:
+        cls.migrate_db_records("providers", "provider_name")  # large table
+        cls.migrate_db_records("provider_models", "provider_name")
+        cls.migrate_db_records("provider_orders", "provider_name")
+        cls.migrate_db_records("tenant_default_models", "provider_name")
+        cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
+        cls.migrate_db_records("provider_model_settings", "provider_name")
+        cls.migrate_db_records("load_balancing_model_configs", "provider_name")
+        cls.migrate_datasets()
+        cls.migrate_db_records("embeddings", "provider_name")  # large table
+        cls.migrate_db_records("dataset_collection_bindings", "provider_name")
+
+    @classmethod
+    def migrate_datasets(cls) -> None:
+        table_name = "datasets"
+        provider_column_name = "embedding_model_provider"
+
+        click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
+
+        processed_count = 0
+        failed_ids = []
+        while True:
+            sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
+where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
+limit 1000"""
+            with db.engine.begin() as conn:
+                rs = conn.execute(db.text(sql))
+
+                current_iter_count = 0
+                for i in rs:
+                    record_id = str(i.id)
+                    provider_name = str(i.provider_name)
+                    retrieval_model = i.retrieval_model
+                    print(type(retrieval_model))
+
+                    if record_id in failed_ids:
+                        continue
+
+                    retrieval_model_changed = False
+                    if retrieval_model:
+                        if (
+                            "reranking_model" in retrieval_model
+                            and "reranking_provider_name" in retrieval_model["reranking_model"]
+                            and retrieval_model["reranking_model"]["reranking_provider_name"]
+                            and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
+                        ):
+                            click.echo(
+                                click.style(
+                                    f"[{processed_count}] Migrating {table_name} {record_id} "
+                                    f"(reranking_provider_name: "
+                                    f"{retrieval_model['reranking_model']['reranking_provider_name']})",
+                                    fg="white",
+                                )
+                            )
+                            retrieval_model["reranking_model"]["reranking_provider_name"] = (
+                                f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
+                            )
+                            retrieval_model_changed = True
+
+                    click.echo(
+                        click.style(
+                            f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
+                            fg="white",
+                        )
+                    )
+
+                    try:
+                        # update provider name append with "langgenius/{provider_name}/{provider_name}"
+                        params = {"record_id": record_id}
+                        update_retrieval_model_sql = ""
+                        if retrieval_model and retrieval_model_changed:
+                            update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
+                            params["retrieval_model"] = json.dumps(retrieval_model)
+
+                        sql = f"""update {table_name} 
+                        set {provider_column_name} = 
+                        concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) 
+                        {update_retrieval_model_sql}
+                        where id = :record_id"""
+                        conn.execute(db.text(sql), params)
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
+                                fg="green",
+                            )
+                        )
+                    except Exception:
+                        failed_ids.append(record_id)
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
+                                fg="red",
+                            )
+                        )
+                        logger.exception(
+                            f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
+                        )
+                        continue
+
+                    current_iter_count += 1
+                    processed_count += 1
+
+            if not current_iter_count:
+                break
+
+        click.echo(
+            click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
+        )
+
+    @classmethod
+    def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
+        click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
+
+        processed_count = 0
+        failed_ids = []
+        while True:
+            sql = f"""select id, {provider_column_name} as provider_name from {table_name}
+where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
+limit 1000"""
+            with db.engine.begin() as conn:
+                rs = conn.execute(db.text(sql))
+
+                current_iter_count = 0
+                for i in rs:
+                    current_iter_count += 1
+                    processed_count += 1
+                    record_id = str(i.id)
+                    provider_name = str(i.provider_name)
+
+                    if record_id in failed_ids:
+                        continue
+
+                    click.echo(
+                        click.style(
+                            f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
+                            fg="white",
+                        )
+                    )
+
+                    try:
+                        # update provider name append with "langgenius/{provider_name}/{provider_name}"
+                        sql = f"""update {table_name} 
+                        set {provider_column_name} = 
+                        concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) 
+                        where id = :record_id"""
+                        conn.execute(db.text(sql), {"record_id": record_id})
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
+                                fg="green",
+                            )
+                        )
+                    except Exception:
+                        failed_ids.append(record_id)
+                        click.echo(
+                            click.style(
+                                f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
+                                fg="red",
+                            )
+                        )
+                        logger.exception(
+                            f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
+                        )
+                        continue
+
+            if not current_iter_count:
+                break
+
+        click.echo(
+            click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
+        )