data_migration.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import json
  2. import logging
  3. import click
  4. from core.entities import DEFAULT_PLUGIN_ID
  5. from models.engine import db
  6. logger = logging.getLogger(__name__)
  7. class PluginDataMigration:
  8. @classmethod
  9. def migrate(cls) -> None:
  10. cls.migrate_db_records("providers", "provider_name") # large table
  11. cls.migrate_db_records("provider_models", "provider_name")
  12. cls.migrate_db_records("provider_orders", "provider_name")
  13. cls.migrate_db_records("tenant_default_models", "provider_name")
  14. cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
  15. cls.migrate_db_records("provider_model_settings", "provider_name")
  16. cls.migrate_db_records("load_balancing_model_configs", "provider_name")
  17. cls.migrate_datasets()
  18. cls.migrate_db_records("embeddings", "provider_name") # large table
  19. cls.migrate_db_records("dataset_collection_bindings", "provider_name")
  20. cls.migrate_db_records("tool_builtin_providers", "provider")
  21. @classmethod
  22. def migrate_datasets(cls) -> None:
  23. table_name = "datasets"
  24. provider_column_name = "embedding_model_provider"
  25. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  26. processed_count = 0
  27. failed_ids = []
  28. while True:
  29. sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
  30. where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
  31. limit 1000"""
  32. with db.engine.begin() as conn:
  33. rs = conn.execute(db.text(sql))
  34. current_iter_count = 0
  35. for i in rs:
  36. record_id = str(i.id)
  37. provider_name = str(i.provider_name)
  38. retrieval_model = i.retrieval_model
  39. print(type(retrieval_model))
  40. if record_id in failed_ids:
  41. continue
  42. retrieval_model_changed = False
  43. if retrieval_model:
  44. if (
  45. "reranking_model" in retrieval_model
  46. and "reranking_provider_name" in retrieval_model["reranking_model"]
  47. and retrieval_model["reranking_model"]["reranking_provider_name"]
  48. and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
  49. ):
  50. click.echo(
  51. click.style(
  52. f"[{processed_count}] Migrating {table_name} {record_id} "
  53. f"(reranking_provider_name: "
  54. f"{retrieval_model['reranking_model']['reranking_provider_name']})",
  55. fg="white",
  56. )
  57. )
  58. retrieval_model["reranking_model"]["reranking_provider_name"] = (
  59. f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
  60. )
  61. retrieval_model_changed = True
  62. click.echo(
  63. click.style(
  64. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  65. fg="white",
  66. )
  67. )
  68. try:
  69. # update provider name append with "langgenius/{provider_name}/{provider_name}"
  70. params = {"record_id": record_id}
  71. update_retrieval_model_sql = ""
  72. if retrieval_model and retrieval_model_changed:
  73. update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
  74. params["retrieval_model"] = json.dumps(retrieval_model)
  75. sql = f"""update {table_name}
  76. set {provider_column_name} =
  77. concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
  78. {update_retrieval_model_sql}
  79. where id = :record_id"""
  80. conn.execute(db.text(sql), params)
  81. click.echo(
  82. click.style(
  83. f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
  84. fg="green",
  85. )
  86. )
  87. except Exception:
  88. failed_ids.append(record_id)
  89. click.echo(
  90. click.style(
  91. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  92. fg="red",
  93. )
  94. )
  95. logger.exception(
  96. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
  97. )
  98. continue
  99. current_iter_count += 1
  100. processed_count += 1
  101. if not current_iter_count:
  102. break
  103. click.echo(
  104. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  105. )
  106. @classmethod
  107. def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
  108. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  109. processed_count = 0
  110. failed_ids = []
  111. while True:
  112. sql = f"""select id, {provider_column_name} as provider_name from {table_name}
  113. where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
  114. limit 1000"""
  115. with db.engine.begin() as conn:
  116. rs = conn.execute(db.text(sql))
  117. current_iter_count = 0
  118. for i in rs:
  119. current_iter_count += 1
  120. processed_count += 1
  121. record_id = str(i.id)
  122. provider_name = str(i.provider_name)
  123. if record_id in failed_ids:
  124. continue
  125. click.echo(
  126. click.style(
  127. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  128. fg="white",
  129. )
  130. )
  131. try:
  132. # update provider name append with "langgenius/{provider_name}/{provider_name}"
  133. sql = f"""update {table_name}
  134. set {provider_column_name} =
  135. concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
  136. where id = :record_id"""
  137. conn.execute(db.text(sql), {"record_id": record_id})
  138. click.echo(
  139. click.style(
  140. f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
  141. fg="green",
  142. )
  143. )
  144. except Exception:
  145. failed_ids.append(record_id)
  146. click.echo(
  147. click.style(
  148. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  149. fg="red",
  150. )
  151. )
  152. logger.exception(
  153. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
  154. )
  155. continue
  156. if not current_iter_count:
  157. break
  158. click.echo(
  159. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  160. )