chat.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import json
  2. import logging
  3. from typing import Generator, Union
  4. import services
  5. from controllers.console import api
  6. from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
  7. ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
  8. ProviderQuotaExceededError)
  9. from controllers.console.universal_chat.wraps import UniversalChatResource
  10. from core.application_queue_manager import ApplicationQueueManager
  11. from core.entities.application_entities import InvokeFrom
  12. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  13. from core.model_runtime.errors.invoke import InvokeError
  14. from flask import Response, stream_with_context
  15. from flask_login import current_user
  16. from flask_restful import reqparse
  17. from libs.helper import uuid_value
  18. from services.completion_service import CompletionService
  19. from werkzeug.exceptions import InternalServerError, NotFound
  20. class UniversalChatApi(UniversalChatResource):
  21. def post(self, universal_app):
  22. app_model = universal_app
  23. parser = reqparse.RequestParser()
  24. parser.add_argument('query', type=str, required=True, location='json')
  25. parser.add_argument('files', type=list, required=False, location='json')
  26. parser.add_argument('conversation_id', type=uuid_value, location='json')
  27. parser.add_argument('provider', type=str, required=True, location='json')
  28. parser.add_argument('model', type=str, required=True, location='json')
  29. parser.add_argument('tools', type=list, required=True, location='json')
  30. parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
  31. args = parser.parse_args()
  32. app_model_config = app_model.app_model_config
  33. # update app model config
  34. args['model_config'] = app_model_config.to_dict()
  35. args['model_config']['model']['name'] = args['model']
  36. args['model_config']['model']['provider'] = args['provider']
  37. args['model_config']['agent_mode']['tools'] = args['tools']
  38. if not args['model_config']['agent_mode']['tools']:
  39. args['model_config']['agent_mode']['tools'] = [
  40. {
  41. "current_datetime": {
  42. "enabled": True
  43. }
  44. }
  45. ]
  46. else:
  47. args['model_config']['agent_mode']['tools'].append({
  48. "current_datetime": {
  49. "enabled": True
  50. }
  51. })
  52. args['inputs'] = {}
  53. del args['model']
  54. del args['tools']
  55. args['auto_generate_name'] = False
  56. try:
  57. response = CompletionService.completion(
  58. app_model=app_model,
  59. user=current_user,
  60. args=args,
  61. invoke_from=InvokeFrom.EXPLORE,
  62. streaming=True,
  63. is_model_config_override=True,
  64. )
  65. return compact_response(response)
  66. except services.errors.conversation.ConversationNotExistsError:
  67. raise NotFound("Conversation Not Exists.")
  68. except services.errors.conversation.ConversationCompletedError:
  69. raise ConversationCompletedError()
  70. except services.errors.app_model_config.AppModelConfigBrokenError:
  71. logging.exception("App model config broken.")
  72. raise AppUnavailableError()
  73. except ProviderTokenNotInitError:
  74. raise ProviderNotInitializeError()
  75. except QuotaExceededError:
  76. raise ProviderQuotaExceededError()
  77. except ModelCurrentlyNotSupportError:
  78. raise ProviderModelCurrentlyNotSupportError()
  79. except InvokeError as e:
  80. raise CompletionRequestError(e.description)
  81. except ValueError as e:
  82. raise e
  83. except Exception as e:
  84. logging.exception("internal server error.")
  85. raise InternalServerError()
  86. class UniversalChatStopApi(UniversalChatResource):
  87. def post(self, universal_app, task_id):
  88. ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
  89. return {'result': 'success'}, 200
  90. def compact_response(response: Union[dict, Generator]) -> Response:
  91. if isinstance(response, dict):
  92. return Response(response=json.dumps(response), status=200, mimetype='application/json')
  93. else:
  94. def generate() -> Generator:
  95. for chunk in response:
  96. yield chunk
  97. return Response(stream_with_context(generate()), status=200,
  98. mimetype='text/event-stream')
  99. api.add_resource(UniversalChatApi, '/universal-chat/messages')
  100. api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')