data_migration.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import json
  2. import logging
  3. import click
  4. from core.entities import DEFAULT_PLUGIN_ID
  5. from extensions.ext_database import db
  6. logger = logging.getLogger(__name__)
  7. class PluginDataMigration:
  8. @classmethod
  9. def migrate(cls) -> None:
  10. cls.migrate_db_records("providers", "provider_name") # large table
  11. cls.migrate_db_records("provider_models", "provider_name")
  12. cls.migrate_db_records("provider_orders", "provider_name")
  13. cls.migrate_db_records("tenant_default_models", "provider_name")
  14. cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
  15. cls.migrate_db_records("provider_model_settings", "provider_name")
  16. cls.migrate_db_records("load_balancing_model_configs", "provider_name")
  17. cls.migrate_datasets()
  18. cls.migrate_db_records("embeddings", "provider_name") # large table
  19. cls.migrate_db_records("dataset_collection_bindings", "provider_name")
  20. @classmethod
  21. def migrate_datasets(cls) -> None:
  22. table_name = "datasets"
  23. provider_column_name = "embedding_model_provider"
  24. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  25. processed_count = 0
  26. failed_ids = []
  27. while True:
  28. sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
  29. where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
  30. limit 1000"""
  31. with db.engine.begin() as conn:
  32. rs = conn.execute(db.text(sql))
  33. current_iter_count = 0
  34. for i in rs:
  35. record_id = str(i.id)
  36. provider_name = str(i.provider_name)
  37. retrieval_model = i.retrieval_model
  38. print(type(retrieval_model))
  39. if record_id in failed_ids:
  40. continue
  41. retrieval_model_changed = False
  42. if retrieval_model:
  43. if (
  44. "reranking_model" in retrieval_model
  45. and "reranking_provider_name" in retrieval_model["reranking_model"]
  46. and retrieval_model["reranking_model"]["reranking_provider_name"]
  47. and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
  48. ):
  49. click.echo(
  50. click.style(
  51. f"[{processed_count}] Migrating {table_name} {record_id} "
  52. f"(reranking_provider_name: "
  53. f"{retrieval_model['reranking_model']['reranking_provider_name']})",
  54. fg="white",
  55. )
  56. )
  57. retrieval_model["reranking_model"]["reranking_provider_name"] = (
  58. f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
  59. )
  60. retrieval_model_changed = True
  61. click.echo(
  62. click.style(
  63. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  64. fg="white",
  65. )
  66. )
  67. try:
  68. # update provider name append with "langgenius/{provider_name}/{provider_name}"
  69. params = {"record_id": record_id}
  70. update_retrieval_model_sql = ""
  71. if retrieval_model and retrieval_model_changed:
  72. update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
  73. params["retrieval_model"] = json.dumps(retrieval_model)
  74. sql = f"""update {table_name}
  75. set {provider_column_name} =
  76. concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
  77. {update_retrieval_model_sql}
  78. where id = :record_id"""
  79. conn.execute(db.text(sql), params)
  80. click.echo(
  81. click.style(
  82. f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
  83. fg="green",
  84. )
  85. )
  86. except Exception:
  87. failed_ids.append(record_id)
  88. click.echo(
  89. click.style(
  90. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  91. fg="red",
  92. )
  93. )
  94. logger.exception(
  95. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
  96. )
  97. continue
  98. current_iter_count += 1
  99. processed_count += 1
  100. if not current_iter_count:
  101. break
  102. click.echo(
  103. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  104. )
  105. @classmethod
  106. def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
  107. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  108. processed_count = 0
  109. failed_ids = []
  110. while True:
  111. sql = f"""select id, {provider_column_name} as provider_name from {table_name}
  112. where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
  113. limit 1000"""
  114. with db.engine.begin() as conn:
  115. rs = conn.execute(db.text(sql))
  116. current_iter_count = 0
  117. for i in rs:
  118. current_iter_count += 1
  119. processed_count += 1
  120. record_id = str(i.id)
  121. provider_name = str(i.provider_name)
  122. if record_id in failed_ids:
  123. continue
  124. click.echo(
  125. click.style(
  126. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  127. fg="white",
  128. )
  129. )
  130. try:
  131. # update provider name append with "langgenius/{provider_name}/{provider_name}"
  132. sql = f"""update {table_name}
  133. set {provider_column_name} =
  134. concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
  135. where id = :record_id"""
  136. conn.execute(db.text(sql), {"record_id": record_id})
  137. click.echo(
  138. click.style(
  139. f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
  140. fg="green",
  141. )
  142. )
  143. except Exception:
  144. failed_ids.append(record_id)
  145. click.echo(
  146. click.style(
  147. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  148. fg="red",
  149. )
  150. )
  151. logger.exception(
  152. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
  153. )
  154. continue
  155. if not current_iter_count:
  156. break
  157. click.echo(
  158. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  159. )