langfuse_trace.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. import json
  2. import logging
  3. import os
  4. from datetime import datetime, timedelta
  5. from typing import Optional
  6. from langfuse import Langfuse # type: ignore
  7. from core.ops.base_trace_instance import BaseTraceInstance
  8. from core.ops.entities.config_entity import LangfuseConfig
  9. from core.ops.entities.trace_entity import (
  10. BaseTraceInfo,
  11. DatasetRetrievalTraceInfo,
  12. GenerateNameTraceInfo,
  13. MessageTraceInfo,
  14. ModerationTraceInfo,
  15. SuggestedQuestionTraceInfo,
  16. ToolTraceInfo,
  17. TraceTaskName,
  18. WorkflowTraceInfo,
  19. )
  20. from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
  21. GenerationUsage,
  22. LangfuseGeneration,
  23. LangfuseSpan,
  24. LangfuseTrace,
  25. LevelEnum,
  26. UnitEnum,
  27. )
  28. from core.ops.utils import filter_none_values
  29. from extensions.ext_database import db
  30. from models.model import EndUser
  31. from models.workflow import WorkflowNodeExecution
  32. logger = logging.getLogger(__name__)
  33. class LangFuseDataTrace(BaseTraceInstance):
  34. def __init__(
  35. self,
  36. langfuse_config: LangfuseConfig,
  37. ):
  38. super().__init__(langfuse_config)
  39. self.langfuse_client = Langfuse(
  40. public_key=langfuse_config.public_key,
  41. secret_key=langfuse_config.secret_key,
  42. host=langfuse_config.host,
  43. )
  44. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  45. def trace(self, trace_info: BaseTraceInfo):
  46. if isinstance(trace_info, WorkflowTraceInfo):
  47. self.workflow_trace(trace_info)
  48. if isinstance(trace_info, MessageTraceInfo):
  49. self.message_trace(trace_info)
  50. if isinstance(trace_info, ModerationTraceInfo):
  51. self.moderation_trace(trace_info)
  52. if isinstance(trace_info, SuggestedQuestionTraceInfo):
  53. self.suggested_question_trace(trace_info)
  54. if isinstance(trace_info, DatasetRetrievalTraceInfo):
  55. self.dataset_retrieval_trace(trace_info)
  56. if isinstance(trace_info, ToolTraceInfo):
  57. self.tool_trace(trace_info)
  58. if isinstance(trace_info, GenerateNameTraceInfo):
  59. self.generate_name_trace(trace_info)
  60. def workflow_trace(self, trace_info: WorkflowTraceInfo):
  61. trace_id = trace_info.workflow_run_id
  62. user_id = trace_info.metadata.get("user_id")
  63. metadata = trace_info.metadata
  64. metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
  65. if trace_info.message_id:
  66. trace_id = trace_info.message_id
  67. name = TraceTaskName.MESSAGE_TRACE.value
  68. trace_data = LangfuseTrace(
  69. id=trace_id,
  70. user_id=user_id,
  71. name=name,
  72. input=dict(trace_info.workflow_run_inputs),
  73. output=dict(trace_info.workflow_run_outputs),
  74. metadata=metadata,
  75. session_id=trace_info.conversation_id,
  76. tags=["message", "workflow"],
  77. )
  78. self.add_trace(langfuse_trace_data=trace_data)
  79. workflow_span_data = LangfuseSpan(
  80. id=trace_info.workflow_run_id,
  81. name=TraceTaskName.WORKFLOW_TRACE.value,
  82. input=dict(trace_info.workflow_run_inputs),
  83. output=dict(trace_info.workflow_run_outputs),
  84. trace_id=trace_id,
  85. start_time=trace_info.start_time,
  86. end_time=trace_info.end_time,
  87. metadata=metadata,
  88. level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
  89. status_message=trace_info.error or "",
  90. )
  91. self.add_span(langfuse_span_data=workflow_span_data)
  92. else:
  93. trace_data = LangfuseTrace(
  94. id=trace_id,
  95. user_id=user_id,
  96. name=TraceTaskName.WORKFLOW_TRACE.value,
  97. input=dict(trace_info.workflow_run_inputs),
  98. output=dict(trace_info.workflow_run_outputs),
  99. metadata=metadata,
  100. session_id=trace_info.conversation_id,
  101. tags=["workflow"],
  102. )
  103. self.add_trace(langfuse_trace_data=trace_data)
  104. # through workflow_run_id get all_nodes_execution
  105. workflow_nodes_execution_id_records = (
  106. db.session.query(WorkflowNodeExecution.id)
  107. .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
  108. .all()
  109. )
  110. for node_execution_id_record in workflow_nodes_execution_id_records:
  111. node_execution = (
  112. db.session.query(
  113. WorkflowNodeExecution.id,
  114. WorkflowNodeExecution.tenant_id,
  115. WorkflowNodeExecution.app_id,
  116. WorkflowNodeExecution.title,
  117. WorkflowNodeExecution.node_type,
  118. WorkflowNodeExecution.status,
  119. WorkflowNodeExecution.inputs,
  120. WorkflowNodeExecution.outputs,
  121. WorkflowNodeExecution.created_at,
  122. WorkflowNodeExecution.elapsed_time,
  123. WorkflowNodeExecution.process_data,
  124. WorkflowNodeExecution.execution_metadata,
  125. )
  126. .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
  127. .first()
  128. )
  129. if not node_execution:
  130. continue
  131. node_execution_id = node_execution.id
  132. tenant_id = node_execution.tenant_id
  133. app_id = node_execution.app_id
  134. node_name = node_execution.title
  135. node_type = node_execution.node_type
  136. status = node_execution.status
  137. if node_type == "llm":
  138. inputs = (
  139. json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
  140. )
  141. else:
  142. inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
  143. outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
  144. created_at = node_execution.created_at or datetime.now()
  145. elapsed_time = node_execution.elapsed_time
  146. finished_at = created_at + timedelta(seconds=elapsed_time)
  147. metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
  148. metadata.update(
  149. {
  150. "workflow_run_id": trace_info.workflow_run_id,
  151. "node_execution_id": node_execution_id,
  152. "tenant_id": tenant_id,
  153. "app_id": app_id,
  154. "node_name": node_name,
  155. "node_type": node_type,
  156. "status": status,
  157. }
  158. )
  159. process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
  160. model_provider = process_data.get("model_provider", None)
  161. model_name = process_data.get("model_name", None)
  162. if model_provider is not None and model_name is not None:
  163. metadata.update(
  164. {
  165. "model_provider": model_provider,
  166. "model_name": model_name,
  167. }
  168. )
  169. # add span
  170. if trace_info.message_id:
  171. span_data = LangfuseSpan(
  172. id=node_execution_id,
  173. name=node_type,
  174. input=inputs,
  175. output=outputs,
  176. trace_id=trace_id,
  177. start_time=created_at,
  178. end_time=finished_at,
  179. metadata=metadata,
  180. level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
  181. status_message=trace_info.error or "",
  182. parent_observation_id=trace_info.workflow_run_id,
  183. )
  184. else:
  185. span_data = LangfuseSpan(
  186. id=node_execution_id,
  187. name=node_type,
  188. input=inputs,
  189. output=outputs,
  190. trace_id=trace_id,
  191. start_time=created_at,
  192. end_time=finished_at,
  193. metadata=metadata,
  194. level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
  195. status_message=trace_info.error or "",
  196. )
  197. self.add_span(langfuse_span_data=span_data)
  198. if process_data and process_data.get("model_mode") == "chat":
  199. total_token = metadata.get("total_tokens", 0)
  200. # add generation
  201. generation_usage = GenerationUsage(
  202. total=total_token,
  203. )
  204. node_generation_data = LangfuseGeneration(
  205. name="llm",
  206. trace_id=trace_id,
  207. model=process_data.get("model_name"),
  208. parent_observation_id=node_execution_id,
  209. start_time=created_at,
  210. end_time=finished_at,
  211. input=inputs,
  212. output=outputs,
  213. metadata=metadata,
  214. level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
  215. status_message=trace_info.error or "",
  216. usage=generation_usage,
  217. )
  218. self.add_generation(langfuse_generation_data=node_generation_data)
  219. def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
  220. # get message file data
  221. file_list = trace_info.file_list
  222. metadata = trace_info.metadata
  223. message_data = trace_info.message_data
  224. if message_data is None:
  225. return
  226. message_id = message_data.id
  227. user_id = message_data.from_account_id
  228. if message_data.from_end_user_id:
  229. end_user_data: Optional[EndUser] = (
  230. db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
  231. )
  232. if end_user_data is not None:
  233. user_id = end_user_data.session_id
  234. metadata["user_id"] = user_id
  235. trace_data = LangfuseTrace(
  236. id=message_id,
  237. user_id=user_id,
  238. name=TraceTaskName.MESSAGE_TRACE.value,
  239. input={
  240. "message": trace_info.inputs,
  241. "files": file_list,
  242. "message_tokens": trace_info.message_tokens,
  243. "answer_tokens": trace_info.answer_tokens,
  244. "total_tokens": trace_info.total_tokens,
  245. "error": trace_info.error,
  246. "provider_response_latency": message_data.provider_response_latency,
  247. "created_at": trace_info.start_time,
  248. },
  249. output=trace_info.outputs,
  250. metadata=metadata,
  251. session_id=message_data.conversation_id,
  252. tags=["message", str(trace_info.conversation_mode)],
  253. version=None,
  254. release=None,
  255. public=None,
  256. )
  257. self.add_trace(langfuse_trace_data=trace_data)
  258. # start add span
  259. generation_usage = GenerationUsage(
  260. input=trace_info.message_tokens,
  261. output=trace_info.answer_tokens,
  262. total=trace_info.total_tokens,
  263. unit=UnitEnum.TOKENS,
  264. totalCost=message_data.total_price,
  265. )
  266. langfuse_generation_data = LangfuseGeneration(
  267. name="llm",
  268. trace_id=message_id,
  269. start_time=trace_info.start_time,
  270. end_time=trace_info.end_time,
  271. model=message_data.model_id,
  272. input=trace_info.inputs,
  273. output=message_data.answer,
  274. metadata=metadata,
  275. level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
  276. status_message=message_data.error or "",
  277. usage=generation_usage,
  278. )
  279. self.add_generation(langfuse_generation_data)
  280. def moderation_trace(self, trace_info: ModerationTraceInfo):
  281. if trace_info.message_data is None:
  282. return
  283. span_data = LangfuseSpan(
  284. name=TraceTaskName.MODERATION_TRACE.value,
  285. input=trace_info.inputs,
  286. output={
  287. "action": trace_info.action,
  288. "flagged": trace_info.flagged,
  289. "preset_response": trace_info.preset_response,
  290. "inputs": trace_info.inputs,
  291. },
  292. trace_id=trace_info.message_id,
  293. start_time=trace_info.start_time or trace_info.message_data.created_at,
  294. end_time=trace_info.end_time or trace_info.message_data.created_at,
  295. metadata=trace_info.metadata,
  296. )
  297. self.add_span(langfuse_span_data=span_data)
  298. def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
  299. message_data = trace_info.message_data
  300. if message_data is None:
  301. return
  302. generation_usage = GenerationUsage(
  303. total=len(str(trace_info.suggested_question)),
  304. input=len(trace_info.inputs) if trace_info.inputs else 0,
  305. output=len(trace_info.suggested_question),
  306. unit=UnitEnum.CHARACTERS,
  307. )
  308. generation_data = LangfuseGeneration(
  309. name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
  310. input=trace_info.inputs,
  311. output=str(trace_info.suggested_question),
  312. trace_id=trace_info.message_id,
  313. start_time=trace_info.start_time,
  314. end_time=trace_info.end_time,
  315. metadata=trace_info.metadata,
  316. level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
  317. status_message=message_data.error or "",
  318. usage=generation_usage,
  319. )
  320. self.add_generation(langfuse_generation_data=generation_data)
  321. def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
  322. if trace_info.message_data is None:
  323. return
  324. dataset_retrieval_span_data = LangfuseSpan(
  325. name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
  326. input=trace_info.inputs,
  327. output={"documents": trace_info.documents},
  328. trace_id=trace_info.message_id,
  329. start_time=trace_info.start_time or trace_info.message_data.created_at,
  330. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  331. metadata=trace_info.metadata,
  332. )
  333. self.add_span(langfuse_span_data=dataset_retrieval_span_data)
  334. def tool_trace(self, trace_info: ToolTraceInfo):
  335. tool_span_data = LangfuseSpan(
  336. name=trace_info.tool_name,
  337. input=trace_info.tool_inputs,
  338. output=trace_info.tool_outputs,
  339. trace_id=trace_info.message_id,
  340. start_time=trace_info.start_time,
  341. end_time=trace_info.end_time,
  342. metadata=trace_info.metadata,
  343. level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR),
  344. status_message=trace_info.error,
  345. )
  346. self.add_span(langfuse_span_data=tool_span_data)
  347. def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
  348. name_generation_trace_data = LangfuseTrace(
  349. name=TraceTaskName.GENERATE_NAME_TRACE.value,
  350. input=trace_info.inputs,
  351. output=trace_info.outputs,
  352. user_id=trace_info.tenant_id,
  353. metadata=trace_info.metadata,
  354. session_id=trace_info.conversation_id,
  355. )
  356. self.add_trace(langfuse_trace_data=name_generation_trace_data)
  357. name_generation_span_data = LangfuseSpan(
  358. name=TraceTaskName.GENERATE_NAME_TRACE.value,
  359. input=trace_info.inputs,
  360. output=trace_info.outputs,
  361. trace_id=trace_info.conversation_id,
  362. start_time=trace_info.start_time,
  363. end_time=trace_info.end_time,
  364. metadata=trace_info.metadata,
  365. )
  366. self.add_span(langfuse_span_data=name_generation_span_data)
  367. def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None):
  368. format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
  369. try:
  370. self.langfuse_client.trace(**format_trace_data)
  371. logger.debug("LangFuse Trace created successfully")
  372. except Exception as e:
  373. raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
  374. def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None):
  375. format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
  376. try:
  377. self.langfuse_client.span(**format_span_data)
  378. logger.debug("LangFuse Span created successfully")
  379. except Exception as e:
  380. raise ValueError(f"LangFuse Failed to create span: {str(e)}")
  381. def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None):
  382. format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
  383. span.end(**format_span_data)
  384. def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None):
  385. format_generation_data = (
  386. filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
  387. )
  388. try:
  389. self.langfuse_client.generation(**format_generation_data)
  390. logger.debug("LangFuse Generation created successfully")
  391. except Exception as e:
  392. raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
  393. def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None):
  394. format_generation_data = (
  395. filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
  396. )
  397. generation.end(**format_generation_data)
  398. def api_check(self):
  399. try:
  400. return self.langfuse_client.auth_check()
  401. except Exception as e:
  402. logger.debug(f"LangFuse API check failed: {str(e)}")
  403. raise ValueError(f"LangFuse API check failed: {str(e)}")
  404. def get_project_key(self):
  405. try:
  406. projects = self.langfuse_client.client.projects.get()
  407. return projects.data[0].id
  408. except Exception as e:
  409. logger.debug(f"LangFuse get project key failed: {str(e)}")
  410. raise ValueError(f"LangFuse get project key failed: {str(e)}")