""" Inference Engine for the Tiny Conversational AI. Handles optimized text generation with streaming, caching, and performance tuning. """ import time import threading from typing import Optional, Iterator, Dict, Any, List, Callable, Union from dataclasses import dataclass from queue import Queue from config import Config, get_config from utils import ( get_logger, Timer, PerformanceTracker, LRUCache, clean_response, get_memory_usage_gb ) from model_loader import ModelLoader, get_loader from conversation import Conversation, create_conversation logger = get_logger(__name__) @dataclass class GenerationResult: """Result of text generation.""" text: str tokens_generated: int total_time_seconds: float first_token_time_seconds: float tokens_per_second: float prompt_tokens: int from_cache: bool = False def __str__(self): return ( f"Generated {self.tokens_generated} tokens in {self.total_time_seconds:.2f}s " f"({self.tokens_per_second:.1f} tokens/s)" ) class InferenceEngine: """ High-performance inference engine for conversational AI. Features: - Streaming generation for instant first-token response - Response caching for common queries - Performance tracking and optimization - Conversation context management """ def __init__( self, model_loader: Optional[ModelLoader] = None, config: Optional[Config] = None ): self.config = config or get_config() self.model_loader = model_loader or get_loader() # Caching self._response_cache = LRUCache( max_size=self.config.performance.cache_size, ttl_seconds=self.config.performance.cache_ttl_seconds ) self._cache_enabled = self.config.performance.enable_response_cache # Performance tracking self._perf_tracker = PerformanceTracker() # Generation settings self._default_params = { "max_tokens": self.config.model.max_new_tokens, "temperature": self.config.model.temperature, "top_p": self.config.model.top_p, "top_k": self.config.model.top_k, "repeat_penalty": self.config.model.repeat_penalty, } # Thread safety self._lock = threading.Lock() self._generating = False def generate( self, prompt: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, repeat_penalty: Optional[float] = None, stop_sequences: Optional[List[str]] = None, stream: bool = False, ) -> Union[GenerationResult, Iterator[str]]: """ Generate text from prompt. Args: prompt: The input prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative) top_p: Nucleus sampling parameter top_k: Top-k sampling parameter repeat_penalty: Penalty for repeating tokens stop_sequences: Sequences that stop generation stream: If True, yields tokens as they're generated Returns: GenerationResult if stream=False, Iterator[str] if stream=True """ # Check if model is loaded if not self.model_loader.is_loaded(): raise RuntimeError("Model not loaded. Call load_model() first.") # Build parameters params = self._default_params.copy() if max_tokens is not None: params["max_tokens"] = max_tokens if temperature is not None: params["temperature"] = temperature if top_p is not None: params["top_p"] = top_p if top_k is not None: params["top_k"] = top_k if repeat_penalty is not None: params["repeat_penalty"] = repeat_penalty # Default stop sequences for TinyLlama / Zephyr format if stop_sequences is None: stop_sequences = ["", "<|user|>", "<|end|>", "<|system|>", "\n\nUser:", "\n\nHuman:"] # Check cache (only for non-streaming) if not stream and self._cache_enabled: cache_key = self._make_cache_key(prompt, params) cached = self._response_cache.get(cache_key) if cached: logger.debug("Cache hit") return GenerationResult( text=cached["text"], tokens_generated=cached["tokens"], total_time_seconds=0.001, first_token_time_seconds=0.001, tokens_per_second=cached["tokens"] / 0.001, prompt_tokens=cached["prompt_tokens"], from_cache=True, ) if stream: return self._generate_streaming(prompt, params, stop_sequences) else: return self._generate_complete(prompt, params, stop_sequences) def _generate_complete( self, prompt: str, params: Dict[str, Any], stop_sequences: List[str] ) -> GenerationResult: """Generate complete response (non-streaming).""" model = self.model_loader.get_model() timer = Timer() timer.start() first_token_time = None # Generate response = model( prompt, max_tokens=params["max_tokens"], temperature=params["temperature"], top_p=params["top_p"], top_k=params["top_k"], repeat_penalty=params["repeat_penalty"], stop=stop_sequences, ) timer.stop() # Extract text text = response["choices"][0]["text"] text = clean_response(text) # Count tokens tokens_generated = response["usage"]["completion_tokens"] prompt_tokens = response["usage"]["prompt_tokens"] # Calculate metrics tokens_per_second = tokens_generated / timer.elapsed if timer.elapsed > 0 else 0 # Track performance self._perf_tracker.record("tokens_per_second", tokens_per_second) self._perf_tracker.record("generation_time", timer.elapsed) # Cache response if self._cache_enabled: cache_key = self._make_cache_key(prompt, params) self._response_cache.set(cache_key, { "text": text, "tokens": tokens_generated, "prompt_tokens": prompt_tokens, }) return GenerationResult( text=text, tokens_generated=tokens_generated, total_time_seconds=timer.elapsed, first_token_time_seconds=timer.elapsed * 0.1, # Estimate tokens_per_second=tokens_per_second, prompt_tokens=prompt_tokens, ) def _generate_streaming( self, prompt: str, params: Dict[str, Any], stop_sequences: List[str] ) -> Iterator[str]: """Generate response with streaming.""" model = self.model_loader.get_model() start_time = time.perf_counter() first_token_time = None tokens_generated = 0 # Stream generation for output in model( prompt, max_tokens=params["max_tokens"], temperature=params["temperature"], top_p=params["top_p"], top_k=params["top_k"], repeat_penalty=params["repeat_penalty"], stop=stop_sequences, stream=True, ): # Record first token time if first_token_time is None: first_token_time = time.perf_counter() - start_time self._perf_tracker.record("first_token_time", first_token_time) # Extract token text token = output["choices"][0]["text"] tokens_generated += 1 yield token # Record final metrics total_time = time.perf_counter() - start_time tokens_per_second = tokens_generated / total_time if total_time > 0 else 0 self._perf_tracker.record("tokens_per_second", tokens_per_second) self._perf_tracker.record("generation_time", total_time) def chat_generate( self, messages: List[Dict[str, str]], stream: bool = False, thinking_mode: bool = False, **kwargs ) -> Union[GenerationResult, Iterator[str]]: """ Generate using the proper chat completion API with message list. This uses llama-cpp's create_chat_completion which applies the model's built-in chat template correctly. Supports thinking mode for deeper reasoning. """ if not self.model_loader.is_loaded(): raise RuntimeError("Model not loaded.") model = self.model_loader.get_model() params = self._default_params.copy() params.update({k: v for k, v in kwargs.items() if v is not None}) # In thinking mode, allow more tokens for reasoning if thinking_mode: params["max_tokens"] = min(int(params.get("max_tokens", 160) * 1.35), 256) params["temperature"] = max(params.get("temperature", 0.45), 0.55) if stream: return self._chat_generate_streaming(model, messages, params) else: return self._chat_generate_complete(model, messages, params) def _chat_generate_complete( self, model, messages: List[Dict[str, str]], params: Dict[str, Any] ) -> GenerationResult: """Non-streaming chat completion.""" timer = Timer() timer.start() response = model.create_chat_completion( messages=messages, max_tokens=params.get("max_tokens", 256), temperature=params.get("temperature", 0.7), top_p=params.get("top_p", 0.9), top_k=params.get("top_k", 40), repeat_penalty=params.get("repeat_penalty", 1.3), ) timer.stop() text = response["choices"][0]["message"]["content"] text = clean_response(text) tokens_generated = response["usage"]["completion_tokens"] prompt_tokens = response["usage"]["prompt_tokens"] tokens_per_second = tokens_generated / timer.elapsed if timer.elapsed > 0 else 0 self._perf_tracker.record("tokens_per_second", tokens_per_second) self._perf_tracker.record("generation_time", timer.elapsed) return GenerationResult( text=text, tokens_generated=tokens_generated, total_time_seconds=timer.elapsed, first_token_time_seconds=timer.elapsed * 0.1, tokens_per_second=tokens_per_second, prompt_tokens=prompt_tokens, ) def _chat_generate_streaming( self, model, messages: List[Dict[str, str]], params: Dict[str, Any] ) -> Iterator[str]: """Streaming chat completion.""" start_time = time.perf_counter() first_token_time = None tokens_generated = 0 for chunk in model.create_chat_completion( messages=messages, max_tokens=params.get("max_tokens", 256), temperature=params.get("temperature", 0.7), top_p=params.get("top_p", 0.9), top_k=params.get("top_k", 40), repeat_penalty=params.get("repeat_penalty", 1.3), stream=True, ): if first_token_time is None: first_token_time = time.perf_counter() - start_time self._perf_tracker.record("first_token_time", first_token_time) delta = chunk["choices"][0].get("delta", {}) token = delta.get("content", "") if token: tokens_generated += 1 yield token total_time = time.perf_counter() - start_time if total_time > 0: self._perf_tracker.record("tokens_per_second", tokens_generated / total_time) self._perf_tracker.record("generation_time", total_time) def chat( self, message: str, conversation: Optional[Conversation] = None, stream: bool = True, **kwargs ) -> Union[GenerationResult, Iterator[str]]: """ Have a conversation with the AI. Uses create_chat_completion for proper template handling. """ if conversation is None: conversation = create_conversation() conversation.add_user_message(message) # Build messages list for chat completion API messages = conversation.get_chat_messages() if stream: def stream_and_collect(): collected = [] for token in self.chat_generate(messages, stream=True, **kwargs): collected.append(token) yield token full_response = "".join(collected) conversation.add_assistant_message(clean_response(full_response)) return stream_and_collect() else: result = self.chat_generate(messages, stream=False, **kwargs) conversation.add_assistant_message(result.text) return result def _make_cache_key(self, prompt: str, params: Dict[str, Any]) -> str: """Create a cache key from prompt and parameters.""" import hashlib import json key_data = { "prompt": prompt[-500:], # Use last 500 chars for uniqueness "params": params, } key_str = json.dumps(key_data, sort_keys=True) return hashlib.md5(key_str.encode()).hexdigest() def get_stats(self) -> Dict[str, Any]: """Get performance statistics.""" stats = { "performance": self._perf_tracker.get_all_stats(), "cache": self._response_cache.stats(), "memory_usage_gb": get_memory_usage_gb(), } if self.model_loader.model_info: stats["model"] = { "name": self.model_loader.model_info.name, "model_id": self.model_loader.model_info.model_id, "load_time": self.model_loader.model_info.load_time_seconds, } return stats def clear_cache(self): """Clear the response cache.""" self._response_cache.clear() logger.info("Cache cleared") def set_cache_enabled(self, enabled: bool): """Enable or disable response caching.""" self._cache_enabled = enabled logger.info(f"Cache {'enabled' if enabled else 'disabled'}") class ChatBot: """ High-level chatbot interface combining all components. This is the main class users will interact with. """ def __init__(self, config: Optional[Config] = None): self.config = config or get_config() self.model_loader = ModelLoader(self.config) self.engine = InferenceEngine(self.model_loader, self.config) self.conversation = Conversation(config=self.config) self._initialized = False self._warmup_done = False def initialize( self, model_id: Optional[str] = None, auto_download: bool = True, warmup: bool = True, progress_callback: Optional[Callable[[float, str], None]] = None ) -> "ChatBot": """ Initialize the chatbot (load model). Args: model_id: Specific model to load (or auto-select) auto_download: Download model if not present warmup: Run warmup inference for faster first response progress_callback: Progress update callback Returns: self for chaining """ # Load model self.model_loader.load( model_id=model_id, auto_download=auto_download, progress_callback=progress_callback ) self._initialized = True # Warmup if warmup and self.config.performance.enable_warmup: if progress_callback: progress_callback(0.9, "Warming up model...") self.model_loader.warmup(self.config.performance.warmup_prompt) self._warmup_done = True if progress_callback: progress_callback(1.0, "Ready!") return self def chat( self, message: str, stream: bool = True, **kwargs ) -> Union[str, Iterator[str]]: """ Send a message and get a response. Args: message: User message stream: Stream response tokens **kwargs: Additional generation parameters Returns: Response string or streaming iterator """ if not self._initialized: raise RuntimeError("ChatBot not initialized. Call initialize() first.") if stream: return self.engine.chat(message, self.conversation, stream=True, **kwargs) else: result = self.engine.chat(message, self.conversation, stream=False, **kwargs) return result.text def chat_simple(self, message: str) -> str: """Simple non-streaming chat. Returns complete response.""" if not self._initialized: raise RuntimeError("ChatBot not initialized. Call initialize() first.") result = self.engine.chat(message, self.conversation, stream=False) return result.text def reset_conversation(self): """Reset the conversation history.""" self.conversation.reset() def get_context(self) -> str: """Get current conversation context summary.""" return self.conversation.get_context_summary() def get_history(self) -> List[Dict[str, str]]: """Get conversation history.""" return self.conversation.get_chat_messages() def get_stats(self) -> Dict[str, Any]: """Get performance statistics.""" return self.engine.get_stats() @property def is_ready(self) -> bool: """Check if chatbot is ready for conversation.""" return self._initialized # ============================================================================= # CONVENIENCE FUNCTIONS # ============================================================================= _global_chatbot: Optional[ChatBot] = None def get_chatbot() -> ChatBot: """Get or create the global chatbot instance.""" global _global_chatbot if _global_chatbot is None: _global_chatbot = ChatBot() return _global_chatbot def quick_chat(message: str, stream: bool = False) -> Union[str, Iterator[str]]: """ Quick chat function - initializes on first use. Args: message: User message stream: Stream response Returns: Response string or iterator """ chatbot = get_chatbot() if not chatbot.is_ready: print("Initializing chatbot (first run)...") chatbot.initialize() return chatbot.chat(message, stream=stream) if __name__ == "__main__": # Test inference engine from utils import print_banner, print_system_status print_banner() print_system_status() print("šŸš€ Initializing ChatBot...\n") def progress(pct: float, msg: str): print(f" [{pct*100:5.1f}%] {msg}") chatbot = ChatBot() chatbot.initialize(progress_callback=progress) print("\nāœ“ ChatBot ready!") print("\n" + "=" * 50) print("Testing conversation...\n") # Test conversation test_messages = [ "Hello! How are you today?", "I'm learning Python. Any tips?", "What about for data science specifically?", "Thanks! That's helpful.", ] for msg in test_messages: print(f"You: {msg}") print("AI: ", end="", flush=True) # Stream response for token in chatbot.chat(msg, stream=True): print(token, end="", flush=True) print("\n") # Show stats print("\n" + "=" * 50) print("\nšŸ“Š Performance Stats:") stats = chatbot.get_stats() if "performance" in stats: perf = stats["performance"] if "tokens_per_second" in perf: print(f" Tokens/second: {perf['tokens_per_second'].get('avg', 0):.1f}") if "generation_time" in perf: print(f" Avg generation time: {perf['generation_time'].get('avg', 0):.2f}s") print(f" Memory usage: {stats.get('memory_usage_gb', 0):.2f} GB")