141 lines
5.8 KiB
Python
141 lines
5.8 KiB
Python
"""
|
|
Data models for the TextGen API.
|
|
"""
|
|
|
|
from typing import Dict, List, Optional, Union, Any
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""A chat message in a conversation."""
|
|
|
|
role: str = Field(
|
|
..., description="The role of the message sender (user, assistant, system)"
|
|
)
|
|
content: str = Field(..., description="The content of the message")
|
|
name: Optional[str] = Field(None, description="The name of the sender (optional)")
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""Request model for chat completions."""
|
|
|
|
messages: List[ChatMessage] = Field(
|
|
..., description="The messages in the conversation"
|
|
)
|
|
model: Optional[str] = Field(None, description="The model to use for completion")
|
|
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
|
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter")
|
|
max_tokens: Optional[int] = Field(
|
|
None, description="Maximum number of tokens to generate"
|
|
)
|
|
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
|
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
|
|
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
|
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
|
|
mode: Optional[str] = Field("chat", description="Mode (chat or instruct)")
|
|
character: Optional[str] = Field(
|
|
None, description="Character to use (for chat mode)"
|
|
)
|
|
instruction_template: Optional[str] = Field(
|
|
None, description="Instruction template (for instruct mode)"
|
|
)
|
|
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
|
|
|
|
|
class ChatCompletionResponseChoice(BaseModel):
|
|
"""A choice in a chat completion response."""
|
|
|
|
index: int = Field(..., description="Index of the choice")
|
|
message: ChatMessage = Field(..., description="The message")
|
|
finish_reason: Optional[str] = Field(None, description="Reason for finishing")
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""Response model for chat completions."""
|
|
|
|
id: str = Field(..., description="Unique identifier for the completion")
|
|
object: str = Field("chat.completion", description="Object type")
|
|
created: int = Field(..., description="Unix timestamp of creation")
|
|
model: str = Field(..., description="Model used for completion")
|
|
choices: List[ChatCompletionResponseChoice] = Field(
|
|
..., description="Completion choices"
|
|
)
|
|
usage: Dict[str, int] = Field(..., description="Token usage information")
|
|
|
|
|
|
class CompletionRequest(BaseModel):
|
|
"""Request model for text completions."""
|
|
|
|
prompt: str = Field(..., description="The prompt to complete")
|
|
model: Optional[str] = Field(None, description="The model to use for completion")
|
|
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
|
top_p: Optional[float] = Field(0.9, description="Nucleus sampling parameter")
|
|
max_tokens: Optional[int] = Field(
|
|
None, description="Maximum number of tokens to generate"
|
|
)
|
|
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
|
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
|
|
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
|
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
|
|
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
|
|
|
|
|
class CompletionResponseChoice(BaseModel):
|
|
"""A choice in a completion response."""
|
|
|
|
text: str = Field(..., description="The generated text")
|
|
index: int = Field(..., description="Index of the choice")
|
|
logprobs: Optional[Any] = Field(None, description="Log probabilities")
|
|
finish_reason: Optional[str] = Field(None, description="Reason for finishing")
|
|
|
|
|
|
class CompletionResponse(BaseModel):
|
|
"""Response model for text completions."""
|
|
|
|
id: str = Field(..., description="Unique identifier for the completion")
|
|
object: str = Field("text_completion", description="Object type")
|
|
created: int = Field(..., description="Unix timestamp of creation")
|
|
model: str = Field(..., description="Model used for completion")
|
|
choices: List[CompletionResponseChoice] = Field(
|
|
..., description="Completion choices"
|
|
)
|
|
usage: Dict[str, int] = Field(..., description="Token usage information")
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""Information about a model."""
|
|
|
|
id: str = Field(..., description="Model identifier")
|
|
object: str = Field("model", description="Object type")
|
|
created: int = Field(..., description="Unix timestamp of creation")
|
|
owned_by: str = Field("user", description="Owner of the model")
|
|
permission: List[Dict[str, Any]] = Field([], description="Permissions")
|
|
root: str = Field(..., description="Root model")
|
|
parent: Optional[str] = Field(None, description="Parent model")
|
|
|
|
|
|
class LogitsRequest(BaseModel):
|
|
"""Request model for logits."""
|
|
|
|
prompt: str = Field(..., description="The prompt to get logits for")
|
|
use_samplers: bool = Field(
|
|
False, description="Whether to apply sampling parameters"
|
|
)
|
|
top_k: Optional[int] = Field(None, description="Top-k sampling parameter")
|
|
top_p: Optional[float] = Field(None, description="Top-p sampling parameter")
|
|
temperature: Optional[float] = Field(None, description="Sampling temperature")
|
|
|
|
|
|
class ModelListResponse(BaseModel):
|
|
"""Response model for model list."""
|
|
|
|
object: str = Field("list", description="Object type")
|
|
data: List[ModelInfo] = Field(..., description="List of models")
|
|
|
|
|
|
class ModelLoadRequest(BaseModel):
|
|
"""Request model for loading a model."""
|
|
|
|
model_name: str = Field(..., description="Name of the model to load")
|
|
args: Dict[str, Any] = Field({}, description="Arguments for loading the model")
|