import re from collections.abc import Generator from typing import Any, Literal, Optional, Union from _pytest.monkeypatch import MonkeyPatch from huggingface_hub import InferenceClient from huggingface_hub.inference._text_generation import ( Details, StreamDetails, TextGenerationResponse, TextGenerationStreamResponse, Token, ) from huggingface_hub.utils import BadRequestError class MockHuggingfaceChatClass: @staticmethod def generate_create_sync(model: str) -> TextGenerationResponse: response = TextGenerationResponse( generated_text="You can call me Miku Miku o~e~o~", details=Details( finish_reason="length", generated_tokens=6, tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)], ), ) return response @staticmethod def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]: full_text = "You can call me Miku Miku o~e~o~" for i in range(0, len(full_text)): response = TextGenerationStreamResponse( token=Token(id=i, text=full_text[i], logprob=0.0, special=False), ) response.generated_text = full_text[i] response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1) yield response def text_generation( self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]: # check if key is valid if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]): raise BadRequestError("Invalid API key") if model is None: raise BadRequestError("Invalid model") if stream: return MockHuggingfaceChatClass.generate_create_stream(model) return MockHuggingfaceChatClass.generate_create_sync(model)