generate_task_pipeline.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. import json
  2. import logging
  3. import time
  4. from collections.abc import Generator
  5. from typing import Optional, Union, cast
  6. from pydantic import BaseModel
  7. from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler
  8. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  9. from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom
  10. from core.entities.queue_entities import (
  11. AnnotationReplyEvent,
  12. QueueAgentMessageEvent,
  13. QueueAgentThoughtEvent,
  14. QueueErrorEvent,
  15. QueueMessageEndEvent,
  16. QueueMessageEvent,
  17. QueueMessageFileEvent,
  18. QueueMessageReplaceEvent,
  19. QueuePingEvent,
  20. QueueRetrieverResourcesEvent,
  21. QueueStopEvent,
  22. )
  23. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  24. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  25. from core.model_runtime.entities.message_entities import (
  26. AssistantPromptMessage,
  27. ImagePromptMessageContent,
  28. PromptMessage,
  29. PromptMessageContentType,
  30. PromptMessageRole,
  31. TextPromptMessageContent,
  32. )
  33. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  34. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  35. from core.model_runtime.utils.encoders import jsonable_encoder
  36. from core.prompt.prompt_template import PromptTemplateParser
  37. from core.tools.tool_file_manager import ToolFileManager
  38. from events.message_event import message_was_created
  39. from extensions.ext_database import db
  40. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  41. from services.annotation_service import AppAnnotationService
  42. logger = logging.getLogger(__name__)
  43. class TaskState(BaseModel):
  44. """
  45. TaskState entity
  46. """
  47. llm_result: LLMResult
  48. metadata: dict = {}
  49. class GenerateTaskPipeline:
  50. """
  51. GenerateTaskPipeline is a class that generate stream output and state management for Application.
  52. """
  53. def __init__(self, application_generate_entity: ApplicationGenerateEntity,
  54. queue_manager: ApplicationQueueManager,
  55. conversation: Conversation,
  56. message: Message) -> None:
  57. """
  58. Initialize GenerateTaskPipeline.
  59. :param application_generate_entity: application generate entity
  60. :param queue_manager: queue manager
  61. :param conversation: conversation
  62. :param message: message
  63. """
  64. self._application_generate_entity = application_generate_entity
  65. self._queue_manager = queue_manager
  66. self._conversation = conversation
  67. self._message = message
  68. self._task_state = TaskState(
  69. llm_result=LLMResult(
  70. model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
  71. prompt_messages=[],
  72. message=AssistantPromptMessage(content=""),
  73. usage=LLMUsage.empty_usage()
  74. )
  75. )
  76. self._start_at = time.perf_counter()
  77. self._output_moderation_handler = self._init_output_moderation()
  78. def process(self, stream: bool) -> Union[dict, Generator]:
  79. """
  80. Process generate task pipeline.
  81. :return:
  82. """
  83. db.session.refresh(self._conversation)
  84. db.session.refresh(self._message)
  85. db.session.close()
  86. if stream:
  87. return self._process_stream_response()
  88. else:
  89. return self._process_blocking_response()
  90. def _process_blocking_response(self) -> dict:
  91. """
  92. Process blocking response.
  93. :return:
  94. """
  95. for queue_message in self._queue_manager.listen():
  96. event = queue_message.event
  97. if isinstance(event, QueueErrorEvent):
  98. raise self._handle_error(event)
  99. elif isinstance(event, QueueRetrieverResourcesEvent):
  100. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  101. elif isinstance(event, AnnotationReplyEvent):
  102. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  103. if annotation:
  104. account = annotation.account
  105. self._task_state.metadata['annotation_reply'] = {
  106. 'id': annotation.id,
  107. 'account': {
  108. 'id': annotation.account_id,
  109. 'name': account.name if account else 'Dify user'
  110. }
  111. }
  112. self._task_state.llm_result.message.content = annotation.content
  113. elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
  114. if isinstance(event, QueueMessageEndEvent):
  115. self._task_state.llm_result = event.llm_result
  116. else:
  117. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  118. model = model_config.model
  119. model_type_instance = model_config.provider_model_bundle.model_type_instance
  120. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  121. # calculate num tokens
  122. prompt_tokens = 0
  123. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  124. prompt_tokens = model_type_instance.get_num_tokens(
  125. model,
  126. model_config.credentials,
  127. self._task_state.llm_result.prompt_messages
  128. )
  129. completion_tokens = 0
  130. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  131. completion_tokens = model_type_instance.get_num_tokens(
  132. model,
  133. model_config.credentials,
  134. [self._task_state.llm_result.message]
  135. )
  136. credentials = model_config.credentials
  137. # transform usage
  138. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  139. model,
  140. credentials,
  141. prompt_tokens,
  142. completion_tokens
  143. )
  144. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  145. # response moderation
  146. if self._output_moderation_handler:
  147. self._output_moderation_handler.stop_thread()
  148. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  149. completion=self._task_state.llm_result.message.content,
  150. public_event=False
  151. )
  152. # Save message
  153. self._save_message(self._task_state.llm_result)
  154. response = {
  155. 'event': 'message',
  156. 'task_id': self._application_generate_entity.task_id,
  157. 'id': self._message.id,
  158. 'message_id': self._message.id,
  159. 'mode': self._conversation.mode,
  160. 'answer': self._task_state.llm_result.message.content,
  161. 'metadata': {},
  162. 'created_at': int(self._message.created_at.timestamp())
  163. }
  164. if self._conversation.mode == 'chat':
  165. response['conversation_id'] = self._conversation.id
  166. if self._task_state.metadata:
  167. response['metadata'] = self._get_response_metadata()
  168. return response
  169. else:
  170. continue
  171. def _process_stream_response(self) -> Generator:
  172. """
  173. Process stream response.
  174. :return:
  175. """
  176. for message in self._queue_manager.listen():
  177. event = message.event
  178. if isinstance(event, QueueErrorEvent):
  179. data = self._error_to_stream_response_data(self._handle_error(event))
  180. yield self._yield_response(data)
  181. break
  182. elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
  183. if isinstance(event, QueueMessageEndEvent):
  184. self._task_state.llm_result = event.llm_result
  185. else:
  186. model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
  187. model = model_config.model
  188. model_type_instance = model_config.provider_model_bundle.model_type_instance
  189. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  190. # calculate num tokens
  191. prompt_tokens = 0
  192. if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
  193. prompt_tokens = model_type_instance.get_num_tokens(
  194. model,
  195. model_config.credentials,
  196. self._task_state.llm_result.prompt_messages
  197. )
  198. completion_tokens = 0
  199. if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
  200. completion_tokens = model_type_instance.get_num_tokens(
  201. model,
  202. model_config.credentials,
  203. [self._task_state.llm_result.message]
  204. )
  205. credentials = model_config.credentials
  206. # transform usage
  207. self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
  208. model,
  209. credentials,
  210. prompt_tokens,
  211. completion_tokens
  212. )
  213. self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
  214. # response moderation
  215. if self._output_moderation_handler:
  216. self._output_moderation_handler.stop_thread()
  217. self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
  218. completion=self._task_state.llm_result.message.content,
  219. public_event=False
  220. )
  221. self._output_moderation_handler = None
  222. replace_response = {
  223. 'event': 'message_replace',
  224. 'task_id': self._application_generate_entity.task_id,
  225. 'message_id': self._message.id,
  226. 'answer': self._task_state.llm_result.message.content,
  227. 'created_at': int(self._message.created_at.timestamp())
  228. }
  229. if self._conversation.mode == 'chat':
  230. replace_response['conversation_id'] = self._conversation.id
  231. yield self._yield_response(replace_response)
  232. # Save message
  233. self._save_message(self._task_state.llm_result)
  234. response = {
  235. 'event': 'message_end',
  236. 'task_id': self._application_generate_entity.task_id,
  237. 'id': self._message.id,
  238. 'message_id': self._message.id,
  239. }
  240. if self._conversation.mode == 'chat':
  241. response['conversation_id'] = self._conversation.id
  242. if self._task_state.metadata:
  243. response['metadata'] = self._get_response_metadata()
  244. yield self._yield_response(response)
  245. elif isinstance(event, QueueRetrieverResourcesEvent):
  246. self._task_state.metadata['retriever_resources'] = event.retriever_resources
  247. elif isinstance(event, AnnotationReplyEvent):
  248. annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
  249. if annotation:
  250. account = annotation.account
  251. self._task_state.metadata['annotation_reply'] = {
  252. 'id': annotation.id,
  253. 'account': {
  254. 'id': annotation.account_id,
  255. 'name': account.name if account else 'Dify user'
  256. }
  257. }
  258. self._task_state.llm_result.message.content = annotation.content
  259. elif isinstance(event, QueueAgentThoughtEvent):
  260. agent_thought: MessageAgentThought = (
  261. db.session.query(MessageAgentThought)
  262. .filter(MessageAgentThought.id == event.agent_thought_id)
  263. .first()
  264. )
  265. db.session.refresh(agent_thought)
  266. db.session.close()
  267. if agent_thought:
  268. response = {
  269. 'event': 'agent_thought',
  270. 'id': agent_thought.id,
  271. 'task_id': self._application_generate_entity.task_id,
  272. 'message_id': self._message.id,
  273. 'position': agent_thought.position,
  274. 'thought': agent_thought.thought,
  275. 'observation': agent_thought.observation,
  276. 'tool': agent_thought.tool,
  277. 'tool_labels': agent_thought.tool_labels,
  278. 'tool_input': agent_thought.tool_input,
  279. 'created_at': int(self._message.created_at.timestamp()),
  280. 'message_files': agent_thought.files
  281. }
  282. if self._conversation.mode == 'chat':
  283. response['conversation_id'] = self._conversation.id
  284. yield self._yield_response(response)
  285. elif isinstance(event, QueueMessageFileEvent):
  286. message_file: MessageFile = (
  287. db.session.query(MessageFile)
  288. .filter(MessageFile.id == event.message_file_id)
  289. .first()
  290. )
  291. db.session.close()
  292. # get extension
  293. if '.' in message_file.url:
  294. extension = f'.{message_file.url.split(".")[-1]}'
  295. if len(extension) > 10:
  296. extension = '.bin'
  297. else:
  298. extension = '.bin'
  299. # add sign url
  300. url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
  301. if message_file:
  302. response = {
  303. 'event': 'message_file',
  304. 'id': message_file.id,
  305. 'type': message_file.type,
  306. 'belongs_to': message_file.belongs_to or 'user',
  307. 'url': url
  308. }
  309. if self._conversation.mode == 'chat':
  310. response['conversation_id'] = self._conversation.id
  311. yield self._yield_response(response)
  312. elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
  313. chunk = event.chunk
  314. delta_text = chunk.delta.message.content
  315. if delta_text is None:
  316. continue
  317. if not self._task_state.llm_result.prompt_messages:
  318. self._task_state.llm_result.prompt_messages = chunk.prompt_messages
  319. if self._output_moderation_handler:
  320. if self._output_moderation_handler.should_direct_output():
  321. # stop subscribe new token when output moderation should direct output
  322. self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
  323. self._queue_manager.publish_chunk_message(LLMResultChunk(
  324. model=self._task_state.llm_result.model,
  325. prompt_messages=self._task_state.llm_result.prompt_messages,
  326. delta=LLMResultChunkDelta(
  327. index=0,
  328. message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
  329. )
  330. ), PublishFrom.TASK_PIPELINE)
  331. self._queue_manager.publish(
  332. QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
  333. PublishFrom.TASK_PIPELINE
  334. )
  335. continue
  336. else:
  337. self._output_moderation_handler.append_new_token(delta_text)
  338. self._task_state.llm_result.message.content += delta_text
  339. response = self._handle_chunk(delta_text, agent=isinstance(event, QueueAgentMessageEvent))
  340. yield self._yield_response(response)
  341. elif isinstance(event, QueueMessageReplaceEvent):
  342. response = {
  343. 'event': 'message_replace',
  344. 'task_id': self._application_generate_entity.task_id,
  345. 'message_id': self._message.id,
  346. 'answer': event.text,
  347. 'created_at': int(self._message.created_at.timestamp())
  348. }
  349. if self._conversation.mode == 'chat':
  350. response['conversation_id'] = self._conversation.id
  351. yield self._yield_response(response)
  352. elif isinstance(event, QueuePingEvent):
  353. yield "event: ping\n\n"
  354. else:
  355. continue
  356. def _save_message(self, llm_result: LLMResult) -> None:
  357. """
  358. Save message.
  359. :param llm_result: llm result
  360. :return:
  361. """
  362. usage = llm_result.usage
  363. self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
  364. self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
  365. self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
  366. self._message.message_tokens = usage.prompt_tokens
  367. self._message.message_unit_price = usage.prompt_unit_price
  368. self._message.message_price_unit = usage.prompt_price_unit
  369. self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
  370. if llm_result.message.content else ''
  371. self._message.answer_tokens = usage.completion_tokens
  372. self._message.answer_unit_price = usage.completion_unit_price
  373. self._message.answer_price_unit = usage.completion_price_unit
  374. self._message.provider_response_latency = time.perf_counter() - self._start_at
  375. self._message.total_price = usage.total_price
  376. db.session.commit()
  377. message_was_created.send(
  378. self._message,
  379. application_generate_entity=self._application_generate_entity,
  380. conversation=self._conversation,
  381. is_first_message=self._application_generate_entity.conversation_id is None,
  382. extras=self._application_generate_entity.extras
  383. )
  384. def _handle_chunk(self, text: str, agent: bool = False) -> dict:
  385. """
  386. Handle completed event.
  387. :param text: text
  388. :return:
  389. """
  390. response = {
  391. 'event': 'message' if not agent else 'agent_message',
  392. 'id': self._message.id,
  393. 'task_id': self._application_generate_entity.task_id,
  394. 'message_id': self._message.id,
  395. 'answer': text,
  396. 'created_at': int(self._message.created_at.timestamp())
  397. }
  398. if self._conversation.mode == 'chat':
  399. response['conversation_id'] = self._conversation.id
  400. return response
  401. def _handle_error(self, event: QueueErrorEvent) -> Exception:
  402. """
  403. Handle error event.
  404. :param event: event
  405. :return:
  406. """
  407. logger.debug("error: %s", event.error)
  408. e = event.error
  409. if isinstance(e, InvokeAuthorizationError):
  410. return InvokeAuthorizationError('Incorrect API key provided')
  411. elif isinstance(e, InvokeError) or isinstance(e, ValueError):
  412. return e
  413. else:
  414. return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
  415. def _error_to_stream_response_data(self, e: Exception) -> dict:
  416. """
  417. Error to stream response.
  418. :param e: exception
  419. :return:
  420. """
  421. error_responses = {
  422. ValueError: {'code': 'invalid_param', 'status': 400},
  423. ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
  424. QuotaExceededError: {
  425. 'code': 'provider_quota_exceeded',
  426. 'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
  427. "Please go to Settings -> Model Provider to complete your own provider credentials.",
  428. 'status': 400
  429. },
  430. ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
  431. InvokeError: {'code': 'completion_request_error', 'status': 400}
  432. }
  433. # Determine the response based on the type of exception
  434. data = None
  435. for k, v in error_responses.items():
  436. if isinstance(e, k):
  437. data = v
  438. if data:
  439. data.setdefault('message', getattr(e, 'description', str(e)))
  440. else:
  441. logging.error(e)
  442. data = {
  443. 'code': 'internal_server_error',
  444. 'message': 'Internal Server Error, please contact support.',
  445. 'status': 500
  446. }
  447. return {
  448. 'event': 'error',
  449. 'task_id': self._application_generate_entity.task_id,
  450. 'message_id': self._message.id,
  451. **data
  452. }
  453. def _get_response_metadata(self) -> dict:
  454. """
  455. Get response metadata by invoke from.
  456. :return:
  457. """
  458. metadata = {}
  459. # show_retrieve_source
  460. if 'retriever_resources' in self._task_state.metadata:
  461. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  462. metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
  463. else:
  464. metadata['retriever_resources'] = []
  465. for resource in self._task_state.metadata['retriever_resources']:
  466. metadata['retriever_resources'].append({
  467. 'segment_id': resource['segment_id'],
  468. 'position': resource['position'],
  469. 'document_name': resource['document_name'],
  470. 'score': resource['score'],
  471. 'content': resource['content'],
  472. })
  473. # show annotation reply
  474. if 'annotation_reply' in self._task_state.metadata:
  475. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  476. metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
  477. # show usage
  478. if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
  479. metadata['usage'] = self._task_state.metadata['usage']
  480. return metadata
  481. def _yield_response(self, response: dict) -> str:
  482. """
  483. Yield response.
  484. :param response: response
  485. :return:
  486. """
  487. return "data: " + json.dumps(response) + "\n\n"
  488. def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
  489. """
  490. Prompt messages to prompt for saving.
  491. :param prompt_messages: prompt messages
  492. :return:
  493. """
  494. prompts = []
  495. if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
  496. for prompt_message in prompt_messages:
  497. if prompt_message.role == PromptMessageRole.USER:
  498. role = 'user'
  499. elif prompt_message.role == PromptMessageRole.ASSISTANT:
  500. role = 'assistant'
  501. elif prompt_message.role == PromptMessageRole.SYSTEM:
  502. role = 'system'
  503. else:
  504. continue
  505. text = ''
  506. files = []
  507. if isinstance(prompt_message.content, list):
  508. for content in prompt_message.content:
  509. if content.type == PromptMessageContentType.TEXT:
  510. content = cast(TextPromptMessageContent, content)
  511. text += content.data
  512. else:
  513. content = cast(ImagePromptMessageContent, content)
  514. files.append({
  515. "type": 'image',
  516. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  517. "detail": content.detail.value
  518. })
  519. else:
  520. text = prompt_message.content
  521. prompts.append({
  522. "role": role,
  523. "text": text,
  524. "files": files
  525. })
  526. else:
  527. prompt_message = prompt_messages[0]
  528. text = ''
  529. files = []
  530. if isinstance(prompt_message.content, list):
  531. for content in prompt_message.content:
  532. if content.type == PromptMessageContentType.TEXT:
  533. content = cast(TextPromptMessageContent, content)
  534. text += content.data
  535. else:
  536. content = cast(ImagePromptMessageContent, content)
  537. files.append({
  538. "type": 'image',
  539. "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
  540. "detail": content.detail.value
  541. })
  542. else:
  543. text = prompt_message.content
  544. params = {
  545. "role": 'user',
  546. "text": text,
  547. }
  548. if files:
  549. params['files'] = files
  550. prompts.append(params)
  551. return prompts
  552. def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
  553. """
  554. Init output moderation.
  555. :return:
  556. """
  557. app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
  558. sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
  559. if sensitive_word_avoidance:
  560. return OutputModerationHandler(
  561. tenant_id=self._application_generate_entity.tenant_id,
  562. app_id=self._application_generate_entity.app_id,
  563. rule=ModerationRule(
  564. type=sensitive_word_avoidance.type,
  565. config=sensitive_word_avoidance.config
  566. ),
  567. on_message_replace_func=self._queue_manager.publish_message_replace
  568. )