app_runner.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import logging
  2. from typing import cast
  3. from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
  4. from core.app.apps.base_app_runner import AppRunner
  5. from core.app.apps.chat.app_config_manager import ChatAppConfig
  6. from core.app.entities.app_invoke_entities import (
  7. ChatAppGenerateEntity,
  8. )
  9. from core.app.entities.queue_entities import QueueAnnotationReplyEvent
  10. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  11. from core.memory.token_buffer_memory import TokenBufferMemory
  12. from core.model_manager import ModelInstance
  13. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  14. from core.moderation.base import ModerationError
  15. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  16. from extensions.ext_database import db
  17. from models.model import App, Conversation, Message
  18. logger = logging.getLogger(__name__)
  19. class ChatAppRunner(AppRunner):
  20. """
  21. Chat Application Runner
  22. """
  23. def run(
  24. self,
  25. application_generate_entity: ChatAppGenerateEntity,
  26. queue_manager: AppQueueManager,
  27. conversation: Conversation,
  28. message: Message,
  29. ) -> None:
  30. """
  31. Run application
  32. :param application_generate_entity: application generate entity
  33. :param queue_manager: application queue manager
  34. :param conversation: conversation
  35. :param message: message
  36. :return:
  37. """
  38. app_config = application_generate_entity.app_config
  39. app_config = cast(ChatAppConfig, app_config)
  40. app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
  41. if not app_record:
  42. raise ValueError("App not found")
  43. inputs = application_generate_entity.inputs
  44. query = application_generate_entity.query
  45. files = application_generate_entity.files
  46. image_detail_config = (
  47. application_generate_entity.file_upload_config.image_config.detail
  48. if (
  49. application_generate_entity.file_upload_config
  50. and application_generate_entity.file_upload_config.image_config
  51. )
  52. else None
  53. )
  54. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  55. # Pre-calculate the number of tokens of the prompt messages,
  56. # and return the rest number of tokens by model context token size limit and max token size limit.
  57. # If the rest number of tokens is not enough, raise exception.
  58. # Include: prompt template, inputs, query(optional), files(optional)
  59. # Not Include: memory, external data, dataset context
  60. self.get_pre_calculate_rest_tokens(
  61. app_record=app_record,
  62. model_config=application_generate_entity.model_conf,
  63. prompt_template_entity=app_config.prompt_template,
  64. inputs=inputs,
  65. files=files,
  66. query=query,
  67. )
  68. memory = None
  69. if application_generate_entity.conversation_id:
  70. # get memory of conversation (read-only)
  71. model_instance = ModelInstance(
  72. provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
  73. model=application_generate_entity.model_conf.model,
  74. )
  75. memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
  76. # organize all inputs and template to prompt messages
  77. # Include: prompt template, inputs, query(optional), files(optional)
  78. # memory(optional)
  79. prompt_messages, stop = self.organize_prompt_messages(
  80. app_record=app_record,
  81. model_config=application_generate_entity.model_conf,
  82. prompt_template_entity=app_config.prompt_template,
  83. inputs=inputs,
  84. files=files,
  85. query=query,
  86. memory=memory,
  87. image_detail_config=image_detail_config,
  88. )
  89. # moderation
  90. try:
  91. # process sensitive_word_avoidance
  92. _, inputs, query = self.moderation_for_inputs(
  93. app_id=app_record.id,
  94. tenant_id=app_config.tenant_id,
  95. app_generate_entity=application_generate_entity,
  96. inputs=inputs,
  97. query=query,
  98. message_id=message.id,
  99. )
  100. except ModerationError as e:
  101. self.direct_output(
  102. queue_manager=queue_manager,
  103. app_generate_entity=application_generate_entity,
  104. prompt_messages=prompt_messages,
  105. text=str(e),
  106. stream=application_generate_entity.stream,
  107. )
  108. return
  109. if query:
  110. # annotation reply
  111. annotation_reply = self.query_app_annotations_to_reply(
  112. app_record=app_record,
  113. message=message,
  114. query=query,
  115. user_id=application_generate_entity.user_id,
  116. invoke_from=application_generate_entity.invoke_from,
  117. )
  118. if annotation_reply:
  119. queue_manager.publish(
  120. QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
  121. PublishFrom.APPLICATION_MANAGER,
  122. )
  123. self.direct_output(
  124. queue_manager=queue_manager,
  125. app_generate_entity=application_generate_entity,
  126. prompt_messages=prompt_messages,
  127. text=annotation_reply.content,
  128. stream=application_generate_entity.stream,
  129. )
  130. return
  131. # fill in variable inputs from external data tools if exists
  132. external_data_tools = app_config.external_data_variables
  133. if external_data_tools:
  134. inputs = self.fill_in_inputs_from_external_data_tools(
  135. tenant_id=app_record.tenant_id,
  136. app_id=app_record.id,
  137. external_data_tools=external_data_tools,
  138. inputs=inputs,
  139. query=query,
  140. )
  141. # get context from datasets
  142. context = None
  143. if app_config.dataset and app_config.dataset.dataset_ids:
  144. hit_callback = DatasetIndexToolCallbackHandler(
  145. queue_manager,
  146. app_record.id,
  147. message.id,
  148. application_generate_entity.user_id,
  149. application_generate_entity.invoke_from,
  150. )
  151. dataset_retrieval = DatasetRetrieval(application_generate_entity)
  152. context = dataset_retrieval.retrieve(
  153. app_id=app_record.id,
  154. user_id=application_generate_entity.user_id,
  155. tenant_id=app_record.tenant_id,
  156. model_config=application_generate_entity.model_conf,
  157. config=app_config.dataset,
  158. query=query,
  159. invoke_from=application_generate_entity.invoke_from,
  160. show_retrieve_source=app_config.additional_features.show_retrieve_source,
  161. hit_callback=hit_callback,
  162. memory=memory,
  163. message_id=message.id,
  164. inputs=inputs,
  165. )
  166. # reorganize all inputs and template to prompt messages
  167. # Include: prompt template, inputs, query(optional), files(optional)
  168. # memory(optional), external data, dataset context(optional)
  169. prompt_messages, stop = self.organize_prompt_messages(
  170. app_record=app_record,
  171. model_config=application_generate_entity.model_conf,
  172. prompt_template_entity=app_config.prompt_template,
  173. inputs=inputs,
  174. files=files,
  175. query=query,
  176. context=context,
  177. memory=memory,
  178. image_detail_config=image_detail_config,
  179. )
  180. # check hosting moderation
  181. hosting_moderation_result = self.check_hosting_moderation(
  182. application_generate_entity=application_generate_entity,
  183. queue_manager=queue_manager,
  184. prompt_messages=prompt_messages,
  185. )
  186. if hosting_moderation_result:
  187. return
  188. # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
  189. self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
  190. # Invoke model
  191. model_instance = ModelInstance(
  192. provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
  193. model=application_generate_entity.model_conf.model,
  194. )
  195. db.session.close()
  196. invoke_result = model_instance.invoke_llm(
  197. prompt_messages=prompt_messages,
  198. model_parameters=application_generate_entity.model_conf.parameters,
  199. stop=stop,
  200. stream=application_generate_entity.stream,
  201. user=application_generate_entity.user_id,
  202. )
  203. # handle invoke result
  204. self._handle_invoke_result(
  205. invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
  206. )