huggingface_chat.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import re
  2. from collections.abc import Generator
  3. from typing import Any, Literal, Optional, Union
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from huggingface_hub import InferenceClient
  6. from huggingface_hub.inference._text_generation import (
  7. Details,
  8. StreamDetails,
  9. TextGenerationResponse,
  10. TextGenerationStreamResponse,
  11. Token,
  12. )
  13. from huggingface_hub.utils import BadRequestError
  14. class MockHuggingfaceChatClass:
  15. @staticmethod
  16. def generate_create_sync(model: str) -> TextGenerationResponse:
  17. response = TextGenerationResponse(
  18. generated_text="You can call me Miku Miku o~e~o~",
  19. details=Details(
  20. finish_reason="length",
  21. generated_tokens=6,
  22. tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
  23. ),
  24. )
  25. return response
  26. @staticmethod
  27. def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
  28. full_text = "You can call me Miku Miku o~e~o~"
  29. for i in range(0, len(full_text)):
  30. response = TextGenerationStreamResponse(
  31. token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
  32. )
  33. response.generated_text = full_text[i]
  34. response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
  35. yield response
  36. def text_generation(
  37. self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
  38. ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
  39. # check if key is valid
  40. if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
  41. raise BadRequestError("Invalid API key")
  42. if model is None:
  43. raise BadRequestError("Invalid model")
  44. if stream:
  45. return MockHuggingfaceChatClass.generate_create_stream(model)
  46. return MockHuggingfaceChatClass.generate_create_sync(model)