""" Client for interacting with the Text Generation WebUI API. """ import json import time import uuid import logging import asyncio import aiohttp from typing import Dict, List, Optional, Union, Any, AsyncGenerator, cast from urllib.parse import urljoin from .models import ( ChatMessage, ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, CompletionResponse, ModelInfo, ModelListResponse, ModelLoadRequest, LogitsRequest, ) logger = logging.getLogger(__name__) class TextGenClient: """Client for interacting with the Text Generation WebUI API.""" def __init__( self, base_url: str = "http://textgen.localhost/v1", api_key: Optional[str] = None, timeout: int = 120, ): """ Initialize the TextGen client. Args: base_url: Base URL for the TextGen API api_key: API key for authentication (optional) timeout: Request timeout in seconds """ self.base_url = base_url self.api_key = api_key self.timeout = timeout self._session = None async def _ensure_session(self) -> aiohttp.ClientSession: """Ensure that an aiohttp session exists.""" if self._session is None or self._session.closed: self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.timeout) ) return self._session async def close(self): """Close the client session.""" if self._session and not self._session.closed: await self._session.close() self._session = None def _get_headers(self) -> Dict[str, str]: """Get headers for API requests.""" headers = { "Content-Type": "application/json", } if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" return headers async def _make_request( self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Make a request to the TextGen API. Args: method: HTTP method (GET, POST, etc.) endpoint: API endpoint data: Request data Returns: API response as a dictionary """ session = await self._ensure_session() url = urljoin(self.base_url, endpoint) try: async with session.request( method=method, url=url, headers=self._get_headers(), json=data, raise_for_status=True, ) as response: return await response.json() except aiohttp.ClientResponseError as e: logger.error(f"API request failed: {e.status} {e.message}") raise except aiohttp.ClientError as e: logger.error(f"Request error: {str(e)}") raise except asyncio.TimeoutError: logger.error(f"Request timed out after {self.timeout} seconds") raise async def _stream_request( self, endpoint: str, data: Dict[str, Any] ) -> AsyncGenerator[Dict[str, Any], None]: """ Make a streaming request to the TextGen API. Args: endpoint: API endpoint data: Request data Yields: Chunks of the API response """ session = await self._ensure_session() url = urljoin(self.base_url, endpoint) try: async with session.post( url=url, headers=self._get_headers(), json=data, raise_for_status=True, ) as response: async for line in response.content: line = line.strip() if not line or line == b"data: [DONE]": continue if line.startswith(b"data: "): line = line[6:] # Remove "data: " prefix try: yield json.loads(line) except json.JSONDecodeError: logger.error(f"Failed to parse SSE data: {line}") except aiohttp.ClientResponseError as e: logger.error(f"API request failed: {e.status} {e.message}") raise except aiohttp.ClientError as e: logger.error(f"Request error: {str(e)}") raise except asyncio.TimeoutError: logger.error(f"Request timed out after {self.timeout} seconds") raise async def list_models(self) -> List[ModelInfo]: """ List available models. Returns: List of available models """ response = await self._make_request("GET", "internal/model/list") model_list = ModelListResponse(**response) return model_list.data async def load_model(self, model_name: str, **kwargs) -> Dict[str, Any]: """ Load a model. Args: model_name: Name of the model to load **kwargs: Additional arguments for loading the model Returns: Response from the API """ request = ModelLoadRequest(model_name=model_name, args=kwargs) return await self._make_request("POST", "internal/model/load", request.dict()) async def chat_completion( self, request: ChatCompletionRequest ) -> Union[ChatCompletionResponse, AsyncGenerator[Dict[str, Any], None]]: """ Create a chat completion. Args: request: Chat completion request Returns: Chat completion response or a stream of responses """ request_data = request.dict(exclude_none=True) if request.stream: return self._stream_request("chat/completions", request_data) response = await self._make_request("POST", "chat/completions", request_data) return ChatCompletionResponse(**response) async def completion( self, request: CompletionRequest ) -> Union[CompletionResponse, AsyncGenerator[Dict[str, Any], None]]: """ Create a text completion. Args: request: Completion request Returns: Completion response or a stream of responses """ request_data = request.dict(exclude_none=True) if request.stream: return self._stream_request("completions", request_data) response = await self._make_request("POST", "completions", request_data) return CompletionResponse(**response) async def get_logits(self, request: LogitsRequest) -> Dict[str, Any]: """ Get logits for a prompt. Args: request: Logits request Returns: Logits response """ request_data = request.dict(exclude_none=True) return await self._make_request("POST", "internal/logits", request_data) async def simple_chat( self, messages: List[Dict[str, str]], model: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.9, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, stream: bool = False, mode: str = "instruct", character: Optional[str] = None, instruction_template: Optional[str] = None, seed: Optional[int] = None, ) -> Union[str, AsyncGenerator[str, None]]: """ Simple interface for chat completions. Args: messages: List of message dictionaries with 'role' and 'content' model: Model to use temperature: Sampling temperature top_p: Nucleus sampling parameter max_tokens: Maximum tokens to generate stop: Stop sequences presence_penalty: Presence penalty frequency_penalty: Frequency penalty stream: Whether to stream the response mode: Mode (chat or instruct) character: Character to use (for chat mode) instruction_template: Instruction template (for instruct mode) seed: Random seed for reproducibility Returns: Generated text or a stream of text chunks """ chat_messages = [ChatMessage(**msg) for msg in messages] request = ChatCompletionRequest( messages=chat_messages, model=model, temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop=stop, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, stream=stream, mode=mode, character=character, instruction_template=instruction_template, seed=seed, ) if stream: async def text_stream() -> AsyncGenerator[str, None]: stream_response = await self.chat_completion(request) if isinstance(stream_response, AsyncGenerator): async for chunk in stream_response: if "choices" in chunk and chunk["choices"]: if ( "delta" in chunk["choices"][0] and "content" in chunk["choices"][0]["delta"] ): yield chunk["choices"][0]["delta"]["content"] return text_stream() else: response = await self.chat_completion(request) if isinstance(response, ChatCompletionResponse): return response.choices[0].message.content # This should never happen due to the if/else structure, but satisfies the type checker raise TypeError("Expected ChatCompletionResponse but got stream response") async def simple_completion( self, prompt: str, model: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.9, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, stream: bool = False, seed: Optional[int] = None, ) -> Union[str, AsyncGenerator[str, None]]: """ Simple interface for text completions. Args: prompt: Text prompt model: Model to use temperature: Sampling temperature top_p: Nucleus sampling parameter max_tokens: Maximum tokens to generate stop: Stop sequences presence_penalty: Presence penalty frequency_penalty: Frequency penalty stream: Whether to stream the response seed: Random seed for reproducibility Returns: Generated text or a stream of text chunks """ request = CompletionRequest( prompt=prompt, model=model, temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop=stop, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, stream=stream, seed=seed, ) if stream: async def text_stream() -> AsyncGenerator[str, None]: stream_response = await self.completion(request) if isinstance(stream_response, AsyncGenerator): async for chunk in stream_response: if "choices" in chunk and chunk["choices"]: if "text" in chunk["choices"][0]: yield chunk["choices"][0]["text"] return text_stream() else: response = await self.completion(request) if isinstance(response, CompletionResponse): return response.choices[0].text # This should never happen due to the if/else structure, but satisfies the type checker raise TypeError("Expected CompletionResponse but got stream response")