langsmith_trace.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import json
  2. import logging
  3. import os
  4. from datetime import datetime, timedelta
  5. from langsmith import Client
  6. from core.ops.base_trace_instance import BaseTraceInstance
  7. from core.ops.entities.config_entity import LangSmithConfig
  8. from core.ops.entities.trace_entity import (
  9. BaseTraceInfo,
  10. DatasetRetrievalTraceInfo,
  11. GenerateNameTraceInfo,
  12. MessageTraceInfo,
  13. ModerationTraceInfo,
  14. SuggestedQuestionTraceInfo,
  15. ToolTraceInfo,
  16. WorkflowTraceInfo,
  17. )
  18. from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
  19. LangSmithRunModel,
  20. LangSmithRunType,
  21. LangSmithRunUpdateModel,
  22. )
  23. from core.ops.utils import filter_none_values
  24. from extensions.ext_database import db
  25. from models.model import EndUser, MessageFile
  26. from models.workflow import WorkflowNodeExecution
  27. logger = logging.getLogger(__name__)
  28. class LangSmithDataTrace(BaseTraceInstance):
  29. def __init__(
  30. self,
  31. langsmith_config: LangSmithConfig,
  32. ):
  33. super().__init__(langsmith_config)
  34. self.langsmith_key = langsmith_config.api_key
  35. self.project_name = langsmith_config.project
  36. self.project_id = None
  37. self.langsmith_client = Client(
  38. api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint
  39. )
  40. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  41. def trace(self, trace_info: BaseTraceInfo):
  42. if isinstance(trace_info, WorkflowTraceInfo):
  43. self.workflow_trace(trace_info)
  44. if isinstance(trace_info, MessageTraceInfo):
  45. self.message_trace(trace_info)
  46. if isinstance(trace_info, ModerationTraceInfo):
  47. self.moderation_trace(trace_info)
  48. if isinstance(trace_info, SuggestedQuestionTraceInfo):
  49. self.suggested_question_trace(trace_info)
  50. if isinstance(trace_info, DatasetRetrievalTraceInfo):
  51. self.dataset_retrieval_trace(trace_info)
  52. if isinstance(trace_info, ToolTraceInfo):
  53. self.tool_trace(trace_info)
  54. if isinstance(trace_info, GenerateNameTraceInfo):
  55. self.generate_name_trace(trace_info)
  56. def workflow_trace(self, trace_info: WorkflowTraceInfo):
  57. if trace_info.message_id:
  58. message_run = LangSmithRunModel(
  59. id=trace_info.message_id,
  60. name=f"message_{trace_info.message_id}",
  61. inputs=trace_info.workflow_run_inputs,
  62. outputs=trace_info.workflow_run_outputs,
  63. run_type=LangSmithRunType.chain,
  64. start_time=trace_info.start_time,
  65. end_time=trace_info.end_time,
  66. extra={
  67. "metadata": trace_info.metadata,
  68. },
  69. tags=["message"],
  70. error=trace_info.error
  71. )
  72. self.add_run(message_run)
  73. langsmith_run = LangSmithRunModel(
  74. file_list=trace_info.file_list,
  75. total_tokens=trace_info.total_tokens,
  76. id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
  77. name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
  78. inputs=trace_info.workflow_run_inputs,
  79. run_type=LangSmithRunType.tool,
  80. start_time=trace_info.workflow_data.created_at,
  81. end_time=trace_info.workflow_data.finished_at,
  82. outputs=trace_info.workflow_run_outputs,
  83. extra={
  84. "metadata": trace_info.metadata,
  85. },
  86. error=trace_info.error,
  87. tags=["workflow"],
  88. parent_run_id=trace_info.message_id if trace_info.message_id else None,
  89. )
  90. self.add_run(langsmith_run)
  91. # through workflow_run_id get all_nodes_execution
  92. workflow_nodes_executions = (
  93. db.session.query(
  94. WorkflowNodeExecution.id,
  95. WorkflowNodeExecution.tenant_id,
  96. WorkflowNodeExecution.app_id,
  97. WorkflowNodeExecution.title,
  98. WorkflowNodeExecution.node_type,
  99. WorkflowNodeExecution.status,
  100. WorkflowNodeExecution.inputs,
  101. WorkflowNodeExecution.outputs,
  102. WorkflowNodeExecution.created_at,
  103. WorkflowNodeExecution.elapsed_time,
  104. WorkflowNodeExecution.process_data,
  105. WorkflowNodeExecution.execution_metadata,
  106. )
  107. .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
  108. .all()
  109. )
  110. for node_execution in workflow_nodes_executions:
  111. node_execution_id = node_execution.id
  112. tenant_id = node_execution.tenant_id
  113. app_id = node_execution.app_id
  114. node_name = node_execution.title
  115. node_type = node_execution.node_type
  116. status = node_execution.status
  117. if node_type == "llm":
  118. inputs = json.loads(node_execution.process_data).get(
  119. "prompts", {}
  120. ) if node_execution.process_data else {}
  121. else:
  122. inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
  123. outputs = (
  124. json.loads(node_execution.outputs) if node_execution.outputs else {}
  125. )
  126. created_at = node_execution.created_at if node_execution.created_at else datetime.now()
  127. elapsed_time = node_execution.elapsed_time
  128. finished_at = created_at + timedelta(seconds=elapsed_time)
  129. execution_metadata = (
  130. json.loads(node_execution.execution_metadata)
  131. if node_execution.execution_metadata
  132. else {}
  133. )
  134. node_total_tokens = execution_metadata.get("total_tokens", 0)
  135. metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
  136. metadata.update(
  137. {
  138. "workflow_run_id": trace_info.workflow_run_id,
  139. "node_execution_id": node_execution_id,
  140. "tenant_id": tenant_id,
  141. "app_id": app_id,
  142. "app_name": node_name,
  143. "node_type": node_type,
  144. "status": status,
  145. }
  146. )
  147. process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
  148. if process_data and process_data.get("model_mode") == "chat":
  149. run_type = LangSmithRunType.llm
  150. elif node_type == "knowledge-retrieval":
  151. run_type = LangSmithRunType.retriever
  152. else:
  153. run_type = LangSmithRunType.tool
  154. langsmith_run = LangSmithRunModel(
  155. total_tokens=node_total_tokens,
  156. name=f"{node_name}_{node_execution_id}",
  157. inputs=inputs,
  158. run_type=run_type,
  159. start_time=created_at,
  160. end_time=finished_at,
  161. outputs=outputs,
  162. file_list=trace_info.file_list,
  163. extra={
  164. "metadata": metadata,
  165. },
  166. parent_run_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
  167. tags=["node_execution"],
  168. )
  169. self.add_run(langsmith_run)
  170. def message_trace(self, trace_info: MessageTraceInfo):
  171. # get message file data
  172. file_list = trace_info.file_list
  173. message_file_data: MessageFile = trace_info.message_file_data
  174. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  175. file_list.append(file_url)
  176. metadata = trace_info.metadata
  177. message_data = trace_info.message_data
  178. message_id = message_data.id
  179. user_id = message_data.from_account_id
  180. metadata["user_id"] = user_id
  181. if message_data.from_end_user_id:
  182. end_user_data: EndUser = db.session.query(EndUser).filter(
  183. EndUser.id == message_data.from_end_user_id
  184. ).first()
  185. if end_user_data is not None:
  186. end_user_id = end_user_data.session_id
  187. metadata["end_user_id"] = end_user_id
  188. message_run = LangSmithRunModel(
  189. input_tokens=trace_info.message_tokens,
  190. output_tokens=trace_info.answer_tokens,
  191. total_tokens=trace_info.total_tokens,
  192. id=message_id,
  193. name=f"message_{message_id}",
  194. inputs=trace_info.inputs,
  195. run_type=LangSmithRunType.chain,
  196. start_time=trace_info.start_time,
  197. end_time=trace_info.end_time,
  198. outputs=message_data.answer,
  199. extra={
  200. "metadata": metadata,
  201. },
  202. tags=["message", str(trace_info.conversation_mode)],
  203. error=trace_info.error,
  204. file_list=file_list,
  205. )
  206. self.add_run(message_run)
  207. # create llm run parented to message run
  208. llm_run = LangSmithRunModel(
  209. input_tokens=trace_info.message_tokens,
  210. output_tokens=trace_info.answer_tokens,
  211. total_tokens=trace_info.total_tokens,
  212. name=f"llm_{message_id}",
  213. inputs=trace_info.inputs,
  214. run_type=LangSmithRunType.llm,
  215. start_time=trace_info.start_time,
  216. end_time=trace_info.end_time,
  217. outputs=message_data.answer,
  218. extra={
  219. "metadata": metadata,
  220. },
  221. parent_run_id=message_id,
  222. tags=["llm", str(trace_info.conversation_mode)],
  223. error=trace_info.error,
  224. file_list=file_list,
  225. )
  226. self.add_run(llm_run)
  227. def moderation_trace(self, trace_info: ModerationTraceInfo):
  228. langsmith_run = LangSmithRunModel(
  229. name="moderation",
  230. inputs=trace_info.inputs,
  231. outputs={
  232. "action": trace_info.action,
  233. "flagged": trace_info.flagged,
  234. "preset_response": trace_info.preset_response,
  235. "inputs": trace_info.inputs,
  236. },
  237. run_type=LangSmithRunType.tool,
  238. extra={
  239. "metadata": trace_info.metadata,
  240. },
  241. tags=["moderation"],
  242. parent_run_id=trace_info.message_id,
  243. start_time=trace_info.start_time or trace_info.message_data.created_at,
  244. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  245. )
  246. self.add_run(langsmith_run)
  247. def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
  248. message_data = trace_info.message_data
  249. suggested_question_run = LangSmithRunModel(
  250. name="suggested_question",
  251. inputs=trace_info.inputs,
  252. outputs=trace_info.suggested_question,
  253. run_type=LangSmithRunType.tool,
  254. extra={
  255. "metadata": trace_info.metadata,
  256. },
  257. tags=["suggested_question"],
  258. parent_run_id=trace_info.message_id,
  259. start_time=trace_info.start_time or message_data.created_at,
  260. end_time=trace_info.end_time or message_data.updated_at,
  261. )
  262. self.add_run(suggested_question_run)
  263. def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
  264. dataset_retrieval_run = LangSmithRunModel(
  265. name="dataset_retrieval",
  266. inputs=trace_info.inputs,
  267. outputs={"documents": trace_info.documents},
  268. run_type=LangSmithRunType.retriever,
  269. extra={
  270. "metadata": trace_info.metadata,
  271. },
  272. tags=["dataset_retrieval"],
  273. parent_run_id=trace_info.message_id,
  274. start_time=trace_info.start_time or trace_info.message_data.created_at,
  275. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  276. )
  277. self.add_run(dataset_retrieval_run)
  278. def tool_trace(self, trace_info: ToolTraceInfo):
  279. tool_run = LangSmithRunModel(
  280. name=trace_info.tool_name,
  281. inputs=trace_info.tool_inputs,
  282. outputs=trace_info.tool_outputs,
  283. run_type=LangSmithRunType.tool,
  284. extra={
  285. "metadata": trace_info.metadata,
  286. },
  287. tags=["tool", trace_info.tool_name],
  288. parent_run_id=trace_info.message_id,
  289. start_time=trace_info.start_time,
  290. end_time=trace_info.end_time,
  291. file_list=[trace_info.file_url],
  292. )
  293. self.add_run(tool_run)
  294. def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
  295. name_run = LangSmithRunModel(
  296. name="generate_name",
  297. inputs=trace_info.inputs,
  298. outputs=trace_info.outputs,
  299. run_type=LangSmithRunType.tool,
  300. extra={
  301. "metadata": trace_info.metadata,
  302. },
  303. tags=["generate_name"],
  304. start_time=trace_info.start_time or datetime.now(),
  305. end_time=trace_info.end_time or datetime.now(),
  306. )
  307. self.add_run(name_run)
  308. def add_run(self, run_data: LangSmithRunModel):
  309. data = run_data.model_dump()
  310. if self.project_id:
  311. data["session_id"] = self.project_id
  312. elif self.project_name:
  313. data["session_name"] = self.project_name
  314. data = filter_none_values(data)
  315. try:
  316. self.langsmith_client.create_run(**data)
  317. logger.debug("LangSmith Run created successfully.")
  318. except Exception as e:
  319. raise ValueError(f"LangSmith Failed to create run: {str(e)}")
  320. def update_run(self, update_run_data: LangSmithRunUpdateModel):
  321. data = update_run_data.model_dump()
  322. data = filter_none_values(data)
  323. try:
  324. self.langsmith_client.update_run(**data)
  325. logger.debug("LangSmith Run updated successfully.")
  326. except Exception as e:
  327. raise ValueError(f"LangSmith Failed to update run: {str(e)}")
  328. def api_check(self):
  329. try:
  330. random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
  331. self.langsmith_client.create_project(project_name=random_project_name)
  332. self.langsmith_client.delete_project(project_name=random_project_name)
  333. return True
  334. except Exception as e:
  335. logger.debug(f"LangSmith API check failed: {str(e)}")
  336. raise ValueError(f"LangSmith API check failed: {str(e)}")