Spaces:
Runtime error
Runtime error
Ultra-fast preset: shorter context/output, reduced history, lower thinking overhead
bb2212d verified | """ | |
| Inference Engine for the Tiny Conversational AI. | |
| Handles optimized text generation with streaming, caching, and performance tuning. | |
| """ | |
| import time | |
| import threading | |
| from typing import Optional, Iterator, Dict, Any, List, Callable, Union | |
| from dataclasses import dataclass | |
| from queue import Queue | |
| from config import Config, get_config | |
| from utils import ( | |
| get_logger, Timer, PerformanceTracker, LRUCache, | |
| clean_response, get_memory_usage_gb | |
| ) | |
| from model_loader import ModelLoader, get_loader | |
| from conversation import Conversation, create_conversation | |
| logger = get_logger(__name__) | |
| class GenerationResult: | |
| """Result of text generation.""" | |
| text: str | |
| tokens_generated: int | |
| total_time_seconds: float | |
| first_token_time_seconds: float | |
| tokens_per_second: float | |
| prompt_tokens: int | |
| from_cache: bool = False | |
| def __str__(self): | |
| return ( | |
| f"Generated {self.tokens_generated} tokens in {self.total_time_seconds:.2f}s " | |
| f"({self.tokens_per_second:.1f} tokens/s)" | |
| ) | |
| class InferenceEngine: | |
| """ | |
| High-performance inference engine for conversational AI. | |
| Features: | |
| - Streaming generation for instant first-token response | |
| - Response caching for common queries | |
| - Performance tracking and optimization | |
| - Conversation context management | |
| """ | |
| def __init__( | |
| self, | |
| model_loader: Optional[ModelLoader] = None, | |
| config: Optional[Config] = None | |
| ): | |
| self.config = config or get_config() | |
| self.model_loader = model_loader or get_loader() | |
| # Caching | |
| self._response_cache = LRUCache( | |
| max_size=self.config.performance.cache_size, | |
| ttl_seconds=self.config.performance.cache_ttl_seconds | |
| ) | |
| self._cache_enabled = self.config.performance.enable_response_cache | |
| # Performance tracking | |
| self._perf_tracker = PerformanceTracker() | |
| # Generation settings | |
| self._default_params = { | |
| "max_tokens": self.config.model.max_new_tokens, | |
| "temperature": self.config.model.temperature, | |
| "top_p": self.config.model.top_p, | |
| "top_k": self.config.model.top_k, | |
| "repeat_penalty": self.config.model.repeat_penalty, | |
| } | |
| # Thread safety | |
| self._lock = threading.Lock() | |
| self._generating = False | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_tokens: Optional[int] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| repeat_penalty: Optional[float] = None, | |
| stop_sequences: Optional[List[str]] = None, | |
| stream: bool = False, | |
| ) -> Union[GenerationResult, Iterator[str]]: | |
| """ | |
| Generate text from prompt. | |
| Args: | |
| prompt: The input prompt | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative) | |
| top_p: Nucleus sampling parameter | |
| top_k: Top-k sampling parameter | |
| repeat_penalty: Penalty for repeating tokens | |
| stop_sequences: Sequences that stop generation | |
| stream: If True, yields tokens as they're generated | |
| Returns: | |
| GenerationResult if stream=False, Iterator[str] if stream=True | |
| """ | |
| # Check if model is loaded | |
| if not self.model_loader.is_loaded(): | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| # Build parameters | |
| params = self._default_params.copy() | |
| if max_tokens is not None: | |
| params["max_tokens"] = max_tokens | |
| if temperature is not None: | |
| params["temperature"] = temperature | |
| if top_p is not None: | |
| params["top_p"] = top_p | |
| if top_k is not None: | |
| params["top_k"] = top_k | |
| if repeat_penalty is not None: | |
| params["repeat_penalty"] = repeat_penalty | |
| # Default stop sequences for TinyLlama / Zephyr format | |
| if stop_sequences is None: | |
| stop_sequences = ["</s>", "<|user|>", "<|end|>", "<|system|>", "\n\nUser:", "\n\nHuman:"] | |
| # Check cache (only for non-streaming) | |
| if not stream and self._cache_enabled: | |
| cache_key = self._make_cache_key(prompt, params) | |
| cached = self._response_cache.get(cache_key) | |
| if cached: | |
| logger.debug("Cache hit") | |
| return GenerationResult( | |
| text=cached["text"], | |
| tokens_generated=cached["tokens"], | |
| total_time_seconds=0.001, | |
| first_token_time_seconds=0.001, | |
| tokens_per_second=cached["tokens"] / 0.001, | |
| prompt_tokens=cached["prompt_tokens"], | |
| from_cache=True, | |
| ) | |
| if stream: | |
| return self._generate_streaming(prompt, params, stop_sequences) | |
| else: | |
| return self._generate_complete(prompt, params, stop_sequences) | |
| def _generate_complete( | |
| self, | |
| prompt: str, | |
| params: Dict[str, Any], | |
| stop_sequences: List[str] | |
| ) -> GenerationResult: | |
| """Generate complete response (non-streaming).""" | |
| model = self.model_loader.get_model() | |
| timer = Timer() | |
| timer.start() | |
| first_token_time = None | |
| # Generate | |
| response = model( | |
| prompt, | |
| max_tokens=params["max_tokens"], | |
| temperature=params["temperature"], | |
| top_p=params["top_p"], | |
| top_k=params["top_k"], | |
| repeat_penalty=params["repeat_penalty"], | |
| stop=stop_sequences, | |
| ) | |
| timer.stop() | |
| # Extract text | |
| text = response["choices"][0]["text"] | |
| text = clean_response(text) | |
| # Count tokens | |
| tokens_generated = response["usage"]["completion_tokens"] | |
| prompt_tokens = response["usage"]["prompt_tokens"] | |
| # Calculate metrics | |
| tokens_per_second = tokens_generated / timer.elapsed if timer.elapsed > 0 else 0 | |
| # Track performance | |
| self._perf_tracker.record("tokens_per_second", tokens_per_second) | |
| self._perf_tracker.record("generation_time", timer.elapsed) | |
| # Cache response | |
| if self._cache_enabled: | |
| cache_key = self._make_cache_key(prompt, params) | |
| self._response_cache.set(cache_key, { | |
| "text": text, | |
| "tokens": tokens_generated, | |
| "prompt_tokens": prompt_tokens, | |
| }) | |
| return GenerationResult( | |
| text=text, | |
| tokens_generated=tokens_generated, | |
| total_time_seconds=timer.elapsed, | |
| first_token_time_seconds=timer.elapsed * 0.1, # Estimate | |
| tokens_per_second=tokens_per_second, | |
| prompt_tokens=prompt_tokens, | |
| ) | |
| def _generate_streaming( | |
| self, | |
| prompt: str, | |
| params: Dict[str, Any], | |
| stop_sequences: List[str] | |
| ) -> Iterator[str]: | |
| """Generate response with streaming.""" | |
| model = self.model_loader.get_model() | |
| start_time = time.perf_counter() | |
| first_token_time = None | |
| tokens_generated = 0 | |
| # Stream generation | |
| for output in model( | |
| prompt, | |
| max_tokens=params["max_tokens"], | |
| temperature=params["temperature"], | |
| top_p=params["top_p"], | |
| top_k=params["top_k"], | |
| repeat_penalty=params["repeat_penalty"], | |
| stop=stop_sequences, | |
| stream=True, | |
| ): | |
| # Record first token time | |
| if first_token_time is None: | |
| first_token_time = time.perf_counter() - start_time | |
| self._perf_tracker.record("first_token_time", first_token_time) | |
| # Extract token text | |
| token = output["choices"][0]["text"] | |
| tokens_generated += 1 | |
| yield token | |
| # Record final metrics | |
| total_time = time.perf_counter() - start_time | |
| tokens_per_second = tokens_generated / total_time if total_time > 0 else 0 | |
| self._perf_tracker.record("tokens_per_second", tokens_per_second) | |
| self._perf_tracker.record("generation_time", total_time) | |
| def chat_generate( | |
| self, | |
| messages: List[Dict[str, str]], | |
| stream: bool = False, | |
| thinking_mode: bool = False, | |
| **kwargs | |
| ) -> Union[GenerationResult, Iterator[str]]: | |
| """ | |
| Generate using the proper chat completion API with message list. | |
| This uses llama-cpp's create_chat_completion which applies | |
| the model's built-in chat template correctly. | |
| Supports thinking mode for deeper reasoning. | |
| """ | |
| if not self.model_loader.is_loaded(): | |
| raise RuntimeError("Model not loaded.") | |
| model = self.model_loader.get_model() | |
| params = self._default_params.copy() | |
| params.update({k: v for k, v in kwargs.items() if v is not None}) | |
| # In thinking mode, allow more tokens for reasoning | |
| if thinking_mode: | |
| params["max_tokens"] = min(int(params.get("max_tokens", 160) * 1.35), 256) | |
| params["temperature"] = max(params.get("temperature", 0.45), 0.55) | |
| if stream: | |
| return self._chat_generate_streaming(model, messages, params) | |
| else: | |
| return self._chat_generate_complete(model, messages, params) | |
| def _chat_generate_complete( | |
| self, | |
| model, | |
| messages: List[Dict[str, str]], | |
| params: Dict[str, Any] | |
| ) -> GenerationResult: | |
| """Non-streaming chat completion.""" | |
| timer = Timer() | |
| timer.start() | |
| response = model.create_chat_completion( | |
| messages=messages, | |
| max_tokens=params.get("max_tokens", 256), | |
| temperature=params.get("temperature", 0.7), | |
| top_p=params.get("top_p", 0.9), | |
| top_k=params.get("top_k", 40), | |
| repeat_penalty=params.get("repeat_penalty", 1.3), | |
| ) | |
| timer.stop() | |
| text = response["choices"][0]["message"]["content"] | |
| text = clean_response(text) | |
| tokens_generated = response["usage"]["completion_tokens"] | |
| prompt_tokens = response["usage"]["prompt_tokens"] | |
| tokens_per_second = tokens_generated / timer.elapsed if timer.elapsed > 0 else 0 | |
| self._perf_tracker.record("tokens_per_second", tokens_per_second) | |
| self._perf_tracker.record("generation_time", timer.elapsed) | |
| return GenerationResult( | |
| text=text, | |
| tokens_generated=tokens_generated, | |
| total_time_seconds=timer.elapsed, | |
| first_token_time_seconds=timer.elapsed * 0.1, | |
| tokens_per_second=tokens_per_second, | |
| prompt_tokens=prompt_tokens, | |
| ) | |
| def _chat_generate_streaming( | |
| self, | |
| model, | |
| messages: List[Dict[str, str]], | |
| params: Dict[str, Any] | |
| ) -> Iterator[str]: | |
| """Streaming chat completion.""" | |
| start_time = time.perf_counter() | |
| first_token_time = None | |
| tokens_generated = 0 | |
| for chunk in model.create_chat_completion( | |
| messages=messages, | |
| max_tokens=params.get("max_tokens", 256), | |
| temperature=params.get("temperature", 0.7), | |
| top_p=params.get("top_p", 0.9), | |
| top_k=params.get("top_k", 40), | |
| repeat_penalty=params.get("repeat_penalty", 1.3), | |
| stream=True, | |
| ): | |
| if first_token_time is None: | |
| first_token_time = time.perf_counter() - start_time | |
| self._perf_tracker.record("first_token_time", first_token_time) | |
| delta = chunk["choices"][0].get("delta", {}) | |
| token = delta.get("content", "") | |
| if token: | |
| tokens_generated += 1 | |
| yield token | |
| total_time = time.perf_counter() - start_time | |
| if total_time > 0: | |
| self._perf_tracker.record("tokens_per_second", tokens_generated / total_time) | |
| self._perf_tracker.record("generation_time", total_time) | |
| def chat( | |
| self, | |
| message: str, | |
| conversation: Optional[Conversation] = None, | |
| stream: bool = True, | |
| **kwargs | |
| ) -> Union[GenerationResult, Iterator[str]]: | |
| """ | |
| Have a conversation with the AI. | |
| Uses create_chat_completion for proper template handling. | |
| """ | |
| if conversation is None: | |
| conversation = create_conversation() | |
| conversation.add_user_message(message) | |
| # Build messages list for chat completion API | |
| messages = conversation.get_chat_messages() | |
| if stream: | |
| def stream_and_collect(): | |
| collected = [] | |
| for token in self.chat_generate(messages, stream=True, **kwargs): | |
| collected.append(token) | |
| yield token | |
| full_response = "".join(collected) | |
| conversation.add_assistant_message(clean_response(full_response)) | |
| return stream_and_collect() | |
| else: | |
| result = self.chat_generate(messages, stream=False, **kwargs) | |
| conversation.add_assistant_message(result.text) | |
| return result | |
| def _make_cache_key(self, prompt: str, params: Dict[str, Any]) -> str: | |
| """Create a cache key from prompt and parameters.""" | |
| import hashlib | |
| import json | |
| key_data = { | |
| "prompt": prompt[-500:], # Use last 500 chars for uniqueness | |
| "params": params, | |
| } | |
| key_str = json.dumps(key_data, sort_keys=True) | |
| return hashlib.md5(key_str.encode()).hexdigest() | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get performance statistics.""" | |
| stats = { | |
| "performance": self._perf_tracker.get_all_stats(), | |
| "cache": self._response_cache.stats(), | |
| "memory_usage_gb": get_memory_usage_gb(), | |
| } | |
| if self.model_loader.model_info: | |
| stats["model"] = { | |
| "name": self.model_loader.model_info.name, | |
| "model_id": self.model_loader.model_info.model_id, | |
| "load_time": self.model_loader.model_info.load_time_seconds, | |
| } | |
| return stats | |
| def clear_cache(self): | |
| """Clear the response cache.""" | |
| self._response_cache.clear() | |
| logger.info("Cache cleared") | |
| def set_cache_enabled(self, enabled: bool): | |
| """Enable or disable response caching.""" | |
| self._cache_enabled = enabled | |
| logger.info(f"Cache {'enabled' if enabled else 'disabled'}") | |
| class ChatBot: | |
| """ | |
| High-level chatbot interface combining all components. | |
| This is the main class users will interact with. | |
| """ | |
| def __init__(self, config: Optional[Config] = None): | |
| self.config = config or get_config() | |
| self.model_loader = ModelLoader(self.config) | |
| self.engine = InferenceEngine(self.model_loader, self.config) | |
| self.conversation = Conversation(config=self.config) | |
| self._initialized = False | |
| self._warmup_done = False | |
| def initialize( | |
| self, | |
| model_id: Optional[str] = None, | |
| auto_download: bool = True, | |
| warmup: bool = True, | |
| progress_callback: Optional[Callable[[float, str], None]] = None | |
| ) -> "ChatBot": | |
| """ | |
| Initialize the chatbot (load model). | |
| Args: | |
| model_id: Specific model to load (or auto-select) | |
| auto_download: Download model if not present | |
| warmup: Run warmup inference for faster first response | |
| progress_callback: Progress update callback | |
| Returns: | |
| self for chaining | |
| """ | |
| # Load model | |
| self.model_loader.load( | |
| model_id=model_id, | |
| auto_download=auto_download, | |
| progress_callback=progress_callback | |
| ) | |
| self._initialized = True | |
| # Warmup | |
| if warmup and self.config.performance.enable_warmup: | |
| if progress_callback: | |
| progress_callback(0.9, "Warming up model...") | |
| self.model_loader.warmup(self.config.performance.warmup_prompt) | |
| self._warmup_done = True | |
| if progress_callback: | |
| progress_callback(1.0, "Ready!") | |
| return self | |
| def chat( | |
| self, | |
| message: str, | |
| stream: bool = True, | |
| **kwargs | |
| ) -> Union[str, Iterator[str]]: | |
| """ | |
| Send a message and get a response. | |
| Args: | |
| message: User message | |
| stream: Stream response tokens | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| Response string or streaming iterator | |
| """ | |
| if not self._initialized: | |
| raise RuntimeError("ChatBot not initialized. Call initialize() first.") | |
| if stream: | |
| return self.engine.chat(message, self.conversation, stream=True, **kwargs) | |
| else: | |
| result = self.engine.chat(message, self.conversation, stream=False, **kwargs) | |
| return result.text | |
| def chat_simple(self, message: str) -> str: | |
| """Simple non-streaming chat. Returns complete response.""" | |
| if not self._initialized: | |
| raise RuntimeError("ChatBot not initialized. Call initialize() first.") | |
| result = self.engine.chat(message, self.conversation, stream=False) | |
| return result.text | |
| def reset_conversation(self): | |
| """Reset the conversation history.""" | |
| self.conversation.reset() | |
| def get_context(self) -> str: | |
| """Get current conversation context summary.""" | |
| return self.conversation.get_context_summary() | |
| def get_history(self) -> List[Dict[str, str]]: | |
| """Get conversation history.""" | |
| return self.conversation.get_chat_messages() | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get performance statistics.""" | |
| return self.engine.get_stats() | |
| def is_ready(self) -> bool: | |
| """Check if chatbot is ready for conversation.""" | |
| return self._initialized | |
| # ============================================================================= | |
| # CONVENIENCE FUNCTIONS | |
| # ============================================================================= | |
| _global_chatbot: Optional[ChatBot] = None | |
| def get_chatbot() -> ChatBot: | |
| """Get or create the global chatbot instance.""" | |
| global _global_chatbot | |
| if _global_chatbot is None: | |
| _global_chatbot = ChatBot() | |
| return _global_chatbot | |
| def quick_chat(message: str, stream: bool = False) -> Union[str, Iterator[str]]: | |
| """ | |
| Quick chat function - initializes on first use. | |
| Args: | |
| message: User message | |
| stream: Stream response | |
| Returns: | |
| Response string or iterator | |
| """ | |
| chatbot = get_chatbot() | |
| if not chatbot.is_ready: | |
| print("Initializing chatbot (first run)...") | |
| chatbot.initialize() | |
| return chatbot.chat(message, stream=stream) | |
| if __name__ == "__main__": | |
| # Test inference engine | |
| from utils import print_banner, print_system_status | |
| print_banner() | |
| print_system_status() | |
| print("🚀 Initializing ChatBot...\n") | |
| def progress(pct: float, msg: str): | |
| print(f" [{pct*100:5.1f}%] {msg}") | |
| chatbot = ChatBot() | |
| chatbot.initialize(progress_callback=progress) | |
| print("\n✓ ChatBot ready!") | |
| print("\n" + "=" * 50) | |
| print("Testing conversation...\n") | |
| # Test conversation | |
| test_messages = [ | |
| "Hello! How are you today?", | |
| "I'm learning Python. Any tips?", | |
| "What about for data science specifically?", | |
| "Thanks! That's helpful.", | |
| ] | |
| for msg in test_messages: | |
| print(f"You: {msg}") | |
| print("AI: ", end="", flush=True) | |
| # Stream response | |
| for token in chatbot.chat(msg, stream=True): | |
| print(token, end="", flush=True) | |
| print("\n") | |
| # Show stats | |
| print("\n" + "=" * 50) | |
| print("\n📊 Performance Stats:") | |
| stats = chatbot.get_stats() | |
| if "performance" in stats: | |
| perf = stats["performance"] | |
| if "tokens_per_second" in perf: | |
| print(f" Tokens/second: {perf['tokens_per_second'].get('avg', 0):.1f}") | |
| if "generation_time" in perf: | |
| print(f" Avg generation time: {perf['generation_time'].get('avg', 0):.2f}s") | |
| print(f" Memory usage: {stats.get('memory_usage_gb', 0):.2f} GB") | |