2025-07-11 13:52:19 +00:00

367 lines
12 KiB
Python

"""
Client for interacting with the Text Generation WebUI API.
"""
import json
import time
import uuid
import logging
import asyncio
import aiohttp
from typing import Dict, List, Optional, Union, Any, AsyncGenerator, cast
from urllib.parse import urljoin
from .models import (
ChatMessage,
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
ModelInfo,
ModelListResponse,
ModelLoadRequest,
LogitsRequest,
)
logger = logging.getLogger(__name__)
class TextGenClient:
"""Client for interacting with the Text Generation WebUI API."""
def __init__(
self,
base_url: str = "http://textgen.localhost/v1",
api_key: Optional[str] = None,
timeout: int = 120,
):
"""
Initialize the TextGen client.
Args:
base_url: Base URL for the TextGen API
api_key: API key for authentication (optional)
timeout: Request timeout in seconds
"""
self.base_url = base_url
self.api_key = api_key
self.timeout = timeout
self._session = None
async def _ensure_session(self) -> aiohttp.ClientSession:
"""Ensure that an aiohttp session exists."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
)
return self._session
async def close(self):
"""Close the client session."""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
def _get_headers(self) -> Dict[str, str]:
"""Get headers for API requests."""
headers = {
"Content-Type": "application/json",
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
async def _make_request(
self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Make a request to the TextGen API.
Args:
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint
data: Request data
Returns:
API response as a dictionary
"""
session = await self._ensure_session()
url = urljoin(self.base_url, endpoint)
try:
async with session.request(
method=method,
url=url,
headers=self._get_headers(),
json=data,
raise_for_status=True,
) as response:
return await response.json()
except aiohttp.ClientResponseError as e:
logger.error(f"API request failed: {e.status} {e.message}")
raise
except aiohttp.ClientError as e:
logger.error(f"Request error: {str(e)}")
raise
except asyncio.TimeoutError:
logger.error(f"Request timed out after {self.timeout} seconds")
raise
async def _stream_request(
self, endpoint: str, data: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Make a streaming request to the TextGen API.
Args:
endpoint: API endpoint
data: Request data
Yields:
Chunks of the API response
"""
session = await self._ensure_session()
url = urljoin(self.base_url, endpoint)
try:
async with session.post(
url=url,
headers=self._get_headers(),
json=data,
raise_for_status=True,
) as response:
async for line in response.content:
line = line.strip()
if not line or line == b"data: [DONE]":
continue
if line.startswith(b"data: "):
line = line[6:] # Remove "data: " prefix
try:
yield json.loads(line)
except json.JSONDecodeError:
logger.error(f"Failed to parse SSE data: {line}")
except aiohttp.ClientResponseError as e:
logger.error(f"API request failed: {e.status} {e.message}")
raise
except aiohttp.ClientError as e:
logger.error(f"Request error: {str(e)}")
raise
except asyncio.TimeoutError:
logger.error(f"Request timed out after {self.timeout} seconds")
raise
async def list_models(self) -> List[ModelInfo]:
"""
List available models.
Returns:
List of available models
"""
response = await self._make_request("GET", "internal/model/list")
model_list = ModelListResponse(**response)
return model_list.data
async def load_model(self, model_name: str, **kwargs) -> Dict[str, Any]:
"""
Load a model.
Args:
model_name: Name of the model to load
**kwargs: Additional arguments for loading the model
Returns:
Response from the API
"""
request = ModelLoadRequest(model_name=model_name, args=kwargs)
return await self._make_request("POST", "internal/model/load", request.dict())
async def chat_completion(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncGenerator[Dict[str, Any], None]]:
"""
Create a chat completion.
Args:
request: Chat completion request
Returns:
Chat completion response or a stream of responses
"""
request_data = request.dict(exclude_none=True)
if request.stream:
return self._stream_request("chat/completions", request_data)
response = await self._make_request("POST", "chat/completions", request_data)
return ChatCompletionResponse(**response)
async def completion(
self, request: CompletionRequest
) -> Union[CompletionResponse, AsyncGenerator[Dict[str, Any], None]]:
"""
Create a text completion.
Args:
request: Completion request
Returns:
Completion response or a stream of responses
"""
request_data = request.dict(exclude_none=True)
if request.stream:
return self._stream_request("completions", request_data)
response = await self._make_request("POST", "completions", request_data)
return CompletionResponse(**response)
async def get_logits(self, request: LogitsRequest) -> Dict[str, Any]:
"""
Get logits for a prompt.
Args:
request: Logits request
Returns:
Logits response
"""
request_data = request.dict(exclude_none=True)
return await self._make_request("POST", "internal/logits", request_data)
async def simple_chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: float = 0.7,
top_p: float = 0.9,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
stream: bool = False,
mode: str = "instruct",
character: Optional[str] = None,
instruction_template: Optional[str] = None,
seed: Optional[int] = None,
) -> Union[str, AsyncGenerator[str, None]]:
"""
Simple interface for chat completions.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model to use
temperature: Sampling temperature
top_p: Nucleus sampling parameter
max_tokens: Maximum tokens to generate
stop: Stop sequences
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
stream: Whether to stream the response
mode: Mode (chat or instruct)
character: Character to use (for chat mode)
instruction_template: Instruction template (for instruct mode)
seed: Random seed for reproducibility
Returns:
Generated text or a stream of text chunks
"""
chat_messages = [ChatMessage(**msg) for msg in messages]
request = ChatCompletionRequest(
messages=chat_messages,
model=model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
stream=stream,
mode=mode,
character=character,
instruction_template=instruction_template,
seed=seed,
)
if stream:
async def text_stream() -> AsyncGenerator[str, None]:
stream_response = await self.chat_completion(request)
if isinstance(stream_response, AsyncGenerator):
async for chunk in stream_response:
if "choices" in chunk and chunk["choices"]:
if (
"delta" in chunk["choices"][0]
and "content" in chunk["choices"][0]["delta"]
):
yield chunk["choices"][0]["delta"]["content"]
return text_stream()
else:
response = await self.chat_completion(request)
if isinstance(response, ChatCompletionResponse):
return response.choices[0].message.content
# This should never happen due to the if/else structure, but satisfies the type checker
raise TypeError("Expected ChatCompletionResponse but got stream response")
async def simple_completion(
self,
prompt: str,
model: Optional[str] = None,
temperature: float = 0.7,
top_p: float = 0.9,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
stream: bool = False,
seed: Optional[int] = None,
) -> Union[str, AsyncGenerator[str, None]]:
"""
Simple interface for text completions.
Args:
prompt: Text prompt
model: Model to use
temperature: Sampling temperature
top_p: Nucleus sampling parameter
max_tokens: Maximum tokens to generate
stop: Stop sequences
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
stream: Whether to stream the response
seed: Random seed for reproducibility
Returns:
Generated text or a stream of text chunks
"""
request = CompletionRequest(
prompt=prompt,
model=model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
stream=stream,
seed=seed,
)
if stream:
async def text_stream() -> AsyncGenerator[str, None]:
stream_response = await self.completion(request)
if isinstance(stream_response, AsyncGenerator):
async for chunk in stream_response:
if "choices" in chunk and chunk["choices"]:
if "text" in chunk["choices"][0]:
yield chunk["choices"][0]["text"]
return text_stream()
else:
response = await self.completion(request)
if isinstance(response, CompletionResponse):
return response.choices[0].text
# This should never happen due to the if/else structure, but satisfies the type checker
raise TypeError("Expected CompletionResponse but got stream response")