huggingface_tei.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter
  2. class MockTEIClass:
  3. @staticmethod
  4. def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
  5. # During mock, we don't have a real server to query, so we just return a dummy value
  6. if 'rerank' in model_name:
  7. model_type = 'reranker'
  8. else:
  9. model_type = 'embedding'
  10. return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
  11. @staticmethod
  12. def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
  13. # Use space as token separator, and split the text into tokens
  14. tokenized_texts = []
  15. for text in texts:
  16. tokens = text.split(' ')
  17. current_index = 0
  18. tokenized_text = []
  19. for idx, token in enumerate(tokens):
  20. s_token = {
  21. 'id': idx,
  22. 'text': token,
  23. 'special': False,
  24. 'start': current_index,
  25. 'stop': current_index + len(token),
  26. }
  27. current_index += len(token) + 1
  28. tokenized_text.append(s_token)
  29. tokenized_texts.append(tokenized_text)
  30. return tokenized_texts
  31. @staticmethod
  32. def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
  33. # {
  34. # "object": "list",
  35. # "data": [
  36. # {
  37. # "object": "embedding",
  38. # "embedding": [...],
  39. # "index": 0
  40. # }
  41. # ],
  42. # "model": "MODEL_NAME",
  43. # "usage": {
  44. # "prompt_tokens": 3,
  45. # "total_tokens": 3
  46. # }
  47. # }
  48. embeddings = []
  49. for idx, text in enumerate(texts):
  50. embedding = [0.1] * 768
  51. embeddings.append(
  52. {
  53. 'object': 'embedding',
  54. 'embedding': embedding,
  55. 'index': idx,
  56. }
  57. )
  58. return {
  59. 'object': 'list',
  60. 'data': embeddings,
  61. 'model': 'MODEL_NAME',
  62. 'usage': {
  63. 'prompt_tokens': sum(len(text.split(' ')) for text in texts),
  64. 'total_tokens': sum(len(text.split(' ')) for text in texts),
  65. },
  66. }
  67. def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
  68. # Example response:
  69. # [
  70. # {
  71. # "index": 0,
  72. # "text": "Deep Learning is ...",
  73. # "score": 0.9950755
  74. # }
  75. # ]
  76. reranked_docs = []
  77. for idx, text in enumerate(texts):
  78. reranked_docs.append(
  79. {
  80. 'index': idx,
  81. 'text': text,
  82. 'score': 0.9,
  83. }
  84. )
  85. # For mock, only return the first document
  86. break
  87. return reranked_docs