app_runner.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import logging
  2. from typing import cast
  3. from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
  4. from core.app.apps.base_app_runner import AppRunner
  5. from core.app.apps.chat.app_config_manager import ChatAppConfig
  6. from core.app.entities.app_invoke_entities import (
  7. ChatAppGenerateEntity,
  8. )
  9. from core.app.entities.queue_entities import QueueAnnotationReplyEvent
  10. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  11. from core.memory.token_buffer_memory import TokenBufferMemory
  12. from core.model_manager import ModelInstance
  13. from core.moderation.base import ModerationException
  14. from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
  15. from extensions.ext_database import db
  16. from models.model import App, Conversation, Message
  17. logger = logging.getLogger(__name__)
  18. class ChatAppRunner(AppRunner):
  19. """
  20. Chat Application Runner
  21. """
  22. def run(self, application_generate_entity: ChatAppGenerateEntity,
  23. queue_manager: AppQueueManager,
  24. conversation: Conversation,
  25. message: Message) -> None:
  26. """
  27. Run application
  28. :param application_generate_entity: application generate entity
  29. :param queue_manager: application queue manager
  30. :param conversation: conversation
  31. :param message: message
  32. :return:
  33. """
  34. app_config = application_generate_entity.app_config
  35. app_config = cast(ChatAppConfig, app_config)
  36. app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
  37. if not app_record:
  38. raise ValueError("App not found")
  39. inputs = application_generate_entity.inputs
  40. query = application_generate_entity.query
  41. files = application_generate_entity.files
  42. # Pre-calculate the number of tokens of the prompt messages,
  43. # and return the rest number of tokens by model context token size limit and max token size limit.
  44. # If the rest number of tokens is not enough, raise exception.
  45. # Include: prompt template, inputs, query(optional), files(optional)
  46. # Not Include: memory, external data, dataset context
  47. self.get_pre_calculate_rest_tokens(
  48. app_record=app_record,
  49. model_config=application_generate_entity.model_conf,
  50. prompt_template_entity=app_config.prompt_template,
  51. inputs=inputs,
  52. files=files,
  53. query=query
  54. )
  55. memory = None
  56. if application_generate_entity.conversation_id:
  57. # get memory of conversation (read-only)
  58. model_instance = ModelInstance(
  59. provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
  60. model=application_generate_entity.model_conf.model
  61. )
  62. memory = TokenBufferMemory(
  63. conversation=conversation,
  64. model_instance=model_instance
  65. )
  66. # organize all inputs and template to prompt messages
  67. # Include: prompt template, inputs, query(optional), files(optional)
  68. # memory(optional)
  69. prompt_messages, stop = self.organize_prompt_messages(
  70. app_record=app_record,
  71. model_config=application_generate_entity.model_conf,
  72. prompt_template_entity=app_config.prompt_template,
  73. inputs=inputs,
  74. files=files,
  75. query=query,
  76. memory=memory
  77. )
  78. # moderation
  79. try:
  80. # process sensitive_word_avoidance
  81. _, inputs, query = self.moderation_for_inputs(
  82. app_id=app_record.id,
  83. tenant_id=app_config.tenant_id,
  84. app_generate_entity=application_generate_entity,
  85. inputs=inputs,
  86. query=query,
  87. message_id=message.id
  88. )
  89. except ModerationException as e:
  90. self.direct_output(
  91. queue_manager=queue_manager,
  92. app_generate_entity=application_generate_entity,
  93. prompt_messages=prompt_messages,
  94. text=str(e),
  95. stream=application_generate_entity.stream
  96. )
  97. return
  98. if query:
  99. # annotation reply
  100. annotation_reply = self.query_app_annotations_to_reply(
  101. app_record=app_record,
  102. message=message,
  103. query=query,
  104. user_id=application_generate_entity.user_id,
  105. invoke_from=application_generate_entity.invoke_from
  106. )
  107. if annotation_reply:
  108. queue_manager.publish(
  109. QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
  110. PublishFrom.APPLICATION_MANAGER
  111. )
  112. self.direct_output(
  113. queue_manager=queue_manager,
  114. app_generate_entity=application_generate_entity,
  115. prompt_messages=prompt_messages,
  116. text=annotation_reply.content,
  117. stream=application_generate_entity.stream
  118. )
  119. return
  120. # fill in variable inputs from external data tools if exists
  121. external_data_tools = app_config.external_data_variables
  122. if external_data_tools:
  123. inputs = self.fill_in_inputs_from_external_data_tools(
  124. tenant_id=app_record.tenant_id,
  125. app_id=app_record.id,
  126. external_data_tools=external_data_tools,
  127. inputs=inputs,
  128. query=query
  129. )
  130. # get context from datasets
  131. context = None
  132. if app_config.dataset and app_config.dataset.dataset_ids:
  133. hit_callback = DatasetIndexToolCallbackHandler(
  134. queue_manager,
  135. app_record.id,
  136. message.id,
  137. application_generate_entity.user_id,
  138. application_generate_entity.invoke_from
  139. )
  140. dataset_retrieval = DatasetRetrieval(application_generate_entity)
  141. context = dataset_retrieval.retrieve(
  142. app_id=app_record.id,
  143. user_id=application_generate_entity.user_id,
  144. tenant_id=app_record.tenant_id,
  145. model_config=application_generate_entity.model_conf,
  146. config=app_config.dataset,
  147. query=query,
  148. invoke_from=application_generate_entity.invoke_from,
  149. show_retrieve_source=app_config.additional_features.show_retrieve_source,
  150. hit_callback=hit_callback,
  151. memory=memory,
  152. message_id=message.id,
  153. )
  154. # reorganize all inputs and template to prompt messages
  155. # Include: prompt template, inputs, query(optional), files(optional)
  156. # memory(optional), external data, dataset context(optional)
  157. prompt_messages, stop = self.organize_prompt_messages(
  158. app_record=app_record,
  159. model_config=application_generate_entity.model_conf,
  160. prompt_template_entity=app_config.prompt_template,
  161. inputs=inputs,
  162. files=files,
  163. query=query,
  164. context=context,
  165. memory=memory
  166. )
  167. # check hosting moderation
  168. hosting_moderation_result = self.check_hosting_moderation(
  169. application_generate_entity=application_generate_entity,
  170. queue_manager=queue_manager,
  171. prompt_messages=prompt_messages
  172. )
  173. if hosting_moderation_result:
  174. return
  175. # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
  176. self.recalc_llm_max_tokens(
  177. model_config=application_generate_entity.model_conf,
  178. prompt_messages=prompt_messages
  179. )
  180. # Invoke model
  181. model_instance = ModelInstance(
  182. provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
  183. model=application_generate_entity.model_conf.model
  184. )
  185. db.session.close()
  186. invoke_result = model_instance.invoke_llm(
  187. prompt_messages=prompt_messages,
  188. model_parameters=application_generate_entity.model_conf.parameters,
  189. stop=stop,
  190. stream=application_generate_entity.stream,
  191. user=application_generate_entity.user_id,
  192. )
  193. # handle invoke result
  194. self._handle_invoke_result(
  195. invoke_result=invoke_result,
  196. queue_manager=queue_manager,
  197. stream=application_generate_entity.stream
  198. )