Naveedai / inference.py
bilalnaveed's picture
Ultra-fast preset: shorter context/output, reduced history, lower thinking overhead
bb2212d verified
"""
Inference Engine for the Tiny Conversational AI.
Handles optimized text generation with streaming, caching, and performance tuning.
"""
import time
import threading
from typing import Optional, Iterator, Dict, Any, List, Callable, Union
from dataclasses import dataclass
from queue import Queue
from config import Config, get_config
from utils import (
get_logger, Timer, PerformanceTracker, LRUCache,
clean_response, get_memory_usage_gb
)
from model_loader import ModelLoader, get_loader
from conversation import Conversation, create_conversation
logger = get_logger(__name__)
@dataclass
class GenerationResult:
"""Result of text generation."""
text: str
tokens_generated: int
total_time_seconds: float
first_token_time_seconds: float
tokens_per_second: float
prompt_tokens: int
from_cache: bool = False
def __str__(self):
return (
f"Generated {self.tokens_generated} tokens in {self.total_time_seconds:.2f}s "
f"({self.tokens_per_second:.1f} tokens/s)"
)
class InferenceEngine:
"""
High-performance inference engine for conversational AI.
Features:
- Streaming generation for instant first-token response
- Response caching for common queries
- Performance tracking and optimization
- Conversation context management
"""
def __init__(
self,
model_loader: Optional[ModelLoader] = None,
config: Optional[Config] = None
):
self.config = config or get_config()
self.model_loader = model_loader or get_loader()
# Caching
self._response_cache = LRUCache(
max_size=self.config.performance.cache_size,
ttl_seconds=self.config.performance.cache_ttl_seconds
)
self._cache_enabled = self.config.performance.enable_response_cache
# Performance tracking
self._perf_tracker = PerformanceTracker()
# Generation settings
self._default_params = {
"max_tokens": self.config.model.max_new_tokens,
"temperature": self.config.model.temperature,
"top_p": self.config.model.top_p,
"top_k": self.config.model.top_k,
"repeat_penalty": self.config.model.repeat_penalty,
}
# Thread safety
self._lock = threading.Lock()
self._generating = False
def generate(
self,
prompt: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repeat_penalty: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
stream: bool = False,
) -> Union[GenerationResult, Iterator[str]]:
"""
Generate text from prompt.
Args:
prompt: The input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
repeat_penalty: Penalty for repeating tokens
stop_sequences: Sequences that stop generation
stream: If True, yields tokens as they're generated
Returns:
GenerationResult if stream=False, Iterator[str] if stream=True
"""
# Check if model is loaded
if not self.model_loader.is_loaded():
raise RuntimeError("Model not loaded. Call load_model() first.")
# Build parameters
params = self._default_params.copy()
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None:
params["temperature"] = temperature
if top_p is not None:
params["top_p"] = top_p
if top_k is not None:
params["top_k"] = top_k
if repeat_penalty is not None:
params["repeat_penalty"] = repeat_penalty
# Default stop sequences for TinyLlama / Zephyr format
if stop_sequences is None:
stop_sequences = ["</s>", "<|user|>", "<|end|>", "<|system|>", "\n\nUser:", "\n\nHuman:"]
# Check cache (only for non-streaming)
if not stream and self._cache_enabled:
cache_key = self._make_cache_key(prompt, params)
cached = self._response_cache.get(cache_key)
if cached:
logger.debug("Cache hit")
return GenerationResult(
text=cached["text"],
tokens_generated=cached["tokens"],
total_time_seconds=0.001,
first_token_time_seconds=0.001,
tokens_per_second=cached["tokens"] / 0.001,
prompt_tokens=cached["prompt_tokens"],
from_cache=True,
)
if stream:
return self._generate_streaming(prompt, params, stop_sequences)
else:
return self._generate_complete(prompt, params, stop_sequences)
def _generate_complete(
self,
prompt: str,
params: Dict[str, Any],
stop_sequences: List[str]
) -> GenerationResult:
"""Generate complete response (non-streaming)."""
model = self.model_loader.get_model()
timer = Timer()
timer.start()
first_token_time = None
# Generate
response = model(
prompt,
max_tokens=params["max_tokens"],
temperature=params["temperature"],
top_p=params["top_p"],
top_k=params["top_k"],
repeat_penalty=params["repeat_penalty"],
stop=stop_sequences,
)
timer.stop()
# Extract text
text = response["choices"][0]["text"]
text = clean_response(text)
# Count tokens
tokens_generated = response["usage"]["completion_tokens"]
prompt_tokens = response["usage"]["prompt_tokens"]
# Calculate metrics
tokens_per_second = tokens_generated / timer.elapsed if timer.elapsed > 0 else 0
# Track performance
self._perf_tracker.record("tokens_per_second", tokens_per_second)
self._perf_tracker.record("generation_time", timer.elapsed)
# Cache response
if self._cache_enabled:
cache_key = self._make_cache_key(prompt, params)
self._response_cache.set(cache_key, {
"text": text,
"tokens": tokens_generated,
"prompt_tokens": prompt_tokens,
})
return GenerationResult(
text=text,
tokens_generated=tokens_generated,
total_time_seconds=timer.elapsed,
first_token_time_seconds=timer.elapsed * 0.1, # Estimate
tokens_per_second=tokens_per_second,
prompt_tokens=prompt_tokens,
)
def _generate_streaming(
self,
prompt: str,
params: Dict[str, Any],
stop_sequences: List[str]
) -> Iterator[str]:
"""Generate response with streaming."""
model = self.model_loader.get_model()
start_time = time.perf_counter()
first_token_time = None
tokens_generated = 0
# Stream generation
for output in model(
prompt,
max_tokens=params["max_tokens"],
temperature=params["temperature"],
top_p=params["top_p"],
top_k=params["top_k"],
repeat_penalty=params["repeat_penalty"],
stop=stop_sequences,
stream=True,
):
# Record first token time
if first_token_time is None:
first_token_time = time.perf_counter() - start_time
self._perf_tracker.record("first_token_time", first_token_time)
# Extract token text
token = output["choices"][0]["text"]
tokens_generated += 1
yield token
# Record final metrics
total_time = time.perf_counter() - start_time
tokens_per_second = tokens_generated / total_time if total_time > 0 else 0
self._perf_tracker.record("tokens_per_second", tokens_per_second)
self._perf_tracker.record("generation_time", total_time)
def chat_generate(
self,
messages: List[Dict[str, str]],
stream: bool = False,
thinking_mode: bool = False,
**kwargs
) -> Union[GenerationResult, Iterator[str]]:
"""
Generate using the proper chat completion API with message list.
This uses llama-cpp's create_chat_completion which applies
the model's built-in chat template correctly.
Supports thinking mode for deeper reasoning.
"""
if not self.model_loader.is_loaded():
raise RuntimeError("Model not loaded.")
model = self.model_loader.get_model()
params = self._default_params.copy()
params.update({k: v for k, v in kwargs.items() if v is not None})
# In thinking mode, allow more tokens for reasoning
if thinking_mode:
params["max_tokens"] = min(int(params.get("max_tokens", 160) * 1.35), 256)
params["temperature"] = max(params.get("temperature", 0.45), 0.55)
if stream:
return self._chat_generate_streaming(model, messages, params)
else:
return self._chat_generate_complete(model, messages, params)
def _chat_generate_complete(
self,
model,
messages: List[Dict[str, str]],
params: Dict[str, Any]
) -> GenerationResult:
"""Non-streaming chat completion."""
timer = Timer()
timer.start()
response = model.create_chat_completion(
messages=messages,
max_tokens=params.get("max_tokens", 256),
temperature=params.get("temperature", 0.7),
top_p=params.get("top_p", 0.9),
top_k=params.get("top_k", 40),
repeat_penalty=params.get("repeat_penalty", 1.3),
)
timer.stop()
text = response["choices"][0]["message"]["content"]
text = clean_response(text)
tokens_generated = response["usage"]["completion_tokens"]
prompt_tokens = response["usage"]["prompt_tokens"]
tokens_per_second = tokens_generated / timer.elapsed if timer.elapsed > 0 else 0
self._perf_tracker.record("tokens_per_second", tokens_per_second)
self._perf_tracker.record("generation_time", timer.elapsed)
return GenerationResult(
text=text,
tokens_generated=tokens_generated,
total_time_seconds=timer.elapsed,
first_token_time_seconds=timer.elapsed * 0.1,
tokens_per_second=tokens_per_second,
prompt_tokens=prompt_tokens,
)
def _chat_generate_streaming(
self,
model,
messages: List[Dict[str, str]],
params: Dict[str, Any]
) -> Iterator[str]:
"""Streaming chat completion."""
start_time = time.perf_counter()
first_token_time = None
tokens_generated = 0
for chunk in model.create_chat_completion(
messages=messages,
max_tokens=params.get("max_tokens", 256),
temperature=params.get("temperature", 0.7),
top_p=params.get("top_p", 0.9),
top_k=params.get("top_k", 40),
repeat_penalty=params.get("repeat_penalty", 1.3),
stream=True,
):
if first_token_time is None:
first_token_time = time.perf_counter() - start_time
self._perf_tracker.record("first_token_time", first_token_time)
delta = chunk["choices"][0].get("delta", {})
token = delta.get("content", "")
if token:
tokens_generated += 1
yield token
total_time = time.perf_counter() - start_time
if total_time > 0:
self._perf_tracker.record("tokens_per_second", tokens_generated / total_time)
self._perf_tracker.record("generation_time", total_time)
def chat(
self,
message: str,
conversation: Optional[Conversation] = None,
stream: bool = True,
**kwargs
) -> Union[GenerationResult, Iterator[str]]:
"""
Have a conversation with the AI.
Uses create_chat_completion for proper template handling.
"""
if conversation is None:
conversation = create_conversation()
conversation.add_user_message(message)
# Build messages list for chat completion API
messages = conversation.get_chat_messages()
if stream:
def stream_and_collect():
collected = []
for token in self.chat_generate(messages, stream=True, **kwargs):
collected.append(token)
yield token
full_response = "".join(collected)
conversation.add_assistant_message(clean_response(full_response))
return stream_and_collect()
else:
result = self.chat_generate(messages, stream=False, **kwargs)
conversation.add_assistant_message(result.text)
return result
def _make_cache_key(self, prompt: str, params: Dict[str, Any]) -> str:
"""Create a cache key from prompt and parameters."""
import hashlib
import json
key_data = {
"prompt": prompt[-500:], # Use last 500 chars for uniqueness
"params": params,
}
key_str = json.dumps(key_data, sort_keys=True)
return hashlib.md5(key_str.encode()).hexdigest()
def get_stats(self) -> Dict[str, Any]:
"""Get performance statistics."""
stats = {
"performance": self._perf_tracker.get_all_stats(),
"cache": self._response_cache.stats(),
"memory_usage_gb": get_memory_usage_gb(),
}
if self.model_loader.model_info:
stats["model"] = {
"name": self.model_loader.model_info.name,
"model_id": self.model_loader.model_info.model_id,
"load_time": self.model_loader.model_info.load_time_seconds,
}
return stats
def clear_cache(self):
"""Clear the response cache."""
self._response_cache.clear()
logger.info("Cache cleared")
def set_cache_enabled(self, enabled: bool):
"""Enable or disable response caching."""
self._cache_enabled = enabled
logger.info(f"Cache {'enabled' if enabled else 'disabled'}")
class ChatBot:
"""
High-level chatbot interface combining all components.
This is the main class users will interact with.
"""
def __init__(self, config: Optional[Config] = None):
self.config = config or get_config()
self.model_loader = ModelLoader(self.config)
self.engine = InferenceEngine(self.model_loader, self.config)
self.conversation = Conversation(config=self.config)
self._initialized = False
self._warmup_done = False
def initialize(
self,
model_id: Optional[str] = None,
auto_download: bool = True,
warmup: bool = True,
progress_callback: Optional[Callable[[float, str], None]] = None
) -> "ChatBot":
"""
Initialize the chatbot (load model).
Args:
model_id: Specific model to load (or auto-select)
auto_download: Download model if not present
warmup: Run warmup inference for faster first response
progress_callback: Progress update callback
Returns:
self for chaining
"""
# Load model
self.model_loader.load(
model_id=model_id,
auto_download=auto_download,
progress_callback=progress_callback
)
self._initialized = True
# Warmup
if warmup and self.config.performance.enable_warmup:
if progress_callback:
progress_callback(0.9, "Warming up model...")
self.model_loader.warmup(self.config.performance.warmup_prompt)
self._warmup_done = True
if progress_callback:
progress_callback(1.0, "Ready!")
return self
def chat(
self,
message: str,
stream: bool = True,
**kwargs
) -> Union[str, Iterator[str]]:
"""
Send a message and get a response.
Args:
message: User message
stream: Stream response tokens
**kwargs: Additional generation parameters
Returns:
Response string or streaming iterator
"""
if not self._initialized:
raise RuntimeError("ChatBot not initialized. Call initialize() first.")
if stream:
return self.engine.chat(message, self.conversation, stream=True, **kwargs)
else:
result = self.engine.chat(message, self.conversation, stream=False, **kwargs)
return result.text
def chat_simple(self, message: str) -> str:
"""Simple non-streaming chat. Returns complete response."""
if not self._initialized:
raise RuntimeError("ChatBot not initialized. Call initialize() first.")
result = self.engine.chat(message, self.conversation, stream=False)
return result.text
def reset_conversation(self):
"""Reset the conversation history."""
self.conversation.reset()
def get_context(self) -> str:
"""Get current conversation context summary."""
return self.conversation.get_context_summary()
def get_history(self) -> List[Dict[str, str]]:
"""Get conversation history."""
return self.conversation.get_chat_messages()
def get_stats(self) -> Dict[str, Any]:
"""Get performance statistics."""
return self.engine.get_stats()
@property
def is_ready(self) -> bool:
"""Check if chatbot is ready for conversation."""
return self._initialized
# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================
_global_chatbot: Optional[ChatBot] = None
def get_chatbot() -> ChatBot:
"""Get or create the global chatbot instance."""
global _global_chatbot
if _global_chatbot is None:
_global_chatbot = ChatBot()
return _global_chatbot
def quick_chat(message: str, stream: bool = False) -> Union[str, Iterator[str]]:
"""
Quick chat function - initializes on first use.
Args:
message: User message
stream: Stream response
Returns:
Response string or iterator
"""
chatbot = get_chatbot()
if not chatbot.is_ready:
print("Initializing chatbot (first run)...")
chatbot.initialize()
return chatbot.chat(message, stream=stream)
if __name__ == "__main__":
# Test inference engine
from utils import print_banner, print_system_status
print_banner()
print_system_status()
print("🚀 Initializing ChatBot...\n")
def progress(pct: float, msg: str):
print(f" [{pct*100:5.1f}%] {msg}")
chatbot = ChatBot()
chatbot.initialize(progress_callback=progress)
print("\n✓ ChatBot ready!")
print("\n" + "=" * 50)
print("Testing conversation...\n")
# Test conversation
test_messages = [
"Hello! How are you today?",
"I'm learning Python. Any tips?",
"What about for data science specifically?",
"Thanks! That's helpful.",
]
for msg in test_messages:
print(f"You: {msg}")
print("AI: ", end="", flush=True)
# Stream response
for token in chatbot.chat(msg, stream=True):
print(token, end="", flush=True)
print("\n")
# Show stats
print("\n" + "=" * 50)
print("\n📊 Performance Stats:")
stats = chatbot.get_stats()
if "performance" in stats:
perf = stats["performance"]
if "tokens_per_second" in perf:
print(f" Tokens/second: {perf['tokens_per_second'].get('avg', 0):.1f}")
if "generation_time" in perf:
print(f" Avg generation time: {perf['generation_time'].get('avg', 0):.2f}s")
print(f" Memory usage: {stats.get('memory_usage_gb', 0):.2f} GB")