langsmith_trace.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import json
  2. import logging
  3. import os
  4. import uuid
  5. from datetime import datetime, timedelta
  6. from langsmith import Client
  7. from langsmith.schemas import RunBase
  8. from core.ops.base_trace_instance import BaseTraceInstance
  9. from core.ops.entities.config_entity import LangSmithConfig
  10. from core.ops.entities.trace_entity import (
  11. BaseTraceInfo,
  12. DatasetRetrievalTraceInfo,
  13. GenerateNameTraceInfo,
  14. MessageTraceInfo,
  15. ModerationTraceInfo,
  16. SuggestedQuestionTraceInfo,
  17. ToolTraceInfo,
  18. TraceTaskName,
  19. WorkflowTraceInfo,
  20. )
  21. from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
  22. LangSmithRunModel,
  23. LangSmithRunType,
  24. LangSmithRunUpdateModel,
  25. )
  26. from core.ops.utils import filter_none_values
  27. from extensions.ext_database import db
  28. from models.model import EndUser, MessageFile
  29. from models.workflow import WorkflowNodeExecution
  30. logger = logging.getLogger(__name__)
  31. class LangSmithDataTrace(BaseTraceInstance):
  32. def __init__(
  33. self,
  34. langsmith_config: LangSmithConfig,
  35. ):
  36. super().__init__(langsmith_config)
  37. self.langsmith_key = langsmith_config.api_key
  38. self.project_name = langsmith_config.project
  39. self.project_id = None
  40. self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
  41. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  42. def trace(self, trace_info: BaseTraceInfo):
  43. if isinstance(trace_info, WorkflowTraceInfo):
  44. self.workflow_trace(trace_info)
  45. if isinstance(trace_info, MessageTraceInfo):
  46. self.message_trace(trace_info)
  47. if isinstance(trace_info, ModerationTraceInfo):
  48. self.moderation_trace(trace_info)
  49. if isinstance(trace_info, SuggestedQuestionTraceInfo):
  50. self.suggested_question_trace(trace_info)
  51. if isinstance(trace_info, DatasetRetrievalTraceInfo):
  52. self.dataset_retrieval_trace(trace_info)
  53. if isinstance(trace_info, ToolTraceInfo):
  54. self.tool_trace(trace_info)
  55. if isinstance(trace_info, GenerateNameTraceInfo):
  56. self.generate_name_trace(trace_info)
  57. def workflow_trace(self, trace_info: WorkflowTraceInfo):
  58. if trace_info.message_id:
  59. message_run = LangSmithRunModel(
  60. id=trace_info.message_id,
  61. name=TraceTaskName.MESSAGE_TRACE.value,
  62. inputs=trace_info.workflow_run_inputs,
  63. outputs=trace_info.workflow_run_outputs,
  64. run_type=LangSmithRunType.chain,
  65. start_time=trace_info.start_time,
  66. end_time=trace_info.end_time,
  67. extra={
  68. "metadata": trace_info.metadata,
  69. },
  70. tags=["message", "workflow"],
  71. error=trace_info.error,
  72. )
  73. self.add_run(message_run)
  74. langsmith_run = LangSmithRunModel(
  75. file_list=trace_info.file_list,
  76. total_tokens=trace_info.total_tokens,
  77. id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
  78. name=TraceTaskName.WORKFLOW_TRACE.value,
  79. inputs=trace_info.workflow_run_inputs,
  80. run_type=LangSmithRunType.tool,
  81. start_time=trace_info.workflow_data.created_at,
  82. end_time=trace_info.workflow_data.finished_at,
  83. outputs=trace_info.workflow_run_outputs,
  84. extra={
  85. "metadata": trace_info.metadata,
  86. },
  87. error=trace_info.error,
  88. tags=["workflow"],
  89. parent_run_id=trace_info.message_id if trace_info.message_id else None,
  90. )
  91. self.add_run(langsmith_run)
  92. # through workflow_run_id get all_nodes_execution
  93. workflow_nodes_executions = (
  94. db.session.query(
  95. WorkflowNodeExecution.id,
  96. WorkflowNodeExecution.tenant_id,
  97. WorkflowNodeExecution.app_id,
  98. WorkflowNodeExecution.title,
  99. WorkflowNodeExecution.node_type,
  100. WorkflowNodeExecution.status,
  101. WorkflowNodeExecution.inputs,
  102. WorkflowNodeExecution.outputs,
  103. WorkflowNodeExecution.created_at,
  104. WorkflowNodeExecution.elapsed_time,
  105. WorkflowNodeExecution.process_data,
  106. WorkflowNodeExecution.execution_metadata,
  107. )
  108. .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
  109. .all()
  110. )
  111. for node_execution in workflow_nodes_executions:
  112. node_execution_id = node_execution.id
  113. tenant_id = node_execution.tenant_id
  114. app_id = node_execution.app_id
  115. node_name = node_execution.title
  116. node_type = node_execution.node_type
  117. status = node_execution.status
  118. if node_type == "llm":
  119. inputs = (
  120. json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
  121. )
  122. else:
  123. inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
  124. outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
  125. created_at = node_execution.created_at if node_execution.created_at else datetime.now()
  126. elapsed_time = node_execution.elapsed_time
  127. finished_at = created_at + timedelta(seconds=elapsed_time)
  128. execution_metadata = (
  129. json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
  130. )
  131. node_total_tokens = execution_metadata.get("total_tokens", 0)
  132. metadata = execution_metadata.copy()
  133. metadata.update(
  134. {
  135. "workflow_run_id": trace_info.workflow_run_id,
  136. "node_execution_id": node_execution_id,
  137. "tenant_id": tenant_id,
  138. "app_id": app_id,
  139. "app_name": node_name,
  140. "node_type": node_type,
  141. "status": status,
  142. }
  143. )
  144. process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
  145. if process_data and process_data.get("model_mode") == "chat":
  146. run_type = LangSmithRunType.llm
  147. metadata.update(
  148. {
  149. "ls_provider": process_data.get("model_provider", ""),
  150. "ls_model_name": process_data.get("model_name", ""),
  151. }
  152. )
  153. elif node_type == "knowledge-retrieval":
  154. run_type = LangSmithRunType.retriever
  155. else:
  156. run_type = LangSmithRunType.tool
  157. langsmith_run = LangSmithRunModel(
  158. total_tokens=node_total_tokens,
  159. name=node_type,
  160. inputs=inputs,
  161. run_type=run_type,
  162. start_time=created_at,
  163. end_time=finished_at,
  164. outputs=outputs,
  165. file_list=trace_info.file_list,
  166. extra={
  167. "metadata": metadata,
  168. },
  169. parent_run_id=trace_info.workflow_app_log_id
  170. if trace_info.workflow_app_log_id
  171. else trace_info.workflow_run_id,
  172. tags=["node_execution"],
  173. )
  174. self.add_run(langsmith_run)
  175. def message_trace(self, trace_info: MessageTraceInfo):
  176. # get message file data
  177. file_list = trace_info.file_list
  178. message_file_data: MessageFile = trace_info.message_file_data
  179. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  180. file_list.append(file_url)
  181. metadata = trace_info.metadata
  182. message_data = trace_info.message_data
  183. message_id = message_data.id
  184. user_id = message_data.from_account_id
  185. metadata["user_id"] = user_id
  186. if message_data.from_end_user_id:
  187. end_user_data: EndUser = (
  188. db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
  189. )
  190. if end_user_data is not None:
  191. end_user_id = end_user_data.session_id
  192. metadata["end_user_id"] = end_user_id
  193. message_run = LangSmithRunModel(
  194. input_tokens=trace_info.message_tokens,
  195. output_tokens=trace_info.answer_tokens,
  196. total_tokens=trace_info.total_tokens,
  197. id=message_id,
  198. name=TraceTaskName.MESSAGE_TRACE.value,
  199. inputs=trace_info.inputs,
  200. run_type=LangSmithRunType.chain,
  201. start_time=trace_info.start_time,
  202. end_time=trace_info.end_time,
  203. outputs=message_data.answer,
  204. extra={
  205. "metadata": metadata,
  206. },
  207. tags=["message", str(trace_info.conversation_mode)],
  208. error=trace_info.error,
  209. file_list=file_list,
  210. )
  211. self.add_run(message_run)
  212. # create llm run parented to message run
  213. llm_run = LangSmithRunModel(
  214. input_tokens=trace_info.message_tokens,
  215. output_tokens=trace_info.answer_tokens,
  216. total_tokens=trace_info.total_tokens,
  217. name="llm",
  218. inputs=trace_info.inputs,
  219. run_type=LangSmithRunType.llm,
  220. start_time=trace_info.start_time,
  221. end_time=trace_info.end_time,
  222. outputs=message_data.answer,
  223. extra={
  224. "metadata": metadata,
  225. },
  226. parent_run_id=message_id,
  227. tags=["llm", str(trace_info.conversation_mode)],
  228. error=trace_info.error,
  229. file_list=file_list,
  230. )
  231. self.add_run(llm_run)
  232. def moderation_trace(self, trace_info: ModerationTraceInfo):
  233. langsmith_run = LangSmithRunModel(
  234. name=TraceTaskName.MODERATION_TRACE.value,
  235. inputs=trace_info.inputs,
  236. outputs={
  237. "action": trace_info.action,
  238. "flagged": trace_info.flagged,
  239. "preset_response": trace_info.preset_response,
  240. "inputs": trace_info.inputs,
  241. },
  242. run_type=LangSmithRunType.tool,
  243. extra={
  244. "metadata": trace_info.metadata,
  245. },
  246. tags=["moderation"],
  247. parent_run_id=trace_info.message_id,
  248. start_time=trace_info.start_time or trace_info.message_data.created_at,
  249. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  250. )
  251. self.add_run(langsmith_run)
  252. def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
  253. message_data = trace_info.message_data
  254. suggested_question_run = LangSmithRunModel(
  255. name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
  256. inputs=trace_info.inputs,
  257. outputs=trace_info.suggested_question,
  258. run_type=LangSmithRunType.tool,
  259. extra={
  260. "metadata": trace_info.metadata,
  261. },
  262. tags=["suggested_question"],
  263. parent_run_id=trace_info.message_id,
  264. start_time=trace_info.start_time or message_data.created_at,
  265. end_time=trace_info.end_time or message_data.updated_at,
  266. )
  267. self.add_run(suggested_question_run)
  268. def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
  269. dataset_retrieval_run = LangSmithRunModel(
  270. name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
  271. inputs=trace_info.inputs,
  272. outputs={"documents": trace_info.documents},
  273. run_type=LangSmithRunType.retriever,
  274. extra={
  275. "metadata": trace_info.metadata,
  276. },
  277. tags=["dataset_retrieval"],
  278. parent_run_id=trace_info.message_id,
  279. start_time=trace_info.start_time or trace_info.message_data.created_at,
  280. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  281. )
  282. self.add_run(dataset_retrieval_run)
  283. def tool_trace(self, trace_info: ToolTraceInfo):
  284. tool_run = LangSmithRunModel(
  285. name=trace_info.tool_name,
  286. inputs=trace_info.tool_inputs,
  287. outputs=trace_info.tool_outputs,
  288. run_type=LangSmithRunType.tool,
  289. extra={
  290. "metadata": trace_info.metadata,
  291. },
  292. tags=["tool", trace_info.tool_name],
  293. parent_run_id=trace_info.message_id,
  294. start_time=trace_info.start_time,
  295. end_time=trace_info.end_time,
  296. file_list=[trace_info.file_url],
  297. )
  298. self.add_run(tool_run)
  299. def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
  300. name_run = LangSmithRunModel(
  301. name=TraceTaskName.GENERATE_NAME_TRACE.value,
  302. inputs=trace_info.inputs,
  303. outputs=trace_info.outputs,
  304. run_type=LangSmithRunType.tool,
  305. extra={
  306. "metadata": trace_info.metadata,
  307. },
  308. tags=["generate_name"],
  309. start_time=trace_info.start_time or datetime.now(),
  310. end_time=trace_info.end_time or datetime.now(),
  311. )
  312. self.add_run(name_run)
  313. def add_run(self, run_data: LangSmithRunModel):
  314. data = run_data.model_dump()
  315. if self.project_id:
  316. data["session_id"] = self.project_id
  317. elif self.project_name:
  318. data["session_name"] = self.project_name
  319. data = filter_none_values(data)
  320. try:
  321. self.langsmith_client.create_run(**data)
  322. logger.debug("LangSmith Run created successfully.")
  323. except Exception as e:
  324. raise ValueError(f"LangSmith Failed to create run: {str(e)}")
  325. def update_run(self, update_run_data: LangSmithRunUpdateModel):
  326. data = update_run_data.model_dump()
  327. data = filter_none_values(data)
  328. try:
  329. self.langsmith_client.update_run(**data)
  330. logger.debug("LangSmith Run updated successfully.")
  331. except Exception as e:
  332. raise ValueError(f"LangSmith Failed to update run: {str(e)}")
  333. def api_check(self):
  334. try:
  335. random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
  336. self.langsmith_client.create_project(project_name=random_project_name)
  337. self.langsmith_client.delete_project(project_name=random_project_name)
  338. return True
  339. except Exception as e:
  340. logger.debug(f"LangSmith API check failed: {str(e)}")
  341. raise ValueError(f"LangSmith API check failed: {str(e)}")
  342. def get_project_url(self):
  343. try:
  344. run_data = RunBase(
  345. id=uuid.uuid4(),
  346. name="tool",
  347. inputs={"input": "test"},
  348. outputs={"output": "test"},
  349. run_type=LangSmithRunType.tool,
  350. start_time=datetime.now(),
  351. )
  352. project_url = self.langsmith_client.get_run_url(
  353. run=run_data, project_id=self.project_id, project_name=self.project_name
  354. )
  355. return project_url.split("/r/")[0]
  356. except Exception as e:
  357. logger.debug(f"LangSmith get run url failed: {str(e)}")
  358. raise ValueError(f"LangSmith get run url failed: {str(e)}")