huggingface_chat.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import re
  2. from collections.abc import Generator
  3. from typing import Any, Literal, Optional, Union
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from huggingface_hub import InferenceClient
  6. from huggingface_hub.inference._text_generation import (
  7. Details,
  8. StreamDetails,
  9. TextGenerationResponse,
  10. TextGenerationStreamResponse,
  11. Token,
  12. )
  13. from huggingface_hub.utils import BadRequestError
  14. class MockHuggingfaceChatClass:
  15. @staticmethod
  16. def generate_create_sync(model: str) -> TextGenerationResponse:
  17. response = TextGenerationResponse(
  18. generated_text="You can call me Miku Miku o~e~o~",
  19. details=Details(
  20. finish_reason="length",
  21. generated_tokens=6,
  22. tokens=[
  23. Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
  24. ]
  25. )
  26. )
  27. return response
  28. @staticmethod
  29. def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
  30. full_text = "You can call me Miku Miku o~e~o~"
  31. for i in range(0, len(full_text)):
  32. response = TextGenerationStreamResponse(
  33. token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
  34. )
  35. response.generated_text = full_text[i]
  36. response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
  37. yield response
  38. def text_generation(self: InferenceClient, prompt: str, *,
  39. stream: Literal[False] = ...,
  40. model: Optional[str] = None,
  41. **kwargs: Any
  42. ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
  43. # check if key is valid
  44. if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
  45. raise BadRequestError('Invalid API key')
  46. if model is None:
  47. raise BadRequestError('Invalid model')
  48. if stream:
  49. return MockHuggingfaceChatClass.generate_create_stream(model)
  50. return MockHuggingfaceChatClass.generate_create_sync(model)