enable_annotation_reply_task.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import datetime
  2. import logging
  3. import time
  4. import click
  5. from celery import shared_task
  6. from werkzeug.exceptions import NotFound
  7. from core.rag.datasource.vdb.vector_factory import Vector
  8. from core.rag.models.document import Document
  9. from extensions.ext_database import db
  10. from extensions.ext_redis import redis_client
  11. from models.dataset import Dataset
  12. from models.model import App, AppAnnotationSetting, MessageAnnotation
  13. from services.dataset_service import DatasetCollectionBindingService
  14. @shared_task(queue='dataset')
  15. def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float,
  16. embedding_provider_name: str, embedding_model_name: str):
  17. """
  18. Async enable annotation reply task
  19. """
  20. logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green'))
  21. start_at = time.perf_counter()
  22. # get app info
  23. app = db.session.query(App).filter(
  24. App.id == app_id,
  25. App.tenant_id == tenant_id,
  26. App.status == 'normal'
  27. ).first()
  28. if not app:
  29. raise NotFound("App not found")
  30. annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
  31. enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
  32. enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
  33. try:
  34. documents = []
  35. dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
  36. embedding_provider_name,
  37. embedding_model_name,
  38. 'annotation'
  39. )
  40. annotation_setting = db.session.query(AppAnnotationSetting).filter(
  41. AppAnnotationSetting.app_id == app_id).first()
  42. if annotation_setting:
  43. annotation_setting.score_threshold = score_threshold
  44. annotation_setting.collection_binding_id = dataset_collection_binding.id
  45. annotation_setting.updated_user_id = user_id
  46. annotation_setting.updated_at = datetime.datetime.utcnow()
  47. db.session.add(annotation_setting)
  48. else:
  49. new_app_annotation_setting = AppAnnotationSetting(
  50. app_id=app_id,
  51. score_threshold=score_threshold,
  52. collection_binding_id=dataset_collection_binding.id,
  53. created_user_id=user_id,
  54. updated_user_id=user_id
  55. )
  56. db.session.add(new_app_annotation_setting)
  57. dataset = Dataset(
  58. id=app_id,
  59. tenant_id=tenant_id,
  60. indexing_technique='high_quality',
  61. embedding_model_provider=embedding_provider_name,
  62. embedding_model=embedding_model_name,
  63. collection_binding_id=dataset_collection_binding.id
  64. )
  65. if annotations:
  66. for annotation in annotations:
  67. document = Document(
  68. page_content=annotation.question,
  69. metadata={
  70. "annotation_id": annotation.id,
  71. "app_id": app_id,
  72. "doc_id": annotation.id
  73. }
  74. )
  75. documents.append(document)
  76. vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
  77. try:
  78. vector.delete_by_metadata_field('app_id', app_id)
  79. except Exception as e:
  80. logging.info(
  81. click.style('Delete annotation index error: {}'.format(str(e)),
  82. fg='red'))
  83. vector.create(documents)
  84. db.session.commit()
  85. redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
  86. end_at = time.perf_counter()
  87. logging.info(
  88. click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at),
  89. fg='green'))
  90. except Exception as e:
  91. logging.exception("Annotation batch created index failed:{}".format(str(e)))
  92. redis_client.setex(enable_app_annotation_job_key, 600, 'error')
  93. enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id))
  94. redis_client.setex(enable_app_annotation_error_key, 600, str(e))
  95. db.session.rollback()
  96. finally:
  97. redis_client.delete(enable_app_annotation_key)