bilalnaveed commited on
Commit
cfd8e56
·
verified ·
1 Parent(s): fc3b069

Upload model_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_loader.py +430 -0
model_loader.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Loader for the Tiny Conversational AI.
3
+ Handles model downloading, loading, and optimization for CPU inference.
4
+ Uses llama-cpp-python for maximum CPU performance with 4-bit quantization.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+ import shutil
11
+ from pathlib import Path
12
+ from typing import Optional, Dict, Any, Callable
13
+ from dataclasses import dataclass
14
+ import threading
15
+
16
+ from config import Config, get_config
17
+ from utils import (
18
+ get_logger, get_system_info, get_available_ram_gb, get_optimal_thread_count,
19
+ check_system_requirements, ProgressBar, Timer, ensure_dir
20
+ )
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class ModelInfo:
27
+ """Information about a loaded model."""
28
+ name: str
29
+ model_id: str
30
+ path: Path
31
+ size_gb: float
32
+ context_size: int
33
+ loaded: bool = False
34
+ load_time_seconds: float = 0.0
35
+
36
+
37
+ class ModelLoader:
38
+ """
39
+ Handles loading and managing LLM models.
40
+ Uses llama-cpp-python for efficient CPU inference with 4-bit quantization.
41
+ """
42
+
43
+ def __init__(self, config: Optional[Config] = None):
44
+ self.config = config or get_config()
45
+ self.model = None
46
+ self.model_info: Optional[ModelInfo] = None
47
+ self._lock = threading.Lock()
48
+ self._loading = False
49
+
50
+ # Ensure models directory exists
51
+ ensure_dir(self.config.paths.models_dir)
52
+
53
+ def _get_model_path(self, model_id: str) -> Path:
54
+ """Get the local path for a model."""
55
+ model_config = self.config.model.AVAILABLE_MODELS.get(model_id)
56
+ if not model_config:
57
+ raise ValueError(f"Unknown model: {model_id}")
58
+
59
+ return self.config.paths.models_dir / model_config["file"]
60
+
61
+ def _is_model_downloaded(self, model_id: str) -> bool:
62
+ """Check if a model is already downloaded."""
63
+ path = self._get_model_path(model_id)
64
+ return path.exists()
65
+
66
+ def _download_model(
67
+ self,
68
+ model_id: str,
69
+ progress_callback: Optional[Callable[[float, str], None]] = None
70
+ ) -> Path:
71
+ """
72
+ Download a model from Hugging Face Hub.
73
+ Uses huggingface_hub for reliable downloads with resume support.
74
+ """
75
+ model_config = self.config.model.AVAILABLE_MODELS.get(model_id)
76
+ if not model_config:
77
+ raise ValueError(f"Unknown model: {model_id}")
78
+
79
+ dest_path = self._get_model_path(model_id)
80
+
81
+ if dest_path.exists():
82
+ logger.info(f"Model already downloaded: {dest_path}")
83
+ return dest_path
84
+
85
+ logger.info(f"Downloading model: {model_config['name']}")
86
+ logger.info(f"Repository: {model_config['repo']}")
87
+ logger.info(f"File: {model_config['file']}")
88
+ logger.info(f"Expected size: ~{model_config['size_gb']}GB")
89
+
90
+ if progress_callback:
91
+ progress_callback(0.0, f"Starting download of {model_config['name']}...")
92
+
93
+ try:
94
+ from huggingface_hub import hf_hub_download
95
+
96
+ # Download with progress
97
+ downloaded_path = hf_hub_download(
98
+ repo_id=model_config["repo"],
99
+ filename=model_config["file"],
100
+ local_dir=self.config.paths.models_dir,
101
+ local_dir_use_symlinks=False,
102
+ resume_download=True,
103
+ )
104
+
105
+ # Move to expected location if needed
106
+ downloaded_path = Path(downloaded_path)
107
+ if downloaded_path != dest_path:
108
+ if downloaded_path.exists():
109
+ shutil.move(str(downloaded_path), str(dest_path))
110
+
111
+ if progress_callback:
112
+ progress_callback(1.0, "Download complete!")
113
+
114
+ logger.info(f"Model downloaded successfully: {dest_path}")
115
+ return dest_path
116
+
117
+ except ImportError:
118
+ logger.error("huggingface_hub not installed. Please install it: pip install huggingface_hub")
119
+ raise
120
+ except Exception as e:
121
+ logger.error(f"Download failed: {e}")
122
+ raise
123
+
124
+ def select_best_model(self) -> str:
125
+ """
126
+ Automatically select the best model based on available RAM.
127
+ Returns model_id of the selected model.
128
+ """
129
+ available_ram = get_available_ram_gb()
130
+ logger.info(f"Available RAM: {available_ram:.1f}GB")
131
+
132
+ # Sort models by quality (descending), filter by RAM requirement
133
+ suitable_models = []
134
+
135
+ for model_id, model_config in self.config.model.AVAILABLE_MODELS.items():
136
+ if model_config["min_ram_gb"] <= available_ram:
137
+ suitable_models.append((model_id, model_config))
138
+
139
+ if not suitable_models:
140
+ # Use smallest model as last resort
141
+ logger.warning("Low RAM detected, using smallest model")
142
+ return "tinyllama-1.1b"
143
+
144
+ # Sort by quality * speed score
145
+ suitable_models.sort(
146
+ key=lambda x: x[1]["quality"] * x[1]["speed"],
147
+ reverse=True
148
+ )
149
+
150
+ selected = suitable_models[0][0]
151
+ logger.info(f"Selected model: {selected}")
152
+ return selected
153
+
154
+ def load(
155
+ self,
156
+ model_id: Optional[str] = None,
157
+ auto_download: bool = True,
158
+ progress_callback: Optional[Callable[[float, str], None]] = None
159
+ ) -> Any:
160
+ """
161
+ Load a model for inference.
162
+
163
+ Args:
164
+ model_id: ID of the model to load. If None, auto-selects best model.
165
+ auto_download: Whether to download the model if not present.
166
+ progress_callback: Optional callback for progress updates (progress, message).
167
+
168
+ Returns:
169
+ The loaded model instance.
170
+ """
171
+ with self._lock:
172
+ if self._loading:
173
+ raise RuntimeError("Model is already being loaded")
174
+ self._loading = True
175
+
176
+ try:
177
+ timer = Timer("Model loading")
178
+ timer.start()
179
+
180
+ # Auto-select model if not specified
181
+ if model_id is None:
182
+ model_id = self.select_best_model()
183
+
184
+ model_config = self.config.model.AVAILABLE_MODELS.get(model_id)
185
+ if not model_config:
186
+ raise ValueError(f"Unknown model: {model_id}")
187
+
188
+ model_path = self._get_model_path(model_id)
189
+
190
+ # Download if needed
191
+ if not model_path.exists():
192
+ if auto_download:
193
+ if progress_callback:
194
+ progress_callback(0.0, "Downloading model...")
195
+ model_path = self._download_model(model_id, progress_callback)
196
+ else:
197
+ raise FileNotFoundError(
198
+ f"Model not found: {model_path}\n"
199
+ f"Run with auto_download=True or download manually from: "
200
+ f"https://huggingface.co/{model_config['repo']}"
201
+ )
202
+
203
+ # Check system requirements (use relaxed check - model will use virtual memory if needed)
204
+ check_result = check_system_requirements(min(model_config["min_ram_gb"], 1.0))
205
+ if not check_result["meets_requirements"]:
206
+ for error in check_result["errors"]:
207
+ logger.warning(f"RAM warning (continuing anyway): {error}")
208
+ # Don't raise - let it try to load, OS will use swap if needed
209
+
210
+ for warning in check_result.get("warnings", []):
211
+ logger.warning(warning)
212
+
213
+ if progress_callback:
214
+ progress_callback(0.5, "Loading model into memory...")
215
+
216
+ # Load with llama-cpp-python
217
+ self.model = self._load_llama_cpp(model_path, model_config)
218
+
219
+ timer.stop()
220
+
221
+ # Store model info
222
+ self.model_info = ModelInfo(
223
+ name=model_config["name"],
224
+ model_id=model_id,
225
+ path=model_path,
226
+ size_gb=model_config["size_gb"],
227
+ context_size=model_config["context_size"],
228
+ loaded=True,
229
+ load_time_seconds=timer.elapsed,
230
+ )
231
+
232
+ if progress_callback:
233
+ progress_callback(1.0, f"Model loaded in {timer.elapsed:.1f}s")
234
+
235
+ logger.info(f"Model loaded successfully in {timer.elapsed:.1f}s")
236
+ return self.model
237
+
238
+ finally:
239
+ with self._lock:
240
+ self._loading = False
241
+
242
+ def _load_llama_cpp(self, model_path: Path, model_config: Dict[str, Any]) -> Any:
243
+ """Load model using llama-cpp-python for optimal CPU performance."""
244
+ try:
245
+ from llama_cpp import Llama
246
+ except ImportError:
247
+ logger.error(
248
+ "llama-cpp-python not installed. Please install it:\n"
249
+ "pip install llama-cpp-python"
250
+ )
251
+ raise
252
+
253
+ # Determine optimal settings
254
+ n_threads = self.config.model.n_threads
255
+ if n_threads == 0:
256
+ n_threads = get_optimal_thread_count()
257
+
258
+ n_ctx = min(
259
+ model_config.get("context_size", 4096),
260
+ self.config.model.max_context_length
261
+ )
262
+
263
+ logger.info(f"Loading model: {model_path.name}")
264
+ logger.info(f"Context size: {n_ctx}")
265
+ logger.info(f"Threads: {n_threads}")
266
+ logger.info(f"Batch size: {self.config.model.n_batch}")
267
+
268
+ # Load the model
269
+ model = Llama(
270
+ model_path=str(model_path),
271
+ n_ctx=n_ctx,
272
+ n_threads=n_threads,
273
+ n_batch=self.config.model.n_batch,
274
+ n_gpu_layers=self.config.model.n_gpu_layers,
275
+ use_mmap=self.config.model.use_mmap,
276
+ use_mlock=self.config.model.use_mlock,
277
+ verbose=False, # Reduce noise
278
+ )
279
+
280
+ return model
281
+
282
+ def unload(self):
283
+ """Unload the current model to free memory."""
284
+ with self._lock:
285
+ if self.model is not None:
286
+ del self.model
287
+ self.model = None
288
+ self.model_info = None
289
+
290
+ # Force garbage collection
291
+ import gc
292
+ gc.collect()
293
+
294
+ logger.info("Model unloaded")
295
+
296
+ def is_loaded(self) -> bool:
297
+ """Check if a model is currently loaded."""
298
+ return self.model is not None
299
+
300
+ def get_model(self) -> Any:
301
+ """Get the loaded model instance."""
302
+ if not self.is_loaded():
303
+ raise RuntimeError("No model loaded. Call load() first.")
304
+ return self.model
305
+
306
+ def get_model_info(self) -> Optional[ModelInfo]:
307
+ """Get information about the loaded model."""
308
+ return self.model_info
309
+
310
+ def warmup(self, prompt: str = "Hello") -> float:
311
+ """
312
+ Warm up the model with a simple generation.
313
+ Returns the warmup time in seconds.
314
+ """
315
+ if not self.is_loaded():
316
+ raise RuntimeError("No model loaded. Call load() first.")
317
+
318
+ logger.info("Warming up model...")
319
+ timer = Timer("Warmup")
320
+ timer.start()
321
+
322
+ # Generate a short response
323
+ _ = self.model(
324
+ prompt,
325
+ max_tokens=10,
326
+ temperature=0.7,
327
+ )
328
+
329
+ timer.stop()
330
+ logger.info(f"Warmup complete in {timer.elapsed:.2f}s")
331
+ return timer.elapsed
332
+
333
+ def list_available_models(self) -> Dict[str, Dict[str, Any]]:
334
+ """List all available models with their info."""
335
+ models = {}
336
+
337
+ for model_id, model_config in self.config.model.AVAILABLE_MODELS.items():
338
+ models[model_id] = {
339
+ **model_config,
340
+ "downloaded": self._is_model_downloaded(model_id),
341
+ "path": str(self._get_model_path(model_id)),
342
+ }
343
+
344
+ return models
345
+
346
+ def delete_model(self, model_id: str) -> bool:
347
+ """Delete a downloaded model to free disk space."""
348
+ model_path = self._get_model_path(model_id)
349
+
350
+ if model_path.exists():
351
+ # Don't delete if currently loaded
352
+ if self.model_info and self.model_info.model_id == model_id:
353
+ self.unload()
354
+
355
+ model_path.unlink()
356
+ logger.info(f"Deleted model: {model_path}")
357
+ return True
358
+
359
+ return False
360
+
361
+
362
+ # =============================================================================
363
+ # CONVENIENCE FUNCTIONS
364
+ # =============================================================================
365
+
366
+ _global_loader: Optional[ModelLoader] = None
367
+
368
+
369
+ def get_loader() -> ModelLoader:
370
+ """Get the global model loader instance."""
371
+ global _global_loader
372
+ if _global_loader is None:
373
+ _global_loader = ModelLoader()
374
+ return _global_loader
375
+
376
+
377
+ def load_model(
378
+ model_id: Optional[str] = None,
379
+ auto_download: bool = True
380
+ ) -> Any:
381
+ """Convenience function to load a model."""
382
+ return get_loader().load(model_id, auto_download)
383
+
384
+
385
+ def get_model() -> Any:
386
+ """Get the currently loaded model."""
387
+ return get_loader().get_model()
388
+
389
+
390
+ if __name__ == "__main__":
391
+ # Test model loading
392
+ from utils import print_banner, print_system_status
393
+
394
+ print_banner()
395
+ print_system_status()
396
+
397
+ loader = ModelLoader()
398
+
399
+ print("\n📦 Available Models:")
400
+ for model_id, info in loader.list_available_models().items():
401
+ status = "✓ Downloaded" if info["downloaded"] else "○ Not downloaded"
402
+ print(f" • {model_id}: {info['name']}")
403
+ print(f" Size: {info['size_gb']}GB | Min RAM: {info['min_ram_gb']}GB | {status}")
404
+
405
+ # Auto-select and load best model
406
+ print("\n🚀 Loading model...")
407
+
408
+ try:
409
+ model = loader.load()
410
+
411
+ print(f"\n✓ Model loaded: {loader.model_info.name}")
412
+ print(f" Load time: {loader.model_info.load_time_seconds:.1f}s")
413
+
414
+ # Warmup
415
+ warmup_time = loader.warmup()
416
+ print(f" Warmup time: {warmup_time:.2f}s")
417
+
418
+ # Simple test
419
+ print("\n📝 Test generation:")
420
+ response = model(
421
+ "User: Hello!\nAssistant:",
422
+ max_tokens=50,
423
+ temperature=0.7,
424
+ stop=["User:", "\n\n"],
425
+ )
426
+ print(f"Response: {response['choices'][0]['text'].strip()}")
427
+
428
+ except Exception as e:
429
+ print(f"❌ Error: {e}")
430
+ raise