langsmith_trace.py 20 KB


  1. import json
  2. import logging
  3. import os
  4. import uuid
  5. from datetime import datetime, timedelta
  6. from typing import Optional, cast
  7. from langsmith import Client
  8. from langsmith.schemas import RunBase
  9. from core.ops.base_trace_instance import BaseTraceInstance
  10. from core.ops.entities.config_entity import LangSmithConfig
  11. from core.ops.entities.trace_entity import (
  12. BaseTraceInfo,
  13. DatasetRetrievalTraceInfo,
  14. GenerateNameTraceInfo,
  15. MessageTraceInfo,
  16. ModerationTraceInfo,
  17. SuggestedQuestionTraceInfo,
  18. ToolTraceInfo,
  19. TraceTaskName,
  20. WorkflowTraceInfo,
  21. )
  22. from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
  23. LangSmithRunModel,
  24. LangSmithRunType,
  25. LangSmithRunUpdateModel,
  26. )
  27. from core.ops.utils import filter_none_values, generate_dotted_order
  28. from extensions.ext_database import db
  29. from models.model import EndUser, MessageFile
  30. from models.workflow import WorkflowNodeExecution
  31. logger = logging.getLogger(__name__)
  32. class LangSmithDataTrace(BaseTraceInstance):
  33. def __init__(
  34. self,
  35. langsmith_config: LangSmithConfig,
  36. ):
  37. super().__init__(langsmith_config)
  38. self.langsmith_key = langsmith_config.api_key
  39. self.project_name = langsmith_config.project
  40. self.project_id = None
  41. self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
  42. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  43. def trace(self, trace_info: BaseTraceInfo):
  44. if isinstance(trace_info, WorkflowTraceInfo):
  45. self.workflow_trace(trace_info)
  46. if isinstance(trace_info, MessageTraceInfo):
  47. self.message_trace(trace_info)
  48. if isinstance(trace_info, ModerationTraceInfo):
  49. self.moderation_trace(trace_info)
  50. if isinstance(trace_info, SuggestedQuestionTraceInfo):
  51. self.suggested_question_trace(trace_info)
  52. if isinstance(trace_info, DatasetRetrievalTraceInfo):
  53. self.dataset_retrieval_trace(trace_info)
  54. if isinstance(trace_info, ToolTraceInfo):
  55. self.tool_trace(trace_info)
  56. if isinstance(trace_info, GenerateNameTraceInfo):
  57. self.generate_name_trace(trace_info)
  58. def workflow_trace(self, trace_info: WorkflowTraceInfo):
  59. trace_id = trace_info.message_id or trace_info.workflow_run_id
  60. if trace_info.start_time is None:
  61. trace_info.start_time = datetime.now()
  62. message_dotted_order = (
  63. generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
  64. )
  65. workflow_dotted_order = generate_dotted_order(
  66. trace_info.workflow_run_id,
  67. trace_info.workflow_data.created_at,
  68. message_dotted_order,
  69. )
  70. metadata = trace_info.metadata
  71. metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
  72. if trace_info.message_id:
  73. message_run = LangSmithRunModel(
  74. id=trace_info.message_id,
  75. name=TraceTaskName.MESSAGE_TRACE.value,
  76. inputs=dict(trace_info.workflow_run_inputs),
  77. outputs=dict(trace_info.workflow_run_outputs),
  78. run_type=LangSmithRunType.chain,
  79. start_time=trace_info.start_time,
  80. end_time=trace_info.end_time,
  81. extra={
  82. "metadata": metadata,
  83. },
  84. tags=["message", "workflow"],
  85. error=trace_info.error,
  86. trace_id=trace_id,
  87. dotted_order=message_dotted_order,
  88. file_list=[],
  89. serialized=None,
  90. parent_run_id=None,
  91. events=[],
  92. session_id=None,
  93. session_name=None,
  94. reference_example_id=None,
  95. input_attachments={},
  96. output_attachments={},
  97. )
  98. self.add_run(message_run)
  99. langsmith_run = LangSmithRunModel(
  100. file_list=trace_info.file_list,
  101. total_tokens=trace_info.total_tokens,
  102. id=trace_info.workflow_run_id,
  103. name=TraceTaskName.WORKFLOW_TRACE.value,
  104. inputs=dict(trace_info.workflow_run_inputs),
  105. run_type=LangSmithRunType.tool,
  106. start_time=trace_info.workflow_data.created_at,
  107. end_time=trace_info.workflow_data.finished_at,
  108. outputs=dict(trace_info.workflow_run_outputs),
  109. extra={
  110. "metadata": metadata,
  111. },
  112. error=trace_info.error,
  113. tags=["workflow"],
  114. parent_run_id=trace_info.message_id or None,
  115. trace_id=trace_id,
  116. dotted_order=workflow_dotted_order,
  117. serialized=None,
  118. events=[],
  119. session_id=None,
  120. session_name=None,
  121. reference_example_id=None,
  122. input_attachments={},
  123. output_attachments={},
  124. )
  125. self.add_run(langsmith_run)
  126. # through workflow_run_id get all_nodes_execution
  127. workflow_nodes_execution_id_records = (
  128. db.session.query(WorkflowNodeExecution.id)
  129. .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
  130. .all()
  131. )
  132. for node_execution_id_record in workflow_nodes_execution_id_records:
  133. node_execution = (
  134. db.session.query(
  135. WorkflowNodeExecution.id,
  136. WorkflowNodeExecution.tenant_id,
  137. WorkflowNodeExecution.app_id,
  138. WorkflowNodeExecution.title,
  139. WorkflowNodeExecution.node_type,
  140. WorkflowNodeExecution.status,
  141. WorkflowNodeExecution.inputs,
  142. WorkflowNodeExecution.outputs,
  143. WorkflowNodeExecution.created_at,
  144. WorkflowNodeExecution.elapsed_time,
  145. WorkflowNodeExecution.process_data,
  146. WorkflowNodeExecution.execution_metadata,
  147. )
  148. .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
  149. .first()
  150. )
  151. if not node_execution:
  152. continue
  153. node_execution_id = node_execution.id
  154. tenant_id = node_execution.tenant_id
  155. app_id = node_execution.app_id
  156. node_name = node_execution.title
  157. node_type = node_execution.node_type
  158. status = node_execution.status
  159. if node_type == "llm":
  160. inputs = (
  161. json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
  162. )
  163. else:
  164. inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
  165. outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
  166. created_at = node_execution.created_at or datetime.now()
  167. elapsed_time = node_execution.elapsed_time
  168. finished_at = created_at + timedelta(seconds=elapsed_time)
  169. execution_metadata = (
  170. json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
  171. )
  172. node_total_tokens = execution_metadata.get("total_tokens", 0)
  173. metadata = execution_metadata.copy()
  174. metadata.update(
  175. {
  176. "workflow_run_id": trace_info.workflow_run_id,
  177. "node_execution_id": node_execution_id,
  178. "tenant_id": tenant_id,
  179. "app_id": app_id,
  180. "app_name": node_name,
  181. "node_type": node_type,
  182. "status": status,
  183. }
  184. )
  185. process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
  186. if process_data and process_data.get("model_mode") == "chat":
  187. run_type = LangSmithRunType.llm
  188. metadata.update(
  189. {
  190. "ls_provider": process_data.get("model_provider", ""),
  191. "ls_model_name": process_data.get("model_name", ""),
  192. }
  193. )
  194. elif node_type == "knowledge-retrieval":
  195. run_type = LangSmithRunType.retriever
  196. else:
  197. run_type = LangSmithRunType.tool
  198. node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
  199. langsmith_run = LangSmithRunModel(
  200. total_tokens=node_total_tokens,
  201. name=node_type,
  202. inputs=inputs,
  203. run_type=run_type,
  204. start_time=created_at,
  205. end_time=finished_at,
  206. outputs=outputs,
  207. file_list=trace_info.file_list,
  208. extra={
  209. "metadata": metadata,
  210. },
  211. parent_run_id=trace_info.workflow_run_id,
  212. tags=["node_execution"],
  213. id=node_execution_id,
  214. trace_id=trace_id,
  215. dotted_order=node_dotted_order,
  216. error="",
  217. serialized=None,
  218. events=[],
  219. session_id=None,
  220. session_name=None,
  221. reference_example_id=None,
  222. input_attachments={},
  223. output_attachments={},
  224. )
  225. self.add_run(langsmith_run)
  226. def message_trace(self, trace_info: MessageTraceInfo):
  227. # get message file data
  228. file_list = cast(list[str], trace_info.file_list) or []
  229. message_file_data: Optional[MessageFile] = trace_info.message_file_data
  230. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  231. file_list.append(file_url)
  232. metadata = trace_info.metadata
  233. message_data = trace_info.message_data
  234. if message_data is None:
  235. return
  236. message_id = message_data.id
  237. user_id = message_data.from_account_id
  238. metadata["user_id"] = user_id
  239. if message_data.from_end_user_id:
  240. end_user_data: Optional[EndUser] = (
  241. db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
  242. )
  243. if end_user_data is not None:
  244. end_user_id = end_user_data.session_id
  245. metadata["end_user_id"] = end_user_id
  246. message_run = LangSmithRunModel(
  247. input_tokens=trace_info.message_tokens,
  248. output_tokens=trace_info.answer_tokens,
  249. total_tokens=trace_info.total_tokens,
  250. id=message_id,
  251. name=TraceTaskName.MESSAGE_TRACE.value,
  252. inputs=trace_info.inputs,
  253. run_type=LangSmithRunType.chain,
  254. start_time=trace_info.start_time,
  255. end_time=trace_info.end_time,
  256. outputs=message_data.answer,
  257. extra={"metadata": metadata},
  258. tags=["message", str(trace_info.conversation_mode)],
  259. error=trace_info.error,
  260. file_list=file_list,
  261. serialized=None,
  262. events=[],
  263. session_id=None,
  264. session_name=None,
  265. reference_example_id=None,
  266. input_attachments={},
  267. output_attachments={},
  268. trace_id=None,
  269. dotted_order=None,
  270. parent_run_id=None,
  271. )
  272. self.add_run(message_run)
  273. # create llm run parented to message run
  274. llm_run = LangSmithRunModel(
  275. input_tokens=trace_info.message_tokens,
  276. output_tokens=trace_info.answer_tokens,
  277. total_tokens=trace_info.total_tokens,
  278. name="llm",
  279. inputs=trace_info.inputs,
  280. run_type=LangSmithRunType.llm,
  281. start_time=trace_info.start_time,
  282. end_time=trace_info.end_time,
  283. outputs=message_data.answer,
  284. extra={"metadata": metadata},
  285. parent_run_id=message_id,
  286. tags=["llm", str(trace_info.conversation_mode)],
  287. error=trace_info.error,
  288. file_list=file_list,
  289. serialized=None,
  290. events=[],
  291. session_id=None,
  292. session_name=None,
  293. reference_example_id=None,
  294. input_attachments={},
  295. output_attachments={},
  296. trace_id=None,
  297. dotted_order=None,
  298. id=str(uuid.uuid4()),
  299. )
  300. self.add_run(llm_run)
  301. def moderation_trace(self, trace_info: ModerationTraceInfo):
  302. if trace_info.message_data is None:
  303. return
  304. langsmith_run = LangSmithRunModel(
  305. name=TraceTaskName.MODERATION_TRACE.value,
  306. inputs=trace_info.inputs,
  307. outputs={
  308. "action": trace_info.action,
  309. "flagged": trace_info.flagged,
  310. "preset_response": trace_info.preset_response,
  311. "inputs": trace_info.inputs,
  312. },
  313. run_type=LangSmithRunType.tool,
  314. extra={"metadata": trace_info.metadata},
  315. tags=["moderation"],
  316. parent_run_id=trace_info.message_id,
  317. start_time=trace_info.start_time or trace_info.message_data.created_at,
  318. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  319. id=str(uuid.uuid4()),
  320. serialized=None,
  321. events=[],
  322. session_id=None,
  323. session_name=None,
  324. reference_example_id=None,
  325. input_attachments={},
  326. output_attachments={},
  327. trace_id=None,
  328. dotted_order=None,
  329. error="",
  330. file_list=[],
  331. )
  332. self.add_run(langsmith_run)
  333. def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
  334. message_data = trace_info.message_data
  335. if message_data is None:
  336. return
  337. suggested_question_run = LangSmithRunModel(
  338. name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
  339. inputs=trace_info.inputs,
  340. outputs=trace_info.suggested_question,
  341. run_type=LangSmithRunType.tool,
  342. extra={"metadata": trace_info.metadata},
  343. tags=["suggested_question"],
  344. parent_run_id=trace_info.message_id,
  345. start_time=trace_info.start_time or message_data.created_at,
  346. end_time=trace_info.end_time or message_data.updated_at,
  347. id=str(uuid.uuid4()),
  348. serialized=None,
  349. events=[],
  350. session_id=None,
  351. session_name=None,
  352. reference_example_id=None,
  353. input_attachments={},
  354. output_attachments={},
  355. trace_id=None,
  356. dotted_order=None,
  357. error="",
  358. file_list=[],
  359. )
  360. self.add_run(suggested_question_run)
  361. def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
  362. if trace_info.message_data is None:
  363. return
  364. dataset_retrieval_run = LangSmithRunModel(
  365. name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
  366. inputs=trace_info.inputs,
  367. outputs={"documents": trace_info.documents},
  368. run_type=LangSmithRunType.retriever,
  369. extra={"metadata": trace_info.metadata},
  370. tags=["dataset_retrieval"],
  371. parent_run_id=trace_info.message_id,
  372. start_time=trace_info.start_time or trace_info.message_data.created_at,
  373. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  374. id=str(uuid.uuid4()),
  375. serialized=None,
  376. events=[],
  377. session_id=None,
  378. session_name=None,
  379. reference_example_id=None,
  380. input_attachments={},
  381. output_attachments={},
  382. trace_id=None,
  383. dotted_order=None,
  384. error="",
  385. file_list=[],
  386. )
  387. self.add_run(dataset_retrieval_run)
  388. def tool_trace(self, trace_info: ToolTraceInfo):
  389. tool_run = LangSmithRunModel(
  390. name=trace_info.tool_name,
  391. inputs=trace_info.tool_inputs,
  392. outputs=trace_info.tool_outputs,
  393. run_type=LangSmithRunType.tool,
  394. extra={
  395. "metadata": trace_info.metadata,
  396. },
  397. tags=["tool", trace_info.tool_name],
  398. parent_run_id=trace_info.message_id,
  399. start_time=trace_info.start_time,
  400. end_time=trace_info.end_time,
  401. file_list=[cast(str, trace_info.file_url)],
  402. id=str(uuid.uuid4()),
  403. serialized=None,
  404. events=[],
  405. session_id=None,
  406. session_name=None,
  407. reference_example_id=None,
  408. input_attachments={},
  409. output_attachments={},
  410. trace_id=None,
  411. dotted_order=None,
  412. error=trace_info.error or "",
  413. )
  414. self.add_run(tool_run)
  415. def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
  416. name_run = LangSmithRunModel(
  417. name=TraceTaskName.GENERATE_NAME_TRACE.value,
  418. inputs=trace_info.inputs,
  419. outputs=trace_info.outputs,
  420. run_type=LangSmithRunType.tool,
  421. extra={"metadata": trace_info.metadata},
  422. tags=["generate_name"],
  423. start_time=trace_info.start_time or datetime.now(),
  424. end_time=trace_info.end_time or datetime.now(),
  425. id=str(uuid.uuid4()),
  426. serialized=None,
  427. events=[],
  428. session_id=None,
  429. session_name=None,
  430. reference_example_id=None,
  431. input_attachments={},
  432. output_attachments={},
  433. trace_id=None,
  434. dotted_order=None,
  435. error="",
  436. file_list=[],
  437. parent_run_id=None,
  438. )
  439. self.add_run(name_run)
  440. def add_run(self, run_data: LangSmithRunModel):
  441. data = run_data.model_dump()
  442. if self.project_id:
  443. data["session_id"] = self.project_id
  444. elif self.project_name:
  445. data["session_name"] = self.project_name
  446. data = filter_none_values(data)
  447. try:
  448. self.langsmith_client.create_run(**data)
  449. logger.debug("LangSmith Run created successfully.")
  450. except Exception as e:
  451. raise ValueError(f"LangSmith Failed to create run: {str(e)}")
  452. def update_run(self, update_run_data: LangSmithRunUpdateModel):
  453. data = update_run_data.model_dump()
  454. data = filter_none_values(data)
  455. try:
  456. self.langsmith_client.update_run(**data)
  457. logger.debug("LangSmith Run updated successfully.")
  458. except Exception as e:
  459. raise ValueError(f"LangSmith Failed to update run: {str(e)}")
  460. def api_check(self):
  461. try:
  462. random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
  463. self.langsmith_client.create_project(project_name=random_project_name)
  464. self.langsmith_client.delete_project(project_name=random_project_name)
  465. return True
  466. except Exception as e:
  467. logger.debug(f"LangSmith API check failed: {str(e)}")
  468. raise ValueError(f"LangSmith API check failed: {str(e)}")
  469. def get_project_url(self):
  470. try:
  471. run_data = RunBase(
  472. id=uuid.uuid4(),
  473. name="tool",
  474. inputs={"input": "test"},
  475. outputs={"output": "test"},
  476. run_type=LangSmithRunType.tool,
  477. start_time=datetime.now(),
  478. )
  479. project_url = self.langsmith_client.get_run_url(
  480. run=run_data, project_id=self.project_id, project_name=self.project_name
  481. )
  482. return project_url.split("/r/")[0]
  483. except Exception as e:
  484. logger.debug(f"LangSmith get run url failed: {str(e)}")
  485. raise ValueError(f"LangSmith get run url failed: {str(e)}")