huggingface_chat.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from huggingface_hub import InferenceClient
  2. from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Details, StreamDetails, Token
  3. from huggingface_hub.utils import BadRequestError
  4. from typing import Literal, Optional, List, Generator, Union, Any
  5. from _pytest.monkeypatch import MonkeyPatch
  6. import re
  7. class MockHuggingfaceChatClass(object):
  8. @staticmethod
  9. def generate_create_sync(model: str) -> TextGenerationResponse:
  10. response = TextGenerationResponse(
  11. generated_text="You can call me Miku Miku o~e~o~",
  12. details=Details(
  13. finish_reason="length",
  14. generated_tokens=6,
  15. tokens=[
  16. Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
  17. ]
  18. )
  19. )
  20. return response
  21. @staticmethod
  22. def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
  23. full_text = "You can call me Miku Miku o~e~o~"
  24. for i in range(0, len(full_text)):
  25. response = TextGenerationStreamResponse(
  26. token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
  27. )
  28. response.generated_text = full_text[i]
  29. response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
  30. yield response
  31. def text_generation(self: InferenceClient, prompt: str, *,
  32. stream: Literal[False] = ...,
  33. model: Optional[str] = None,
  34. **kwargs: Any
  35. ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
  36. # check if key is valid
  37. if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
  38. raise BadRequestError('Invalid API key')
  39. if model is None:
  40. raise BadRequestError('Invalid model')
  41. if stream:
  42. return MockHuggingfaceChatClass.generate_create_stream(model)
  43. return MockHuggingfaceChatClass.generate_create_sync(model)