huggingface_tei.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
  2. class MockTEIClass:
  3. @staticmethod
  4. def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
  5. # During mock, we don't have a real server to query, so we just return a dummy value
  6. if "rerank" in model_name:
  7. model_type = "reranker"
  8. else:
  9. model_type = "embedding"
  10. return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
  11. @staticmethod
  12. def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
  13. # Use space as token separator, and split the text into tokens
  14. tokenized_texts = []
  15. for text in texts:
  16. tokens = text.split(" ")
  17. current_index = 0
  18. tokenized_text = []
  19. for idx, token in enumerate(tokens):
  20. s_token = {
  21. "id": idx,
  22. "text": token,
  23. "special": False,
  24. "start": current_index,
  25. "stop": current_index + len(token),
  26. }
  27. current_index += len(token) + 1
  28. tokenized_text.append(s_token)
  29. tokenized_texts.append(tokenized_text)
  30. return tokenized_texts
  31. @staticmethod
  32. def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
  33. # {
  34. # "object": "list",
  35. # "data": [
  36. # {
  37. # "object": "embedding",
  38. # "embedding": [...],
  39. # "index": 0
  40. # }
  41. # ],
  42. # "model": "MODEL_NAME",
  43. # "usage": {
  44. # "prompt_tokens": 3,
  45. # "total_tokens": 3
  46. # }
  47. # }
  48. embeddings = []
  49. for idx in range(len(texts)):
  50. embedding = [0.1] * 768
  51. embeddings.append(
  52. {
  53. "object": "embedding",
  54. "embedding": embedding,
  55. "index": idx,
  56. }
  57. )
  58. return {
  59. "object": "list",
  60. "data": embeddings,
  61. "model": "MODEL_NAME",
  62. "usage": {
  63. "prompt_tokens": sum(len(text.split(" ")) for text in texts),
  64. "total_tokens": sum(len(text.split(" ")) for text in texts),
  65. },
  66. }
  67. @staticmethod
  68. def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
  69. # Example response:
  70. # [
  71. # {
  72. # "index": 0,
  73. # "text": "Deep Learning is ...",
  74. # "score": 0.9950755
  75. # }
  76. # ]
  77. reranked_docs = []
  78. for idx, text in enumerate(texts):
  79. reranked_docs.append(
  80. {
  81. "index": idx,
  82. "text": text,
  83. "score": 0.9,
  84. }
  85. )
  86. # For mock, only return the first document
  87. break
  88. return reranked_docs