|
@@ -1,17 +1,24 @@
|
|
|
+import json
|
|
|
+import os
|
|
|
+import re
|
|
|
from abc import abstractmethod
|
|
|
-from typing import List, Optional, Any, Union
|
|
|
+from typing import List, Optional, Any, Union, Tuple
|
|
|
import decimal
|
|
|
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
|
+from langchain.memory.chat_memory import BaseChatMemory
|
|
|
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
|
|
|
|
|
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
|
-from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
|
|
|
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
|
|
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
|
|
from core.model_providers.providers.base import BaseModelProvider
|
|
|
+from core.prompt.prompt_builder import PromptBuilder
|
|
|
+from core.prompt.prompt_template import JinjaPromptTemplate
|
|
|
from core.third_party.langchain.llms.fake import FakeLLM
|
|
|
import logging
|
|
|
+
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
@@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
|
|
|
def price_config(self) -> dict:
|
|
|
def get_or_default():
|
|
|
default_price_config = {
|
|
|
- 'prompt': decimal.Decimal('0'),
|
|
|
- 'completion': decimal.Decimal('0'),
|
|
|
- 'unit': decimal.Decimal('0'),
|
|
|
- 'currency': 'USD'
|
|
|
- }
|
|
|
+ 'prompt': decimal.Decimal('0'),
|
|
|
+ 'completion': decimal.Decimal('0'),
|
|
|
+ 'unit': decimal.Decimal('0'),
|
|
|
+ 'currency': 'USD'
|
|
|
+ }
|
|
|
rules = self.model_provider.get_rules()
|
|
|
- price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
|
|
+ price_config = rules['price_config'][
|
|
|
+ self.base_model_name] if 'price_config' in rules else default_price_config
|
|
|
price_config = {
|
|
|
'prompt': decimal.Decimal(price_config['prompt']),
|
|
|
'completion': decimal.Decimal(price_config['completion']),
|
|
@@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
|
|
|
'currency': price_config['currency']
|
|
|
}
|
|
|
return price_config
|
|
|
-
|
|
|
+
|
|
|
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
|
|
|
|
|
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
|
@@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
|
|
|
total_tokens = result.llm_output['token_usage']['total_tokens']
|
|
|
else:
|
|
|
prompt_tokens = self.get_num_tokens(messages)
|
|
|
- completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
|
|
+ completion_tokens = self.get_num_tokens(
|
|
|
+ [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
|
|
|
|
self.model_provider.update_last_used()
|
|
@@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
|
|
|
def support_streaming(cls):
|
|
|
return False
|
|
|
|
|
|
+ def get_prompt(self, mode: str,
|
|
|
+ pre_prompt: str, inputs: dict,
|
|
|
+ query: str,
|
|
|
+ context: Optional[str],
|
|
|
+ memory: Optional[BaseChatMemory]) -> \
|
|
|
+ Tuple[List[PromptMessage], Optional[List[str]]]:
|
|
|
+ prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
|
|
|
+ prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
|
|
|
+ return [PromptMessage(content=prompt)], stops
|
|
|
+
|
|
|
+ def prompt_file_name(self, mode: str) -> str:
|
|
|
+ if mode == 'completion':
|
|
|
+ return 'common_completion'
|
|
|
+ else:
|
|
|
+ return 'common_chat'
|
|
|
+
|
|
|
+ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
|
|
+ query: str,
|
|
|
+ context: Optional[str],
|
|
|
+ memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
|
|
|
+ context_prompt_content = ''
|
|
|
+ if context and 'context_prompt' in prompt_rules:
|
|
|
+ prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
|
|
|
+ context_prompt_content = prompt_template.format(
|
|
|
+ context=context
|
|
|
+ )
|
|
|
+
|
|
|
+ pre_prompt_content = ''
|
|
|
+ if pre_prompt:
|
|
|
+ prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
|
|
|
+ prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
|
|
+ pre_prompt_content = prompt_template.format(
|
|
|
+ **prompt_inputs
|
|
|
+ )
|
|
|
+
|
|
|
+ prompt = ''
|
|
|
+ for order in prompt_rules['system_prompt_orders']:
|
|
|
+ if order == 'context_prompt':
|
|
|
+ prompt += context_prompt_content
|
|
|
+ elif order == 'pre_prompt':
|
|
|
+ prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
|
|
|
+
|
|
|
+ query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
|
|
|
+
|
|
|
+ if memory and 'histories_prompt' in prompt_rules:
|
|
|
+ # append chat histories
|
|
|
+ tmp_human_message = PromptBuilder.to_human_message(
|
|
|
+ prompt_content=prompt + query_prompt,
|
|
|
+ inputs={
|
|
|
+ 'query': query
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.model_rules.max_tokens.max:
|
|
|
+ curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
|
|
|
+ max_tokens = self.model_kwargs.max_tokens
|
|
|
+ rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
|
|
|
+ rest_tokens = max(rest_tokens, 0)
|
|
|
+ else:
|
|
|
+ rest_tokens = 2000
|
|
|
+
|
|
|
+ memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
|
|
|
+ memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
|
|
+
|
|
|
+ histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
|
|
+ prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
|
|
|
+ histories_prompt_content = prompt_template.format(
|
|
|
+ histories=histories
|
|
|
+ )
|
|
|
+
|
|
|
+ prompt = ''
|
|
|
+ for order in prompt_rules['system_prompt_orders']:
|
|
|
+ if order == 'context_prompt':
|
|
|
+ prompt += context_prompt_content
|
|
|
+ elif order == 'pre_prompt':
|
|
|
+ prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
|
|
|
+ elif order == 'histories_prompt':
|
|
|
+ prompt += histories_prompt_content
|
|
|
+
|
|
|
+ prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
|
|
|
+ query_prompt_content = prompt_template.format(
|
|
|
+ query=query
|
|
|
+ )
|
|
|
+
|
|
|
+ prompt += query_prompt_content
|
|
|
+
|
|
|
+ prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
|
|
+
|
|
|
+ stops = prompt_rules.get('stops')
|
|
|
+ if stops is not None and len(stops) == 0:
|
|
|
+ stops = None
|
|
|
+
|
|
|
+ return prompt, stops
|
|
|
+
|
|
|
+ def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
|
|
|
+ # Get the absolute path of the subdirectory
|
|
|
+ prompt_path = os.path.join(
|
|
|
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
|
|
|
+ 'prompt/generate_prompts')
|
|
|
+
|
|
|
+ json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
|
|
|
+ # Open the JSON file and read its content
|
|
|
+ with open(json_file_path, 'r') as json_file:
|
|
|
+ return json.load(json_file)
|
|
|
+
|
|
|
+ def _get_history_messages_from_memory(self, memory: BaseChatMemory,
|
|
|
+ max_token_limit: int) -> str:
|
|
|
+ """Get memory messages."""
|
|
|
+ memory.max_token_limit = max_token_limit
|
|
|
+ memory_key = memory.memory_variables[0]
|
|
|
+ external_context = memory.load_memory_variables({})
|
|
|
+ return external_context[memory_key]
|
|
|
+
|
|
|
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
|
|
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
|
|
if not model_mode:
|