bilalnaveed commited on
Commit
2c64828
Β·
verified Β·
1 Parent(s): ac974cb

Upload utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils.py +564 -0
utils.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Tiny Conversational AI.
3
+ Helper functions for system detection, logging, caching, and more.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import json
10
+ import hashlib
11
+ import logging
12
+ import platform
13
+ import threading
14
+ from pathlib import Path
15
+ from typing import Optional, Dict, Any, Callable, List
16
+ from functools import wraps, lru_cache
17
+ from datetime import datetime, timedelta
18
+ from collections import OrderedDict
19
+ from contextlib import contextmanager
20
+
21
+ # Try to import psutil for memory monitoring
22
+ try:
23
+ import psutil
24
+ HAS_PSUTIL = True
25
+ except ImportError:
26
+ HAS_PSUTIL = False
27
+
28
+
29
+ # =============================================================================
30
+ # LOGGING
31
+ # =============================================================================
32
+
33
+ def setup_logging(
34
+ log_level: str = "INFO",
35
+ log_file: Optional[str] = None,
36
+ log_format: Optional[str] = None
37
+ ) -> logging.Logger:
38
+ """Setup and configure logging."""
39
+
40
+ if log_format is None:
41
+ log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
42
+
43
+ logger = logging.getLogger("tiny_ai")
44
+ logger.setLevel(getattr(logging, log_level.upper()))
45
+
46
+ # Console handler
47
+ console_handler = logging.StreamHandler(sys.stdout)
48
+ console_handler.setFormatter(logging.Formatter(log_format))
49
+ logger.addHandler(console_handler)
50
+
51
+ # File handler (optional)
52
+ if log_file:
53
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
54
+ file_handler.setFormatter(logging.Formatter(log_format))
55
+ logger.addHandler(file_handler)
56
+
57
+ return logger
58
+
59
+
60
+ def get_logger(name: str = "tiny_ai") -> logging.Logger:
61
+ """Get a logger instance."""
62
+ return logging.getLogger(name)
63
+
64
+
65
+ # =============================================================================
66
+ # SYSTEM DETECTION
67
+ # =============================================================================
68
+
69
+ @lru_cache(maxsize=1)
70
+ def get_system_info() -> Dict[str, Any]:
71
+ """Get comprehensive system information."""
72
+ info = {
73
+ "platform": platform.system(),
74
+ "platform_release": platform.release(),
75
+ "platform_version": platform.version(),
76
+ "architecture": platform.machine(),
77
+ "processor": platform.processor(),
78
+ "python_version": platform.python_version(),
79
+ "cpu_count": os.cpu_count() or 1,
80
+ "cpu_count_physical": None,
81
+ "total_ram_gb": None,
82
+ "available_ram_gb": None,
83
+ "has_avx": False,
84
+ "has_avx2": False,
85
+ "has_avx512": False,
86
+ }
87
+
88
+ if HAS_PSUTIL:
89
+ try:
90
+ info["cpu_count_physical"] = psutil.cpu_count(logical=False)
91
+ mem = psutil.virtual_memory()
92
+ info["total_ram_gb"] = round(mem.total / (1024**3), 2)
93
+ info["available_ram_gb"] = round(mem.available / (1024**3), 2)
94
+ except Exception:
95
+ pass
96
+
97
+ # Detect CPU features (AVX support improves llama.cpp performance)
98
+ try:
99
+ if platform.system() == "Windows":
100
+ # On Windows, check via processor info
101
+ proc_info = platform.processor().lower()
102
+ info["has_avx"] = True # Most modern CPUs have AVX
103
+ info["has_avx2"] = "intel" in proc_info or "amd" in proc_info
104
+ else:
105
+ # On Linux, check /proc/cpuinfo
106
+ if os.path.exists("/proc/cpuinfo"):
107
+ with open("/proc/cpuinfo", "r") as f:
108
+ cpuinfo = f.read().lower()
109
+ info["has_avx"] = "avx" in cpuinfo
110
+ info["has_avx2"] = "avx2" in cpuinfo
111
+ info["has_avx512"] = "avx512" in cpuinfo
112
+ except Exception:
113
+ pass
114
+
115
+ return info
116
+
117
+
118
+ def get_available_ram_gb() -> float:
119
+ """Get available RAM in gigabytes."""
120
+ if HAS_PSUTIL:
121
+ return psutil.virtual_memory().available / (1024**3)
122
+ return 4.0 # Assume 4GB if psutil not available
123
+
124
+
125
+ def get_memory_usage_gb() -> float:
126
+ """Get current process memory usage in GB."""
127
+ if HAS_PSUTIL:
128
+ process = psutil.Process(os.getpid())
129
+ return process.memory_info().rss / (1024**3)
130
+ return 0.0
131
+
132
+
133
+ def get_optimal_thread_count() -> int:
134
+ """Determine optimal thread count for inference."""
135
+ cpu_count = os.cpu_count() or 4
136
+
137
+ if HAS_PSUTIL:
138
+ physical_cores = psutil.cpu_count(logical=False)
139
+ if physical_cores:
140
+ # Use physical cores for best performance
141
+ return max(1, physical_cores - 1)
142
+
143
+ # Fallback: use half of logical cores
144
+ return max(1, cpu_count // 2)
145
+
146
+
147
+ def check_system_requirements(min_ram_gb: float = 2.0) -> Dict[str, Any]:
148
+ """Check if system meets minimum requirements."""
149
+ info = get_system_info()
150
+
151
+ result = {
152
+ "meets_requirements": True,
153
+ "warnings": [],
154
+ "errors": [],
155
+ "recommendations": [],
156
+ }
157
+
158
+ # Check RAM
159
+ if info["available_ram_gb"] and info["available_ram_gb"] < min_ram_gb:
160
+ result["meets_requirements"] = False
161
+ result["errors"].append(
162
+ f"Insufficient RAM: {info['available_ram_gb']:.1f}GB available, "
163
+ f"{min_ram_gb}GB required"
164
+ )
165
+ result["recommendations"].append("Close other applications to free memory")
166
+
167
+ # Check Python version
168
+ py_version = tuple(map(int, info["python_version"].split(".")[:2]))
169
+ if py_version < (3, 8):
170
+ result["meets_requirements"] = False
171
+ result["errors"].append(f"Python 3.8+ required, found {info['python_version']}")
172
+
173
+ # Performance warnings
174
+ if info["cpu_count"] and info["cpu_count"] < 4:
175
+ result["warnings"].append("Low CPU core count may result in slower responses")
176
+
177
+ if not info.get("has_avx"):
178
+ result["warnings"].append("CPU may not support AVX instructions (slower inference)")
179
+
180
+ return result
181
+
182
+
183
+ # =============================================================================
184
+ # CACHING
185
+ # =============================================================================
186
+
187
+ class LRUCache:
188
+ """Thread-safe LRU cache with TTL support."""
189
+
190
+ def __init__(self, max_size: int = 100, ttl_seconds: int = 3600):
191
+ self.max_size = max_size
192
+ self.ttl_seconds = ttl_seconds
193
+ self.cache: OrderedDict = OrderedDict()
194
+ self.timestamps: Dict[str, datetime] = {}
195
+ self.lock = threading.Lock()
196
+
197
+ def _make_key(self, *args, **kwargs) -> str:
198
+ """Create a hash key from arguments."""
199
+ key_str = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True, default=str)
200
+ return hashlib.md5(key_str.encode()).hexdigest()
201
+
202
+ def get(self, key: str) -> Optional[Any]:
203
+ """Get value from cache if exists and not expired."""
204
+ with self.lock:
205
+ if key not in self.cache:
206
+ return None
207
+
208
+ # Check TTL
209
+ if datetime.now() - self.timestamps[key] > timedelta(seconds=self.ttl_seconds):
210
+ del self.cache[key]
211
+ del self.timestamps[key]
212
+ return None
213
+
214
+ # Move to end (most recently used)
215
+ self.cache.move_to_end(key)
216
+ return self.cache[key]
217
+
218
+ def set(self, key: str, value: Any):
219
+ """Set value in cache."""
220
+ with self.lock:
221
+ if key in self.cache:
222
+ self.cache.move_to_end(key)
223
+ else:
224
+ if len(self.cache) >= self.max_size:
225
+ # Remove oldest item
226
+ oldest = next(iter(self.cache))
227
+ del self.cache[oldest]
228
+ del self.timestamps[oldest]
229
+
230
+ self.cache[key] = value
231
+ self.timestamps[key] = datetime.now()
232
+
233
+ def clear(self):
234
+ """Clear the cache."""
235
+ with self.lock:
236
+ self.cache.clear()
237
+ self.timestamps.clear()
238
+
239
+ def stats(self) -> Dict[str, Any]:
240
+ """Get cache statistics."""
241
+ with self.lock:
242
+ return {
243
+ "size": len(self.cache),
244
+ "max_size": self.max_size,
245
+ "ttl_seconds": self.ttl_seconds,
246
+ }
247
+
248
+
249
+ def cached(cache: LRUCache):
250
+ """Decorator to cache function results."""
251
+ def decorator(func: Callable):
252
+ @wraps(func)
253
+ def wrapper(*args, **kwargs):
254
+ key = cache._make_key(func.__name__, *args, **kwargs)
255
+ result = cache.get(key)
256
+ if result is not None:
257
+ return result
258
+
259
+ result = func(*args, **kwargs)
260
+ cache.set(key, result)
261
+ return result
262
+ return wrapper
263
+ return decorator
264
+
265
+
266
+ # =============================================================================
267
+ # TIMING AND PERFORMANCE
268
+ # =============================================================================
269
+
270
+ class Timer:
271
+ """Simple timer for measuring execution time."""
272
+
273
+ def __init__(self, name: str = ""):
274
+ self.name = name
275
+ self.start_time: Optional[float] = None
276
+ self.end_time: Optional[float] = None
277
+ self.elapsed: float = 0.0
278
+
279
+ def start(self):
280
+ """Start the timer."""
281
+ self.start_time = time.perf_counter()
282
+ return self
283
+
284
+ def stop(self) -> float:
285
+ """Stop the timer and return elapsed time."""
286
+ self.end_time = time.perf_counter()
287
+ self.elapsed = self.end_time - (self.start_time or self.end_time)
288
+ return self.elapsed
289
+
290
+ def __enter__(self):
291
+ self.start()
292
+ return self
293
+
294
+ def __exit__(self, *args):
295
+ self.stop()
296
+
297
+
298
+ @contextmanager
299
+ def measure_time(name: str = "", logger: Optional[logging.Logger] = None):
300
+ """Context manager to measure and log execution time."""
301
+ timer = Timer(name)
302
+ timer.start()
303
+ try:
304
+ yield timer
305
+ finally:
306
+ timer.stop()
307
+ if logger:
308
+ logger.debug(f"{name}: {timer.elapsed:.4f}s")
309
+
310
+
311
+ class PerformanceTracker:
312
+ """Track performance metrics over time."""
313
+
314
+ def __init__(self, window_size: int = 100):
315
+ self.window_size = window_size
316
+ self.metrics: Dict[str, List[float]] = {}
317
+ self.lock = threading.Lock()
318
+
319
+ def record(self, metric_name: str, value: float):
320
+ """Record a metric value."""
321
+ with self.lock:
322
+ if metric_name not in self.metrics:
323
+ self.metrics[metric_name] = []
324
+
325
+ self.metrics[metric_name].append(value)
326
+
327
+ # Keep only recent values
328
+ if len(self.metrics[metric_name]) > self.window_size:
329
+ self.metrics[metric_name] = self.metrics[metric_name][-self.window_size:]
330
+
331
+ def get_average(self, metric_name: str) -> Optional[float]:
332
+ """Get average value for a metric."""
333
+ with self.lock:
334
+ if metric_name not in self.metrics or not self.metrics[metric_name]:
335
+ return None
336
+ return sum(self.metrics[metric_name]) / len(self.metrics[metric_name])
337
+
338
+ def get_stats(self, metric_name: str) -> Dict[str, Any]:
339
+ """Get statistics for a metric."""
340
+ with self.lock:
341
+ if metric_name not in self.metrics or not self.metrics[metric_name]:
342
+ return {}
343
+
344
+ values = self.metrics[metric_name]
345
+ return {
346
+ "count": len(values),
347
+ "min": min(values),
348
+ "max": max(values),
349
+ "avg": sum(values) / len(values),
350
+ "last": values[-1],
351
+ }
352
+
353
+ def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
354
+ """Get statistics for all metrics."""
355
+ return {name: self.get_stats(name) for name in self.metrics}
356
+
357
+
358
+ # =============================================================================
359
+ # TEXT PROCESSING
360
+ # =============================================================================
361
+
362
+ def count_tokens_approx(text: str) -> int:
363
+ """Approximate token count (roughly 4 characters per token)."""
364
+ return len(text) // 4 + 1
365
+
366
+
367
+ def truncate_text(text: str, max_tokens: int) -> str:
368
+ """Truncate text to approximate token limit."""
369
+ max_chars = max_tokens * 4
370
+ if len(text) <= max_chars:
371
+ return text
372
+
373
+ # Try to truncate at word boundary
374
+ truncated = text[:max_chars]
375
+ last_space = truncated.rfind(" ")
376
+ if last_space > max_chars * 0.8:
377
+ truncated = truncated[:last_space]
378
+
379
+ return truncated + "..."
380
+
381
+
382
+ def clean_response(text: str) -> str:
383
+ """Clean up model response text."""
384
+ # Remove common artifacts
385
+ text = text.strip()
386
+
387
+ # Remove chat template tokens that leaked through
388
+ for token in ["<|system|>", "<|user|>", "<|assistant|>", "</s>", "<|end|>", "<s>"]:
389
+ text = text.replace(token, "")
390
+
391
+ # Remove repeated punctuation
392
+ while " " in text:
393
+ text = text.replace(" ", " ")
394
+
395
+ text = text.strip()
396
+ return text
397
+
398
+
399
+ def extract_entities(text: str) -> Dict[str, List[str]]:
400
+ """Extract simple entities from text (names, topics, etc.)."""
401
+ entities = {
402
+ "names": [],
403
+ "topics": [],
404
+ "pronouns": [],
405
+ }
406
+
407
+ # Simple pronoun detection
408
+ pronouns = ["it", "they", "them", "this", "that", "these", "those", "he", "she"]
409
+ words = text.lower().split()
410
+ entities["pronouns"] = [w for w in words if w in pronouns]
411
+
412
+ # Detect capitalized words as potential names (simple heuristic)
413
+ import re
414
+ capitalized = re.findall(r'\b[A-Z][a-z]+\b', text)
415
+ # Filter out common sentence starters
416
+ common_starters = {"I", "The", "A", "An", "This", "That", "What", "How", "Why", "When", "Where"}
417
+ entities["names"] = [w for w in capitalized if w not in common_starters]
418
+
419
+ return entities
420
+
421
+
422
+ # =============================================================================
423
+ # PROGRESS DISPLAY
424
+ # =============================================================================
425
+
426
+ class ProgressBar:
427
+ """Simple progress bar for terminal output."""
428
+
429
+ def __init__(self, total: int, prefix: str = "", width: int = 40):
430
+ self.total = total
431
+ self.prefix = prefix
432
+ self.width = width
433
+ self.current = 0
434
+ self.start_time = time.time()
435
+
436
+ def update(self, current: Optional[int] = None, increment: int = 1):
437
+ """Update progress bar."""
438
+ if current is not None:
439
+ self.current = current
440
+ else:
441
+ self.current += increment
442
+
443
+ self._display()
444
+
445
+ def _display(self):
446
+ """Display the progress bar."""
447
+ progress = self.current / self.total if self.total > 0 else 1
448
+ filled = int(self.width * progress)
449
+ bar = "β–ˆ" * filled + "β–‘" * (self.width - filled)
450
+ percent = progress * 100
451
+
452
+ # Estimate remaining time
453
+ elapsed = time.time() - self.start_time
454
+ if self.current > 0:
455
+ eta = elapsed * (self.total - self.current) / self.current
456
+ eta_str = f"ETA: {eta:.0f}s"
457
+ else:
458
+ eta_str = "ETA: --"
459
+
460
+ print(f"\r{self.prefix} |{bar}| {percent:.1f}% {eta_str}", end="", flush=True)
461
+
462
+ def finish(self):
463
+ """Complete the progress bar."""
464
+ self.current = self.total
465
+ self._display()
466
+ print()
467
+
468
+
469
+ def print_banner():
470
+ """Print application banner."""
471
+ banner = """
472
+ ╔══════════════════════════════════════════════════════════════╗
473
+ β•‘ πŸ€– Tiny Conversational AI β•‘
474
+ β•‘ Fast β€’ Lightweight β€’ Local β•‘
475
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
476
+ """
477
+ print(banner)
478
+
479
+
480
+ def print_system_status():
481
+ """Print current system status."""
482
+ info = get_system_info()
483
+
484
+ print("\nπŸ“Š System Status:")
485
+ print(f" β€’ Platform: {info['platform']} {info['platform_release']}")
486
+ print(f" β€’ CPU: {info['processor'][:50]}...")
487
+ print(f" β€’ Cores: {info['cpu_count']} logical, {info.get('cpu_count_physical', 'N/A')} physical")
488
+
489
+ if info["total_ram_gb"]:
490
+ print(f" β€’ RAM: {info['available_ram_gb']:.1f}GB available / {info['total_ram_gb']:.1f}GB total")
491
+
492
+ print(f" β€’ Python: {info['python_version']}")
493
+ print()
494
+
495
+
496
+ # =============================================================================
497
+ # FILE OPERATIONS
498
+ # =============================================================================
499
+
500
+ def ensure_dir(path: Path) -> Path:
501
+ """Ensure directory exists, create if needed."""
502
+ path.mkdir(parents=True, exist_ok=True)
503
+ return path
504
+
505
+
506
+ def safe_json_load(path: Path, default: Any = None) -> Any:
507
+ """Safely load JSON file, return default on error."""
508
+ try:
509
+ with open(path, 'r', encoding='utf-8') as f:
510
+ return json.load(f)
511
+ except (FileNotFoundError, json.JSONDecodeError):
512
+ return default
513
+
514
+
515
+ def safe_json_save(path: Path, data: Any):
516
+ """Safely save data to JSON file."""
517
+ ensure_dir(path.parent)
518
+ with open(path, 'w', encoding='utf-8') as f:
519
+ json.dump(data, f, indent=2, default=str)
520
+
521
+
522
+ # =============================================================================
523
+ # DOWNLOAD HELPERS
524
+ # =============================================================================
525
+
526
+ def download_with_progress(url: str, dest_path: Path, chunk_size: int = 8192) -> bool:
527
+ """Download a file with progress bar."""
528
+ try:
529
+ import requests
530
+
531
+ response = requests.get(url, stream=True)
532
+ response.raise_for_status()
533
+
534
+ total_size = int(response.headers.get('content-length', 0))
535
+
536
+ ensure_dir(dest_path.parent)
537
+
538
+ progress = ProgressBar(
539
+ total=total_size,
540
+ prefix=f"Downloading {dest_path.name}",
541
+ )
542
+
543
+ downloaded = 0
544
+ with open(dest_path, 'wb') as f:
545
+ for chunk in response.iter_content(chunk_size=chunk_size):
546
+ if chunk:
547
+ f.write(chunk)
548
+ downloaded += len(chunk)
549
+ progress.update(downloaded)
550
+
551
+ progress.finish()
552
+ return True
553
+
554
+ except Exception as e:
555
+ get_logger().error(f"Download failed: {e}")
556
+ return False
557
+
558
+
559
+ # =============================================================================
560
+ # INITIALIZATION
561
+ # =============================================================================
562
+
563
+ # Setup default logger
564
+ _default_logger = setup_logging()