advanced_prompt_template_service.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import copy
  2. from core.prompt.prompt_templates.advanced_prompt_templates import (
  3. BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
  4. BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
  5. BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
  6. BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
  7. BAICHUAN_CONTEXT,
  8. CHAT_APP_CHAT_PROMPT_CONFIG,
  9. CHAT_APP_COMPLETION_PROMPT_CONFIG,
  10. COMPLETION_APP_CHAT_PROMPT_CONFIG,
  11. COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
  12. CONTEXT,
  13. )
  14. from models.model import AppMode
  15. class AdvancedPromptTemplateService:
  16. @classmethod
  17. def get_prompt(cls, args: dict) -> dict:
  18. app_mode = args["app_mode"]
  19. model_mode = args["model_mode"]
  20. model_name = args["model_name"]
  21. has_context = args["has_context"]
  22. if "baichuan" in model_name.lower():
  23. return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
  24. else:
  25. return cls.get_common_prompt(app_mode, model_mode, has_context)
  26. @classmethod
  27. def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
  28. context_prompt = copy.deepcopy(CONTEXT)
  29. if app_mode == AppMode.CHAT.value:
  30. if model_mode == "completion":
  31. return cls.get_completion_prompt(
  32. copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
  33. )
  34. elif model_mode == "chat":
  35. return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
  36. elif app_mode == AppMode.COMPLETION.value:
  37. if model_mode == "completion":
  38. return cls.get_completion_prompt(
  39. copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
  40. )
  41. elif model_mode == "chat":
  42. return cls.get_chat_prompt(
  43. copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
  44. )
  45. # default return empty dict
  46. return {}
  47. @classmethod
  48. def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
  49. if has_context == "true":
  50. prompt_template["completion_prompt_config"]["prompt"]["text"] = (
  51. context + prompt_template["completion_prompt_config"]["prompt"]["text"]
  52. )
  53. return prompt_template
  54. @classmethod
  55. def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
  56. if has_context == "true":
  57. prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
  58. context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
  59. )
  60. return prompt_template
  61. @classmethod
  62. def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
  63. baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
  64. if app_mode == AppMode.CHAT.value:
  65. if model_mode == "completion":
  66. return cls.get_completion_prompt(
  67. copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
  68. )
  69. elif model_mode == "chat":
  70. return cls.get_chat_prompt(
  71. copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
  72. )
  73. elif app_mode == AppMode.COMPLETION.value:
  74. if model_mode == "completion":
  75. return cls.get_completion_prompt(
  76. copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
  77. has_context,
  78. baichuan_context_prompt,
  79. )
  80. elif model_mode == "chat":
  81. return cls.get_chat_prompt(
  82. copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
  83. )
  84. # default return empty dict
  85. return {}