commands.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. import datetime
  2. import json
  3. import math
  4. import random
  5. import string
  6. import threading
  7. import time
  8. import uuid
  9. import click
  10. from tqdm import tqdm
  11. from flask import current_app, Flask
  12. from langchain.embeddings import OpenAIEmbeddings
  13. from werkzeug.exceptions import NotFound
  14. from core.embedding.cached_embedding import CacheEmbedding
  15. from core.index.index import IndexBuilder
  16. from core.model_providers.model_factory import ModelFactory
  17. from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
  18. from core.model_providers.models.entity.model_params import ModelType
  19. from core.model_providers.providers.hosted import hosted_model_providers
  20. from core.model_providers.providers.openai_provider import OpenAIProvider
  21. from libs.password import password_pattern, valid_password, hash_password
  22. from libs.helper import email as email_validate
  23. from extensions.ext_database import db
  24. from libs.rsa import generate_key_pair
  25. from models.account import InvitationCode, Tenant, TenantAccountJoin
  26. from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
  27. from models.model import Account, AppModelConfig, App
  28. import secrets
  29. import base64
  30. from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
  31. @click.command('reset-password', help='Reset the account password.')
  32. @click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
  33. @click.option('--new-password', prompt=True, help='the new password.')
  34. @click.option('--password-confirm', prompt=True, help='the new password confirm.')
  35. def reset_password(email, new_password, password_confirm):
  36. if str(new_password).strip() != str(password_confirm).strip():
  37. click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
  38. return
  39. account = db.session.query(Account). \
  40. filter(Account.email == email). \
  41. one_or_none()
  42. if not account:
  43. click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
  44. return
  45. try:
  46. valid_password(new_password)
  47. except:
  48. click.echo(
  49. click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
  50. return
  51. # generate password salt
  52. salt = secrets.token_bytes(16)
  53. base64_salt = base64.b64encode(salt).decode()
  54. # encrypt password with salt
  55. password_hashed = hash_password(new_password, salt)
  56. base64_password_hashed = base64.b64encode(password_hashed).decode()
  57. account.password = base64_password_hashed
  58. account.password_salt = base64_salt
  59. db.session.commit()
  60. click.echo(click.style('Congratulations!, password has been reset.', fg='green'))
  61. @click.command('reset-email', help='Reset the account email.')
  62. @click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
  63. @click.option('--new-email', prompt=True, help='the new email.')
  64. @click.option('--email-confirm', prompt=True, help='the new email confirm.')
  65. def reset_email(email, new_email, email_confirm):
  66. if str(new_email).strip() != str(email_confirm).strip():
  67. click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
  68. return
  69. account = db.session.query(Account). \
  70. filter(Account.email == email). \
  71. one_or_none()
  72. if not account:
  73. click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
  74. return
  75. try:
  76. email_validate(new_email)
  77. except:
  78. click.echo(
  79. click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
  80. return
  81. account.email = new_email
  82. db.session.commit()
  83. click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
  84. @click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
  85. 'After the reset, all LLM credentials will become invalid, '
  86. 'requiring re-entry.'
  87. 'Only support SELF_HOSTED mode.')
  88. @click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
  89. ' this operation cannot be rolled back!', fg='red'))
  90. def reset_encrypt_key_pair():
  91. if current_app.config['EDITION'] != 'SELF_HOSTED':
  92. click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
  93. return
  94. tenant = db.session.query(Tenant).first()
  95. if not tenant:
  96. click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
  97. return
  98. tenant.encrypt_public_key = generate_key_pair(tenant.id)
  99. db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
  100. db.session.query(ProviderModel).delete()
  101. db.session.commit()
  102. click.echo(click.style('Congratulations! '
  103. 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
  104. @click.command('generate-invitation-codes', help='Generate invitation codes.')
  105. @click.option('--batch', help='The batch of invitation codes.')
  106. @click.option('--count', prompt=True, help='Invitation codes count.')
  107. def generate_invitation_codes(batch, count):
  108. if not batch:
  109. now = datetime.datetime.now()
  110. batch = now.strftime('%Y%m%d%H%M%S')
  111. if not count or int(count) <= 0:
  112. click.echo(click.style('sorry. the count must be greater than 0.', fg='red'))
  113. return
  114. count = int(count)
  115. click.echo('Start generate {} invitation codes for batch {}.'.format(count, batch))
  116. codes = ''
  117. for i in range(count):
  118. code = generate_invitation_code()
  119. invitation_code = InvitationCode(
  120. code=code,
  121. batch=batch
  122. )
  123. db.session.add(invitation_code)
  124. click.echo(code)
  125. codes += code + "\n"
  126. db.session.commit()
  127. filename = 'storage/invitation-codes-{}.txt'.format(batch)
  128. with open(filename, 'w') as f:
  129. f.write(codes)
  130. click.echo(click.style(
  131. 'Congratulations! Generated {} invitation codes for batch {} and saved to the file \'{}\''.format(count, batch,
  132. filename),
  133. fg='green'))
  134. def generate_invitation_code():
  135. code = generate_upper_string()
  136. while db.session.query(InvitationCode).filter(InvitationCode.code == code).count() > 0:
  137. code = generate_upper_string()
  138. return code
  139. def generate_upper_string():
  140. letters_digits = string.ascii_uppercase + string.digits
  141. result = ""
  142. for i in range(8):
  143. result += random.choice(letters_digits)
  144. return result
  145. @click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
  146. def recreate_all_dataset_indexes():
  147. click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
  148. recreate_count = 0
  149. page = 1
  150. while True:
  151. try:
  152. datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
  153. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  154. except NotFound:
  155. break
  156. page += 1
  157. for dataset in datasets:
  158. try:
  159. click.echo('Recreating dataset index: {}'.format(dataset.id))
  160. index = IndexBuilder.get_index(dataset, 'high_quality')
  161. if index and index._is_origin():
  162. index.recreate_dataset(dataset)
  163. recreate_count += 1
  164. else:
  165. click.echo('passed.')
  166. except Exception as e:
  167. click.echo(
  168. click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
  169. continue
  170. click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
  171. @click.command('clean-unused-dataset-indexes', help='Clean unused dataset indexes.')
  172. def clean_unused_dataset_indexes():
  173. click.echo(click.style('Start clean unused dataset indexes.', fg='green'))
  174. clean_days = int(current_app.config.get('CLEAN_DAY_SETTING'))
  175. start_at = time.perf_counter()
  176. thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
  177. page = 1
  178. while True:
  179. try:
  180. datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \
  181. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  182. except NotFound:
  183. break
  184. page += 1
  185. for dataset in datasets:
  186. dataset_query = db.session.query(DatasetQuery).filter(
  187. DatasetQuery.created_at > thirty_days_ago,
  188. DatasetQuery.dataset_id == dataset.id
  189. ).all()
  190. if not dataset_query or len(dataset_query) == 0:
  191. documents = db.session.query(Document).filter(
  192. Document.dataset_id == dataset.id,
  193. Document.indexing_status == 'completed',
  194. Document.enabled == True,
  195. Document.archived == False,
  196. Document.updated_at > thirty_days_ago
  197. ).all()
  198. if not documents or len(documents) == 0:
  199. try:
  200. # remove index
  201. vector_index = IndexBuilder.get_index(dataset, 'high_quality')
  202. kw_index = IndexBuilder.get_index(dataset, 'economy')
  203. # delete from vector index
  204. if vector_index:
  205. if dataset.collection_binding_id:
  206. vector_index.delete_by_group_id(dataset.id)
  207. else:
  208. if dataset.collection_binding_id:
  209. vector_index.delete_by_group_id(dataset.id)
  210. else:
  211. vector_index.delete()
  212. kw_index.delete()
  213. # update document
  214. update_params = {
  215. Document.enabled: False
  216. }
  217. Document.query.filter_by(dataset_id=dataset.id).update(update_params)
  218. db.session.commit()
  219. click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
  220. fg='green'))
  221. except Exception as e:
  222. click.echo(
  223. click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
  224. fg='red'))
  225. end_at = time.perf_counter()
  226. click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))
  227. @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
  228. def sync_anthropic_hosted_providers():
  229. if not hosted_model_providers.anthropic:
  230. click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
  231. return
  232. click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
  233. count = 0
  234. new_quota_limit = hosted_model_providers.anthropic.quota_limit
  235. page = 1
  236. while True:
  237. try:
  238. providers = db.session.query(Provider).filter(
  239. Provider.provider_name == 'anthropic',
  240. Provider.provider_type == ProviderType.SYSTEM.value,
  241. Provider.quota_type == ProviderQuotaType.TRIAL.value,
  242. Provider.quota_limit != new_quota_limit
  243. ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
  244. except NotFound:
  245. break
  246. page += 1
  247. for provider in providers:
  248. try:
  249. click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
  250. .format(provider.tenant_id, provider.quota_limit, provider.quota_used))
  251. original_quota_limit = provider.quota_limit
  252. division = math.ceil(new_quota_limit / 1000)
  253. provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
  254. else original_quota_limit * division
  255. provider.quota_used = division * provider.quota_used
  256. db.session.commit()
  257. count += 1
  258. except Exception as e:
  259. click.echo(click.style(
  260. 'Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)),
  261. fg='red'))
  262. continue
  263. click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
  264. @click.command('create-qdrant-indexes', help='Create qdrant indexes.')
  265. def create_qdrant_indexes():
  266. click.echo(click.style('Start create qdrant indexes.', fg='green'))
  267. create_count = 0
  268. page = 1
  269. while True:
  270. try:
  271. datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
  272. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  273. except NotFound:
  274. break
  275. page += 1
  276. for dataset in datasets:
  277. if dataset.index_struct_dict:
  278. if dataset.index_struct_dict['type'] != 'qdrant':
  279. try:
  280. click.echo('Create dataset qdrant index: {}'.format(dataset.id))
  281. try:
  282. embedding_model = ModelFactory.get_embedding_model(
  283. tenant_id=dataset.tenant_id,
  284. model_provider_name=dataset.embedding_model_provider,
  285. model_name=dataset.embedding_model
  286. )
  287. except Exception:
  288. try:
  289. embedding_model = ModelFactory.get_embedding_model(
  290. tenant_id=dataset.tenant_id
  291. )
  292. dataset.embedding_model = embedding_model.name
  293. dataset.embedding_model_provider = embedding_model.model_provider.provider_name
  294. except Exception:
  295. provider = Provider(
  296. id='provider_id',
  297. tenant_id=dataset.tenant_id,
  298. provider_name='openai',
  299. provider_type=ProviderType.SYSTEM.value,
  300. encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
  301. is_valid=True,
  302. )
  303. model_provider = OpenAIProvider(provider=provider)
  304. embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
  305. model_provider=model_provider)
  306. embeddings = CacheEmbedding(embedding_model)
  307. from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
  308. index = QdrantVectorIndex(
  309. dataset=dataset,
  310. config=QdrantConfig(
  311. endpoint=current_app.config.get('QDRANT_URL'),
  312. api_key=current_app.config.get('QDRANT_API_KEY'),
  313. root_path=current_app.root_path
  314. ),
  315. embeddings=embeddings
  316. )
  317. if index:
  318. index.create_qdrant_dataset(dataset)
  319. index_struct = {
  320. "type": 'qdrant',
  321. "vector_store": {
  322. "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
  323. }
  324. dataset.index_struct = json.dumps(index_struct)
  325. db.session.commit()
  326. create_count += 1
  327. else:
  328. click.echo('passed.')
  329. except Exception as e:
  330. click.echo(
  331. click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
  332. fg='red'))
  333. continue
  334. click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
  335. @click.command('update-qdrant-indexes', help='Update qdrant indexes.')
  336. def update_qdrant_indexes():
  337. click.echo(click.style('Start Update qdrant indexes.', fg='green'))
  338. create_count = 0
  339. page = 1
  340. while True:
  341. try:
  342. datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
  343. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
  344. except NotFound:
  345. break
  346. page += 1
  347. for dataset in datasets:
  348. if dataset.index_struct_dict:
  349. if dataset.index_struct_dict['type'] != 'qdrant':
  350. try:
  351. click.echo('Update dataset qdrant index: {}'.format(dataset.id))
  352. try:
  353. embedding_model = ModelFactory.get_embedding_model(
  354. tenant_id=dataset.tenant_id,
  355. model_provider_name=dataset.embedding_model_provider,
  356. model_name=dataset.embedding_model
  357. )
  358. except Exception:
  359. provider = Provider(
  360. id='provider_id',
  361. tenant_id=dataset.tenant_id,
  362. provider_name='openai',
  363. provider_type=ProviderType.CUSTOM.value,
  364. encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
  365. is_valid=True,
  366. )
  367. model_provider = OpenAIProvider(provider=provider)
  368. embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
  369. model_provider=model_provider)
  370. embeddings = CacheEmbedding(embedding_model)
  371. from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
  372. index = QdrantVectorIndex(
  373. dataset=dataset,
  374. config=QdrantConfig(
  375. endpoint=current_app.config.get('QDRANT_URL'),
  376. api_key=current_app.config.get('QDRANT_API_KEY'),
  377. root_path=current_app.root_path
  378. ),
  379. embeddings=embeddings
  380. )
  381. if index:
  382. index.update_qdrant_dataset(dataset)
  383. create_count += 1
  384. else:
  385. click.echo('passed.')
  386. except Exception as e:
  387. click.echo(
  388. click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
  389. fg='red'))
  390. continue
  391. click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
  392. @click.command('normalization-collections', help='restore all collections in one')
  393. def normalization_collections():
  394. click.echo(click.style('Start normalization collections.', fg='green'))
  395. normalization_count = []
  396. page = 1
  397. while True:
  398. try:
  399. datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
  400. .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
  401. except NotFound:
  402. break
  403. datasets_result = datasets.items
  404. page += 1
  405. for i in range(0, len(datasets_result), 5):
  406. threads = []
  407. sub_datasets = datasets_result[i:i + 5]
  408. for dataset in sub_datasets:
  409. document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
  410. 'flask_app': current_app._get_current_object(),
  411. 'dataset': dataset,
  412. 'normalization_count': normalization_count
  413. })
  414. threads.append(document_format_thread)
  415. document_format_thread.start()
  416. for thread in threads:
  417. thread.join()
  418. click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
  419. def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
  420. with flask_app.app_context():
  421. try:
  422. click.echo('restore dataset index: {}'.format(dataset.id))
  423. try:
  424. embedding_model = ModelFactory.get_embedding_model(
  425. tenant_id=dataset.tenant_id,
  426. model_provider_name=dataset.embedding_model_provider,
  427. model_name=dataset.embedding_model
  428. )
  429. except Exception:
  430. provider = Provider(
  431. id='provider_id',
  432. tenant_id=dataset.tenant_id,
  433. provider_name='openai',
  434. provider_type=ProviderType.CUSTOM.value,
  435. encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
  436. is_valid=True,
  437. )
  438. model_provider = OpenAIProvider(provider=provider)
  439. embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
  440. model_provider=model_provider)
  441. embeddings = CacheEmbedding(embedding_model)
  442. dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
  443. filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
  444. DatasetCollectionBinding.model_name == embedding_model.name). \
  445. order_by(DatasetCollectionBinding.created_at). \
  446. first()
  447. if not dataset_collection_binding:
  448. dataset_collection_binding = DatasetCollectionBinding(
  449. provider_name=embedding_model.model_provider.provider_name,
  450. model_name=embedding_model.name,
  451. collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
  452. )
  453. db.session.add(dataset_collection_binding)
  454. db.session.commit()
  455. from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
  456. index = QdrantVectorIndex(
  457. dataset=dataset,
  458. config=QdrantConfig(
  459. endpoint=current_app.config.get('QDRANT_URL'),
  460. api_key=current_app.config.get('QDRANT_API_KEY'),
  461. root_path=current_app.root_path
  462. ),
  463. embeddings=embeddings
  464. )
  465. if index:
  466. # index.delete_by_group_id(dataset.id)
  467. index.restore_dataset_in_one(dataset, dataset_collection_binding)
  468. else:
  469. click.echo('passed.')
  470. normalization_count.append(1)
  471. except Exception as e:
  472. click.echo(
  473. click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
  474. fg='red'))
  475. @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
  476. @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
  477. def update_app_model_configs(batch_size):
  478. pre_prompt_template = '{{default_input}}'
  479. user_input_form_template = {
  480. "en-US": [
  481. {
  482. "paragraph": {
  483. "label": "Query",
  484. "variable": "default_input",
  485. "required": False,
  486. "default": ""
  487. }
  488. }
  489. ],
  490. "zh-Hans": [
  491. {
  492. "paragraph": {
  493. "label": "查询内容",
  494. "variable": "default_input",
  495. "required": False,
  496. "default": ""
  497. }
  498. }
  499. ]
  500. }
  501. click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
  502. total_records = db.session.query(AppModelConfig) \
  503. .join(App, App.app_model_config_id == AppModelConfig.id) \
  504. .filter(App.mode == 'completion') \
  505. .count()
  506. if total_records == 0:
  507. click.secho("No data to migrate.", fg='green')
  508. return
  509. num_batches = (total_records + batch_size - 1) // batch_size
  510. with tqdm(total=total_records, desc="Migrating Data") as pbar:
  511. for i in range(num_batches):
  512. offset = i * batch_size
  513. limit = min(batch_size, total_records - offset)
  514. click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
  515. data_batch = db.session.query(AppModelConfig) \
  516. .join(App, App.app_model_config_id == AppModelConfig.id) \
  517. .filter(App.mode == 'completion') \
  518. .order_by(App.created_at) \
  519. .offset(offset).limit(limit).all()
  520. if not data_batch:
  521. click.secho("No more data to migrate.", fg='green')
  522. break
  523. try:
  524. click.secho(f"Migrating {len(data_batch)} records...", fg='green')
  525. for data in data_batch:
  526. # click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
  527. if data.pre_prompt is None:
  528. data.pre_prompt = pre_prompt_template
  529. else:
  530. if pre_prompt_template in data.pre_prompt:
  531. continue
  532. data.pre_prompt += pre_prompt_template
  533. app_data = db.session.query(App) \
  534. .filter(App.id == data.app_id) \
  535. .one()
  536. account_data = db.session.query(Account) \
  537. .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
  538. .filter(TenantAccountJoin.role == 'owner') \
  539. .filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
  540. .one_or_none()
  541. if not account_data:
  542. continue
  543. if data.user_input_form is None or data.user_input_form == 'null':
  544. data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
  545. else:
  546. raw_json_data = json.loads(data.user_input_form)
  547. raw_json_data.append(user_input_form_template[account_data.interface_language][0])
  548. data.user_input_form = json.dumps(raw_json_data)
  549. # click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
  550. db.session.commit()
  551. except Exception as e:
  552. click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
  553. fg='red')
  554. continue
  555. click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
  556. pbar.update(len(data_batch))
  557. @click.command('migrate_default_input_to_dataset_query_variable')
  558. @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
  559. def migrate_default_input_to_dataset_query_variable(batch_size):
  560. click.secho("Starting...", fg='green')
  561. total_records = db.session.query(AppModelConfig) \
  562. .join(App, App.app_model_config_id == AppModelConfig.id) \
  563. .filter(App.mode == 'completion') \
  564. .filter(AppModelConfig.dataset_query_variable == None) \
  565. .count()
  566. if total_records == 0:
  567. click.secho("No data to migrate.", fg='green')
  568. return
  569. num_batches = (total_records + batch_size - 1) // batch_size
  570. with tqdm(total=total_records, desc="Migrating Data") as pbar:
  571. for i in range(num_batches):
  572. offset = i * batch_size
  573. limit = min(batch_size, total_records - offset)
  574. click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
  575. data_batch = db.session.query(AppModelConfig) \
  576. .join(App, App.app_model_config_id == AppModelConfig.id) \
  577. .filter(App.mode == 'completion') \
  578. .filter(AppModelConfig.dataset_query_variable == None) \
  579. .order_by(App.created_at) \
  580. .offset(offset).limit(limit).all()
  581. if not data_batch:
  582. click.secho("No more data to migrate.", fg='green')
  583. break
  584. try:
  585. click.secho(f"Migrating {len(data_batch)} records...", fg='green')
  586. for data in data_batch:
  587. config = AppModelConfig.to_dict(data)
  588. tools = config["agent_mode"]["tools"]
  589. dataset_exists = "dataset" in str(tools)
  590. if not dataset_exists:
  591. continue
  592. user_input_form = config.get("user_input_form", [])
  593. for form in user_input_form:
  594. paragraph = form.get('paragraph')
  595. if paragraph \
  596. and paragraph.get('variable') == 'query':
  597. data.dataset_query_variable = 'query'
  598. break
  599. if paragraph \
  600. and paragraph.get('variable') == 'default_input':
  601. data.dataset_query_variable = 'default_input'
  602. break
  603. db.session.commit()
  604. except Exception as e:
  605. click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
  606. fg='red')
  607. continue
  608. click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
  609. pbar.update(len(data_batch))
  610. def register_commands(app):
  611. app.cli.add_command(reset_password)
  612. app.cli.add_command(reset_email)
  613. app.cli.add_command(generate_invitation_codes)
  614. app.cli.add_command(reset_encrypt_key_pair)
  615. app.cli.add_command(recreate_all_dataset_indexes)
  616. app.cli.add_command(sync_anthropic_hosted_providers)
  617. app.cli.add_command(clean_unused_dataset_indexes)
  618. app.cli.add_command(create_qdrant_indexes)
  619. app.cli.add_command(update_qdrant_indexes)
  620. app.cli.add_command(update_app_model_configs)
  621. app.cli.add_command(normalization_collections)
  622. app.cli.add_command(migrate_default_input_to_dataset_query_variable)