app_runner.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import logging
  2. from typing import cast
  3. from core.app.apps.base_app_queue_manager import AppQueueManager
  4. from core.app.apps.base_app_runner import AppRunner
  5. from core.app.apps.completion.app_config_manager import CompletionAppConfig
  6. from core.app.entities.app_invoke_entities import (
  7. CompletionAppGenerateEntity,
  8. )
  9. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  10. from core.model_manager import ModelInstance
  11. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  12. from core.moderation.base import ModerationError
  13. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  14. from extensions.ext_database import db
  15. from models.model import App, Message
  16. logger = logging.getLogger(__name__)
  17. class CompletionAppRunner(AppRunner):
  18. """
  19. Completion Application Runner
  20. """
  21. def run(
  22. self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
  23. ) -> None:
  24. """
  25. Run application
  26. :param application_generate_entity: application generate entity
  27. :param queue_manager: application queue manager
  28. :param message: message
  29. :return:
  30. """
  31. app_config = application_generate_entity.app_config
  32. app_config = cast(CompletionAppConfig, app_config)
  33. app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
  34. if not app_record:
  35. raise ValueError("App not found")
  36. inputs = application_generate_entity.inputs
  37. query = application_generate_entity.query
  38. files = application_generate_entity.files
  39. image_detail_config = (
  40. application_generate_entity.file_upload_config.image_config.detail
  41. if (
  42. application_generate_entity.file_upload_config
  43. and application_generate_entity.file_upload_config.image_config
  44. )
  45. else None
  46. )
  47. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  48. # Pre-calculate the number of tokens of the prompt messages,
  49. # and return the rest number of tokens by model context token size limit and max token size limit.
  50. # If the rest number of tokens is not enough, raise exception.
  51. # Include: prompt template, inputs, query(optional), files(optional)
  52. # Not Include: memory, external data, dataset context
  53. self.get_pre_calculate_rest_tokens(
  54. app_record=app_record,
  55. model_config=application_generate_entity.model_conf,
  56. prompt_template_entity=app_config.prompt_template,
  57. inputs=inputs,
  58. files=files,
  59. query=query,
  60. )
  61. # organize all inputs and template to prompt messages
  62. # Include: prompt template, inputs, query(optional), files(optional)
  63. prompt_messages, stop = self.organize_prompt_messages(
  64. app_record=app_record,
  65. model_config=application_generate_entity.model_conf,
  66. prompt_template_entity=app_config.prompt_template,
  67. inputs=inputs,
  68. files=files,
  69. query=query,
  70. image_detail_config=image_detail_config,
  71. )
  72. # moderation
  73. try:
  74. # process sensitive_word_avoidance
  75. _, inputs, query = self.moderation_for_inputs(
  76. app_id=app_record.id,
  77. tenant_id=app_config.tenant_id,
  78. app_generate_entity=application_generate_entity,
  79. inputs=inputs,
  80. query=query or "",
  81. message_id=message.id,
  82. )
  83. except ModerationError as e:
  84. self.direct_output(
  85. queue_manager=queue_manager,
  86. app_generate_entity=application_generate_entity,
  87. prompt_messages=prompt_messages,
  88. text=str(e),
  89. stream=application_generate_entity.stream,
  90. )
  91. return
  92. # fill in variable inputs from external data tools if exists
  93. external_data_tools = app_config.external_data_variables
  94. if external_data_tools:
  95. inputs = self.fill_in_inputs_from_external_data_tools(
  96. tenant_id=app_record.tenant_id,
  97. app_id=app_record.id,
  98. external_data_tools=external_data_tools,
  99. inputs=inputs,
  100. query=query,
  101. )
  102. # get context from datasets
  103. context = None
  104. if app_config.dataset and app_config.dataset.dataset_ids:
  105. hit_callback = DatasetIndexToolCallbackHandler(
  106. queue_manager,
  107. app_record.id,
  108. message.id,
  109. application_generate_entity.user_id,
  110. application_generate_entity.invoke_from,
  111. )
  112. dataset_config = app_config.dataset
  113. if dataset_config and dataset_config.retrieve_config.query_variable:
  114. query = inputs.get(dataset_config.retrieve_config.query_variable, "")
  115. dataset_retrieval = DatasetRetrieval(application_generate_entity)
  116. context = dataset_retrieval.retrieve(
  117. app_id=app_record.id,
  118. user_id=application_generate_entity.user_id,
  119. tenant_id=app_record.tenant_id,
  120. model_config=application_generate_entity.model_conf,
  121. config=dataset_config,
  122. query=query or "",
  123. invoke_from=application_generate_entity.invoke_from,
  124. show_retrieve_source=app_config.additional_features.show_retrieve_source,
  125. hit_callback=hit_callback,
  126. message_id=message.id,
  127. inputs=inputs,
  128. )
  129. # reorganize all inputs and template to prompt messages
  130. # Include: prompt template, inputs, query(optional), files(optional)
  131. # memory(optional), external data, dataset context(optional)
  132. prompt_messages, stop = self.organize_prompt_messages(
  133. app_record=app_record,
  134. model_config=application_generate_entity.model_conf,
  135. prompt_template_entity=app_config.prompt_template,
  136. inputs=inputs,
  137. files=files,
  138. query=query,
  139. context=context,
  140. image_detail_config=image_detail_config,
  141. )
  142. # check hosting moderation
  143. hosting_moderation_result = self.check_hosting_moderation(
  144. application_generate_entity=application_generate_entity,
  145. queue_manager=queue_manager,
  146. prompt_messages=prompt_messages,
  147. )
  148. if hosting_moderation_result:
  149. return
  150. # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
  151. self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
  152. # Invoke model
  153. model_instance = ModelInstance(
  154. provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
  155. model=application_generate_entity.model_conf.model,
  156. )
  157. db.session.close()
  158. invoke_result = model_instance.invoke_llm(
  159. prompt_messages=prompt_messages,
  160. model_parameters=application_generate_entity.model_conf.parameters,
  161. stop=stop,
  162. stream=application_generate_entity.stream,
  163. user=application_generate_entity.user_id,
  164. )
  165. # handle invoke result
  166. self._handle_invoke_result(
  167. invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
  168. )