completion.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # -*- coding:utf-8 -*-
  2. import json
  3. import logging
  4. from typing import Generator, Union
  5. import services
  6. from controllers.web import api
  7. from controllers.web.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
  8. NotChatAppError, NotCompletionAppError, ProviderModelCurrentlyNotSupportError,
  9. ProviderNotInitializeError, ProviderQuotaExceededError)
  10. from controllers.web.wraps import WebApiResource
  11. from core.application_queue_manager import ApplicationQueueManager
  12. from core.entities.application_entities import InvokeFrom
  13. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  14. from core.model_runtime.errors.invoke import InvokeError
  15. from flask import Response, stream_with_context
  16. from flask_restful import reqparse
  17. from libs.helper import uuid_value
  18. from services.completion_service import CompletionService
  19. from werkzeug.exceptions import InternalServerError, NotFound
  20. # define completion api for user
  21. class CompletionApi(WebApiResource):
  22. def post(self, app_model, end_user):
  23. if app_model.mode != 'completion':
  24. raise NotCompletionAppError()
  25. parser = reqparse.RequestParser()
  26. parser.add_argument('inputs', type=dict, required=True, location='json')
  27. parser.add_argument('query', type=str, location='json', default='')
  28. parser.add_argument('files', type=list, required=False, location='json')
  29. parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
  30. parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
  31. args = parser.parse_args()
  32. streaming = args['response_mode'] == 'streaming'
  33. args['auto_generate_name'] = False
  34. try:
  35. response = CompletionService.completion(
  36. app_model=app_model,
  37. user=end_user,
  38. args=args,
  39. invoke_from=InvokeFrom.WEB_APP,
  40. streaming=streaming
  41. )
  42. return compact_response(response)
  43. except services.errors.conversation.ConversationNotExistsError:
  44. raise NotFound("Conversation Not Exists.")
  45. except services.errors.conversation.ConversationCompletedError:
  46. raise ConversationCompletedError()
  47. except services.errors.app_model_config.AppModelConfigBrokenError:
  48. logging.exception("App model config broken.")
  49. raise AppUnavailableError()
  50. except ProviderTokenNotInitError as ex:
  51. raise ProviderNotInitializeError(ex.description)
  52. except QuotaExceededError:
  53. raise ProviderQuotaExceededError()
  54. except ModelCurrentlyNotSupportError:
  55. raise ProviderModelCurrentlyNotSupportError()
  56. except InvokeError as e:
  57. raise CompletionRequestError(e.description)
  58. except ValueError as e:
  59. raise e
  60. except Exception as e:
  61. logging.exception("internal server error.")
  62. raise InternalServerError()
  63. class CompletionStopApi(WebApiResource):
  64. def post(self, app_model, end_user, task_id):
  65. if app_model.mode != 'completion':
  66. raise NotCompletionAppError()
  67. ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
  68. return {'result': 'success'}, 200
  69. class ChatApi(WebApiResource):
  70. def post(self, app_model, end_user):
  71. if app_model.mode != 'chat':
  72. raise NotChatAppError()
  73. parser = reqparse.RequestParser()
  74. parser.add_argument('inputs', type=dict, required=True, location='json')
  75. parser.add_argument('query', type=str, required=True, location='json')
  76. parser.add_argument('files', type=list, required=False, location='json')
  77. parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
  78. parser.add_argument('conversation_id', type=uuid_value, location='json')
  79. parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
  80. args = parser.parse_args()
  81. streaming = args['response_mode'] == 'streaming'
  82. args['auto_generate_name'] = False
  83. try:
  84. response = CompletionService.completion(
  85. app_model=app_model,
  86. user=end_user,
  87. args=args,
  88. invoke_from=InvokeFrom.WEB_APP,
  89. streaming=streaming
  90. )
  91. return compact_response(response)
  92. except services.errors.conversation.ConversationNotExistsError:
  93. raise NotFound("Conversation Not Exists.")
  94. except services.errors.conversation.ConversationCompletedError:
  95. raise ConversationCompletedError()
  96. except services.errors.app_model_config.AppModelConfigBrokenError:
  97. logging.exception("App model config broken.")
  98. raise AppUnavailableError()
  99. except ProviderTokenNotInitError as ex:
  100. raise ProviderNotInitializeError(ex.description)
  101. except QuotaExceededError:
  102. raise ProviderQuotaExceededError()
  103. except ModelCurrentlyNotSupportError:
  104. raise ProviderModelCurrentlyNotSupportError()
  105. except InvokeError as e:
  106. raise CompletionRequestError(e.description)
  107. except ValueError as e:
  108. raise e
  109. except Exception as e:
  110. logging.exception("internal server error.")
  111. raise InternalServerError()
  112. class ChatStopApi(WebApiResource):
  113. def post(self, app_model, end_user, task_id):
  114. if app_model.mode != 'chat':
  115. raise NotChatAppError()
  116. ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
  117. return {'result': 'success'}, 200
  118. def compact_response(response: Union[dict, Generator]) -> Response:
  119. if isinstance(response, dict):
  120. return Response(response=json.dumps(response), status=200, mimetype='application/json')
  121. else:
  122. def generate() -> Generator:
  123. for chunk in response:
  124. yield chunk
  125. return Response(stream_with_context(generate()), status=200,
  126. mimetype='text/event-stream')
  127. api.add_resource(CompletionApi, '/completion-messages')
  128. api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
  129. api.add_resource(ChatApi, '/chat-messages')
  130. api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')