367 lines
12 KiB
Python
367 lines
12 KiB
Python
"""
|
|
Client for interacting with the Text Generation WebUI API.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import uuid
|
|
import logging
|
|
import asyncio
|
|
import aiohttp
|
|
from typing import Dict, List, Optional, Union, Any, AsyncGenerator, cast
|
|
from urllib.parse import urljoin
|
|
|
|
from .models import (
|
|
ChatMessage,
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
ModelInfo,
|
|
ModelListResponse,
|
|
ModelLoadRequest,
|
|
LogitsRequest,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TextGenClient:
|
|
"""Client for interacting with the Text Generation WebUI API."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str = "http://textgen.localhost/v1",
|
|
api_key: Optional[str] = None,
|
|
timeout: int = 120,
|
|
):
|
|
"""
|
|
Initialize the TextGen client.
|
|
|
|
Args:
|
|
base_url: Base URL for the TextGen API
|
|
api_key: API key for authentication (optional)
|
|
timeout: Request timeout in seconds
|
|
"""
|
|
self.base_url = base_url
|
|
self.api_key = api_key
|
|
self.timeout = timeout
|
|
self._session = None
|
|
|
|
async def _ensure_session(self) -> aiohttp.ClientSession:
|
|
"""Ensure that an aiohttp session exists."""
|
|
if self._session is None or self._session.closed:
|
|
self._session = aiohttp.ClientSession(
|
|
timeout=aiohttp.ClientTimeout(total=self.timeout)
|
|
)
|
|
return self._session
|
|
|
|
async def close(self):
|
|
"""Close the client session."""
|
|
if self._session and not self._session.closed:
|
|
await self._session.close()
|
|
self._session = None
|
|
|
|
def _get_headers(self) -> Dict[str, str]:
|
|
"""Get headers for API requests."""
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
return headers
|
|
|
|
async def _make_request(
|
|
self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Make a request to the TextGen API.
|
|
|
|
Args:
|
|
method: HTTP method (GET, POST, etc.)
|
|
endpoint: API endpoint
|
|
data: Request data
|
|
|
|
Returns:
|
|
API response as a dictionary
|
|
"""
|
|
session = await self._ensure_session()
|
|
url = urljoin(self.base_url, endpoint)
|
|
|
|
try:
|
|
async with session.request(
|
|
method=method,
|
|
url=url,
|
|
headers=self._get_headers(),
|
|
json=data,
|
|
raise_for_status=True,
|
|
) as response:
|
|
return await response.json()
|
|
except aiohttp.ClientResponseError as e:
|
|
logger.error(f"API request failed: {e.status} {e.message}")
|
|
raise
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"Request error: {str(e)}")
|
|
raise
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"Request timed out after {self.timeout} seconds")
|
|
raise
|
|
|
|
async def _stream_request(
|
|
self, endpoint: str, data: Dict[str, Any]
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""
|
|
Make a streaming request to the TextGen API.
|
|
|
|
Args:
|
|
endpoint: API endpoint
|
|
data: Request data
|
|
|
|
Yields:
|
|
Chunks of the API response
|
|
"""
|
|
session = await self._ensure_session()
|
|
url = urljoin(self.base_url, endpoint)
|
|
|
|
try:
|
|
async with session.post(
|
|
url=url,
|
|
headers=self._get_headers(),
|
|
json=data,
|
|
raise_for_status=True,
|
|
) as response:
|
|
async for line in response.content:
|
|
line = line.strip()
|
|
if not line or line == b"data: [DONE]":
|
|
continue
|
|
if line.startswith(b"data: "):
|
|
line = line[6:] # Remove "data: " prefix
|
|
try:
|
|
yield json.loads(line)
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to parse SSE data: {line}")
|
|
except aiohttp.ClientResponseError as e:
|
|
logger.error(f"API request failed: {e.status} {e.message}")
|
|
raise
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"Request error: {str(e)}")
|
|
raise
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"Request timed out after {self.timeout} seconds")
|
|
raise
|
|
|
|
async def list_models(self) -> List[ModelInfo]:
|
|
"""
|
|
List available models.
|
|
|
|
Returns:
|
|
List of available models
|
|
"""
|
|
response = await self._make_request("GET", "internal/model/list")
|
|
model_list = ModelListResponse(**response)
|
|
return model_list.data
|
|
|
|
async def load_model(self, model_name: str, **kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Load a model.
|
|
|
|
Args:
|
|
model_name: Name of the model to load
|
|
**kwargs: Additional arguments for loading the model
|
|
|
|
Returns:
|
|
Response from the API
|
|
"""
|
|
request = ModelLoadRequest(model_name=model_name, args=kwargs)
|
|
return await self._make_request("POST", "internal/model/load", request.dict())
|
|
|
|
async def chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> Union[ChatCompletionResponse, AsyncGenerator[Dict[str, Any], None]]:
|
|
"""
|
|
Create a chat completion.
|
|
|
|
Args:
|
|
request: Chat completion request
|
|
|
|
Returns:
|
|
Chat completion response or a stream of responses
|
|
"""
|
|
request_data = request.dict(exclude_none=True)
|
|
|
|
if request.stream:
|
|
return self._stream_request("chat/completions", request_data)
|
|
|
|
response = await self._make_request("POST", "chat/completions", request_data)
|
|
return ChatCompletionResponse(**response)
|
|
|
|
async def completion(
|
|
self, request: CompletionRequest
|
|
) -> Union[CompletionResponse, AsyncGenerator[Dict[str, Any], None]]:
|
|
"""
|
|
Create a text completion.
|
|
|
|
Args:
|
|
request: Completion request
|
|
|
|
Returns:
|
|
Completion response or a stream of responses
|
|
"""
|
|
request_data = request.dict(exclude_none=True)
|
|
|
|
if request.stream:
|
|
return self._stream_request("completions", request_data)
|
|
|
|
response = await self._make_request("POST", "completions", request_data)
|
|
return CompletionResponse(**response)
|
|
|
|
async def get_logits(self, request: LogitsRequest) -> Dict[str, Any]:
|
|
"""
|
|
Get logits for a prompt.
|
|
|
|
Args:
|
|
request: Logits request
|
|
|
|
Returns:
|
|
Logits response
|
|
"""
|
|
request_data = request.dict(exclude_none=True)
|
|
return await self._make_request("POST", "internal/logits", request_data)
|
|
|
|
async def simple_chat(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: Optional[str] = None,
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.9,
|
|
max_tokens: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
presence_penalty: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
stream: bool = False,
|
|
mode: str = "instruct",
|
|
character: Optional[str] = None,
|
|
instruction_template: Optional[str] = None,
|
|
seed: Optional[int] = None,
|
|
) -> Union[str, AsyncGenerator[str, None]]:
|
|
"""
|
|
Simple interface for chat completions.
|
|
|
|
Args:
|
|
messages: List of message dictionaries with 'role' and 'content'
|
|
model: Model to use
|
|
temperature: Sampling temperature
|
|
top_p: Nucleus sampling parameter
|
|
max_tokens: Maximum tokens to generate
|
|
stop: Stop sequences
|
|
presence_penalty: Presence penalty
|
|
frequency_penalty: Frequency penalty
|
|
stream: Whether to stream the response
|
|
mode: Mode (chat or instruct)
|
|
character: Character to use (for chat mode)
|
|
instruction_template: Instruction template (for instruct mode)
|
|
seed: Random seed for reproducibility
|
|
|
|
Returns:
|
|
Generated text or a stream of text chunks
|
|
"""
|
|
chat_messages = [ChatMessage(**msg) for msg in messages]
|
|
request = ChatCompletionRequest(
|
|
messages=chat_messages,
|
|
model=model,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_tokens=max_tokens,
|
|
stop=stop,
|
|
presence_penalty=presence_penalty,
|
|
frequency_penalty=frequency_penalty,
|
|
stream=stream,
|
|
mode=mode,
|
|
character=character,
|
|
instruction_template=instruction_template,
|
|
seed=seed,
|
|
)
|
|
|
|
if stream:
|
|
|
|
async def text_stream() -> AsyncGenerator[str, None]:
|
|
stream_response = await self.chat_completion(request)
|
|
if isinstance(stream_response, AsyncGenerator):
|
|
async for chunk in stream_response:
|
|
if "choices" in chunk and chunk["choices"]:
|
|
if (
|
|
"delta" in chunk["choices"][0]
|
|
and "content" in chunk["choices"][0]["delta"]
|
|
):
|
|
yield chunk["choices"][0]["delta"]["content"]
|
|
|
|
return text_stream()
|
|
else:
|
|
response = await self.chat_completion(request)
|
|
if isinstance(response, ChatCompletionResponse):
|
|
return response.choices[0].message.content
|
|
# This should never happen due to the if/else structure, but satisfies the type checker
|
|
raise TypeError("Expected ChatCompletionResponse but got stream response")
|
|
|
|
async def simple_completion(
|
|
self,
|
|
prompt: str,
|
|
model: Optional[str] = None,
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.9,
|
|
max_tokens: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
presence_penalty: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
stream: bool = False,
|
|
seed: Optional[int] = None,
|
|
) -> Union[str, AsyncGenerator[str, None]]:
|
|
"""
|
|
Simple interface for text completions.
|
|
|
|
Args:
|
|
prompt: Text prompt
|
|
model: Model to use
|
|
temperature: Sampling temperature
|
|
top_p: Nucleus sampling parameter
|
|
max_tokens: Maximum tokens to generate
|
|
stop: Stop sequences
|
|
presence_penalty: Presence penalty
|
|
frequency_penalty: Frequency penalty
|
|
stream: Whether to stream the response
|
|
seed: Random seed for reproducibility
|
|
|
|
Returns:
|
|
Generated text or a stream of text chunks
|
|
"""
|
|
request = CompletionRequest(
|
|
prompt=prompt,
|
|
model=model,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_tokens=max_tokens,
|
|
stop=stop,
|
|
presence_penalty=presence_penalty,
|
|
frequency_penalty=frequency_penalty,
|
|
stream=stream,
|
|
seed=seed,
|
|
)
|
|
|
|
if stream:
|
|
|
|
async def text_stream() -> AsyncGenerator[str, None]:
|
|
stream_response = await self.completion(request)
|
|
if isinstance(stream_response, AsyncGenerator):
|
|
async for chunk in stream_response:
|
|
if "choices" in chunk and chunk["choices"]:
|
|
if "text" in chunk["choices"][0]:
|
|
yield chunk["choices"][0]["text"]
|
|
|
|
return text_stream()
|
|
else:
|
|
response = await self.completion(request)
|
|
if isinstance(response, CompletionResponse):
|
|
return response.choices[0].text
|
|
# This should never happen due to the if/else structure, but satisfies the type checker
|
|
raise TypeError("Expected CompletionResponse but got stream response")
|