""" Model loader for all Arabic NLP models. Handles loading and inference with error handling. """ import os import logging from pathlib import Path import json import pickle import difflib import torch import importlib.util from transformers import ( MBartForConditionalGeneration, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM, EncoderDecoderModel, BertConfig, EncoderDecoderConfig ) # Imported lazily inside SpellingModel so summarization-only workflows do not # require spelling dependencies at module import time. ArabicSpellChecker = None logger = logging.getLogger(__name__) # Model paths MODEL_BASE_PATH = Path(__file__).parent.parent / "models" SUMMARIZATION_PATH = MODEL_BASE_PATH / "Summarization" / "Model" SPELLING_PATH = MODEL_BASE_PATH / "Spelling" / "Model" AUTOCOMPLETE_PATH = MODEL_BASE_PATH / "Autocomplete" / "Model" GRAMMAR_PATH = MODEL_BASE_PATH / "Grammrar" / "Model" PUNCTUATION_PATH = MODEL_BASE_PATH / "Punctuation" / "Model" class SummarizationModel: """Wrapper class for the Arabic summarization model.""" def __init__(self, model_path): """ Initialize the model. Args: model_path: Path to the model directory Raises: FileNotFoundError: If model files are not found RuntimeError: If model loading fails """ self.model_source = str(model_path) self._is_remote_source = self._looks_like_remote_source(self.model_source) self.model_path = None if self._is_remote_source else Path(model_path) self.model = None self.tokenizer = None self.device = None self._validate_path() self._load_model() @staticmethod def _looks_like_remote_source(source): """Detect Hugging Face repo ids or URLs.""" source = str(source) return source.startswith(("http://", "https://")) or ("/" in source and not Path(source).exists()) def _validate_path(self): """Validate that the model path exists and contains required files.""" if self._is_remote_source: logger.info(f"Using remote model source: {self.model_source}") return if not self.model_path.exists(): raise FileNotFoundError(f"Model path does not exist: {self.model_path}") required_files = ['config.json', 'tokenizer.json', 'model.safetensors'] missing_files = [] for file in required_files: if not (self.model_path / file).exists(): missing_files.append(file) if missing_files: raise FileNotFoundError( f"Missing required model files: {', '.join(missing_files)}" ) logger.info(f"Model path validated: {self.model_path}") def _fix_generation_config(self): """Fix generation_config.json and config.json if early_stopping is None/null.""" if self._is_remote_source: return gen_config_path = self.model_path / "generation_config.json" config_path = self.model_path / "config.json" try: # Fix config.json first (this is the main issue) if config_path.exists(): with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) # Check if early_stopping exists in config and is None if 'early_stopping' in config and config['early_stopping'] is None: logger.info("Fixing early_stopping in config.json (was None)") config['early_stopping'] = True with open(config_path, 'w', encoding='utf-8') as f: json.dump(config, f, indent=2, ensure_ascii=False) logger.info("Fixed config.json - set early_stopping to True") # Fix generation_config.json if gen_config_path.exists(): with open(gen_config_path, 'r', encoding='utf-8') as f: gen_config = json.load(f) # Fix early_stopping if it's None/null if gen_config.get('early_stopping') is None: logger.info("Fixing early_stopping in generation_config.json (was None)") gen_config['early_stopping'] = True with open(gen_config_path, 'w', encoding='utf-8') as f: json.dump(gen_config, f, indent=2, ensure_ascii=False) logger.info("Fixed generation_config.json - set early_stopping to True") except Exception as e: logger.warning(f"Could not fix generation config files: {str(e)}") # Continue anyway, we'll try to load with workaround def _load_model(self): """Load the model and tokenizer.""" try: # Determine device if torch.cuda.is_available(): self.device = torch.device('cuda') logger.info(f"Using CUDA device: {torch.cuda.get_device_name(0)}") else: self.device = torch.device('cpu') logger.info("Using CPU device") # Load tokenizer with error handling logger.info("Loading tokenizer...") try: self.tokenizer = AutoTokenizer.from_pretrained( self.model_source, local_files_only=True, trust_remote_code=False ) logger.info("Tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load tokenizer: {str(e)}") # Try without local_files_only as fallback logger.info("Retrying tokenizer load without local_files_only...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_source, trust_remote_code=False ) logger.info("Tokenizer loaded successfully (fallback method)") # Load model with error handling logger.info("Loading model (this may take a while)...") # Fix generation config files if needed self._fix_generation_config() # Load config and fix early_stopping if needed try: config = AutoConfig.from_pretrained( self.model_source, local_files_only=True, trust_remote_code=False ) # Fix early_stopping in config if it's None if hasattr(config, 'early_stopping') and config.early_stopping is None: logger.info("Fixing early_stopping in loaded config (was None)") config.early_stopping = True # Load model with fixed config try: self.model = MBartForConditionalGeneration.from_pretrained( self.model_source, config=config, local_files_only=True, trust_remote_code=False, torch_dtype=torch.float16 ) except Exception as e: logger.warning(f"Failed to load with config: {str(e)}") # Try without explicit config self.model = MBartForConditionalGeneration.from_pretrained( self.model_source, local_files_only=True, trust_remote_code=False, torch_dtype=torch.float16 ) except Exception as e: logger.warning(f"Failed to load config: {str(e)}") logger.info("Retrying model load without config fix...") try: self.model = MBartForConditionalGeneration.from_pretrained( self.model_source, local_files_only=True, trust_remote_code=False, torch_dtype=torch.float16 ) except Exception as e2: logger.warning(f"Failed to load with local_files_only: {str(e2)}") logger.info("Retrying model load without local_files_only...") self.model = MBartForConditionalGeneration.from_pretrained( self.model_source, trust_remote_code=False, torch_dtype=torch.float16 ) # Move model to device self.model.to(self.device) self.model.eval() # Set to evaluation mode # Clear cache if using CUDA if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"Model loaded successfully on {self.device}") except Exception as e: logger.error(f"Error loading model: {str(e)}") import traceback logger.error(traceback.format_exc()) raise RuntimeError(f"Failed to load model: {str(e)}") def summarize(self, text, max_length=150, min_length=30, num_beams=1, **kwargs): """ Summarize Arabic text. Args: text: Input Arabic text to summarize max_length: Maximum length of the summary min_length: Minimum length of the summary num_beams: Number of beams for beam search **kwargs: Additional generation parameters Returns: str: Summarized text Raises: RuntimeError: If summarization fails """ if self.model is None or self.tokenizer is None: raise RuntimeError("Model not loaded") try: # Tokenize input inputs = self.tokenizer( text, max_length=1024, truncation=True, padding=True, return_tensors="pt" ) # Move inputs to device inputs = {k: v.to(self.device) for k, v in inputs.items()} # Keep decoding conservative: this model performed best with greedy decoding # and became more generic/hallucinated with beam search or sampling. generate_kwargs = dict( max_new_tokens=max(20, min(max_length, 160)), min_new_tokens=max(0, min_length), num_beams=num_beams, do_sample=False, early_stopping=False, no_repeat_ngram_size=3, repetition_penalty=1.1, ) generate_kwargs.update(kwargs) # Remove legacy max_length/min_length if caller supplied them; we prefer max_new_tokens. generate_kwargs.pop('max_length', None) generate_kwargs.pop('min_length', None) # Generate summary with torch.no_grad(): outputs = self.model.generate( **inputs, **generate_kwargs, ) # Decode output summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True) summary = summary.strip() if self._needs_fallback(text, summary): return self._extractive_fallback(text, max_words=max(12, min(max_length, 80))) return summary except Exception as e: logger.error(f"Error during summarization: {str(e)}") raise RuntimeError(f"Summarization failed: {str(e)}") def _needs_fallback(self, source_text, summary_text): """Return True when generated summary appears too far from the source text.""" if not summary_text: return True source_words = set(source_text.split()) summary_words = summary_text.split() if not summary_words: return True overlap = sum(1 for word in summary_words if word in source_words) overlap_ratio = overlap / max(1, len(summary_words)) # Also guard against summaries that are lexically too dissimilar. ratio = difflib.SequenceMatcher(None, source_text[:500], summary_text[:500]).ratio() return overlap_ratio < 0.35 or ratio < 0.22 def _extractive_fallback(self, source_text, max_words=40): """Build a conservative summary from the opening sentences of the source text.""" text = source_text.strip() if not text: return text sentence_endings = ['.', '!', '?', '؟', '۔', '،'] sentences = [] current = [] for chunk in text.replace('\n', ' ').split(' '): current.append(chunk) if chunk and chunk[-1] in sentence_endings: sentence = ' '.join(current).strip() if sentence: sentences.append(sentence) current = [] if current: sentence = ' '.join(current).strip() if sentence: sentences.append(sentence) if not sentences: words = text.split() return ' '.join(words[:max_words]).strip() chosen = [] total_words = 0 for sentence in sentences: sentence_words = sentence.split() if not sentence_words: continue if total_words + len(sentence_words) > max_words and chosen: break chosen.append(sentence) total_words += len(sentence_words) if total_words >= max_words: break if not chosen: return ' '.join(text.split()[:max_words]).strip() return ' '.join(chosen).strip() def __del__(self): """Cleanup resources.""" try: if getattr(self, 'model', None) is not None: del self.model # Guard torch.cuda access in case torch or cuda subsystems are unavailable if 'torch' in globals() and hasattr(torch, 'cuda'): try: if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: # Ignore any cuda-related errors during cleanup pass except Exception: # Ensure destructor never raises pass class SpellingModel: """Wrapper class for the Arabic spelling correction model.""" def __init__(self, model_path=None): """ Initialize the spelling model. Args: model_path: Path to the model directory (defaults to SPELLING_PATH) """ self.model_path = Path(model_path) if model_path else SPELLING_PATH self.model = None self.device = None self._validate_path() self._load_model() def _validate_path(self): """Validate that the model path exists.""" if not self.model_path.exists(): raise FileNotFoundError(f"Model path does not exist: {self.model_path}") # Check for .pt file pt_files = list(self.model_path.glob("*.pt")) if not pt_files: raise FileNotFoundError(f"No .pt model file found in: {self.model_path}") logger.info(f"Spelling model path validated: {self.model_path}") def _load_model(self): """Load the spelling model.""" try: global ArabicSpellChecker if ArabicSpellChecker is None: try: from ara_spell import ArabicSpellChecker as _ArabicSpellChecker except ImportError: try: from src.ara_spell import ArabicSpellChecker as _ArabicSpellChecker except ImportError: from .ara_spell import ArabicSpellChecker as _ArabicSpellChecker ArabicSpellChecker = _ArabicSpellChecker self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Loading spelling model on {self.device}...") # Load tokenizer (using AraBERT tokenizer as it matches the vocab size 64000) logger.info("Loading tokenizer for spelling model...") try: self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv02") except Exception as e: logger.warning(f"Could not load aubmindlab/bert-base-arabertv02 tokenizer: {e}") # Fallback or error - strictly speaking we need this one raise RuntimeError(f"Tokenizer load failed: {e}") # Define the configuration for the EncoderDecoder model # Based on inspection, it uses BERT-like architecture with 64000 vocab # We'll create a generic config that matches what we found config_encoder = BertConfig( vocab_size=64000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072 ) config_decoder = BertConfig( vocab_size=64000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, is_decoder=True, add_cross_attention=True ) config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) # Initialize empty model self.model = EncoderDecoderModel(config=config) # Load state dict pt_file = list(self.model_path.glob("*.pt"))[0] logger.info(f"Loading weights from {pt_file}...") checkpoint = torch.load(pt_file, map_location=self.device) if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint # Load weights into model self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() # Set special tokens for generation # Usually strict encoder-decoder models need decoder_start_token_id self.model.config.decoder_start_token_id = self.tokenizer.cls_token_id self.model.config.eos_token_id = self.tokenizer.sep_token_id self.model.config.pad_token_id = self.tokenizer.pad_token_id self.model.config.vocab_size = self.model.config.encoder.vocab_size # Initialize the ArabicSpellChecker engine logger.info("Initializing ArabicSpellChecker engine...") self.engine = ArabicSpellChecker(self.model, self.tokenizer, self.device, use_contextual=True) logger.info("ArabicSpellChecker engine initialized successfully") logger.info("Spelling model loaded successfully") except Exception as e: logger.error(f"Error loading spelling model: {str(e)}") raise RuntimeError(f"Failed to load spelling model: {str(e)}") def correct(self, text): """ Correct spelling in Arabic text. Args: text: Input Arabic text Returns: str: Corrected text """ if not hasattr(self, 'engine'): # This should have happened in __init__ -> _load_model, but just in case logger.warning("Engine not found, reloading spelling model...") self._load_model() try: # Use the integrated AraSpell pipeline (Pre-process -> Generate -> Rerank -> Post-process) return self.engine.correct(text) except Exception as e: logger.error(f"Error during spelling correction: {str(e)}") # In case of error, return original text to avoid crashing the app flow logger.warning("Returning original text due to error.") return text class AutocompleteModel: """Wrapper class for the Arabic autocomplete model.""" def __init__(self, model_path=None, lazy=True): """ Initialize the autocomplete model. Args: model_path: Path to the model directory (defaults to AUTOCOMPLETE_PATH) lazy: If True, models will be loaded on first use instead of initialization """ self.model_path = Path(model_path) if model_path else AUTOCOMPLETE_PATH # GPT-2 only components (no bigram .pkl on this machine) self.bigram_model = None # kept for backward compatibility (unused) self.ngram_model = None # unused self.unigrams = None # unused self.bigrams = None # unused self._hybrid = None # reference to hybrid_module (for GPT-2 helpers) self._gpt_tokenizer = None self._gpt_model = None self.lazy = lazy self.enabled = os.environ.get("LOAD_AUTOCOMPLETE", "false").lower() == "true" if not self.enabled: logger.info("Autocomplete model is disabled (LOAD_AUTOCOMPLETE=false)") return self._validate_path() if not self.lazy: self._load_model() def _validate_path(self): """Validate that the model path exists.""" if not self.model_path.exists(): raise FileNotFoundError(f"Model path does not exist: {self.model_path}") logger.info(f"Autocomplete model path validated: {self.model_path}") def _load_model(self): """Load GPT-2 autocomplete model using hybrid_module (GPT-2 only, no bigram .pkl).""" try: logger.info("Loading autocomplete models...") # Load hybrid_module.py dynamically from the model directory hybrid_path = self.model_path / "hybrid_module.py" if not hybrid_path.exists(): raise FileNotFoundError(f"hybrid_module.py not found at: {hybrid_path}") spec = importlib.util.spec_from_file_location("autocomplete_hybrid", str(hybrid_path)) if spec is None or spec.loader is None: raise ImportError(f"Could not load spec for hybrid_module from: {hybrid_path}") hybrid = importlib.util.module_from_spec(spec) spec.loader.exec_module(hybrid) self._hybrid = hybrid logger.info("hybrid_module imported successfully") # Load GPT-2 model (on CPU) for autocomplete. if hasattr(hybrid, "load_gpt2"): try: logger.info("Loading GPT-2 model for autocomplete (CPU, GPT-2 only mode)...") tokenizer, model = hybrid.load_gpt2() # Force CPU to avoid GPU OOM / kills model.to(torch.device("cpu")) self._gpt_tokenizer = tokenizer self._gpt_model = model logger.info("GPT-2 model loaded successfully for autocomplete (CPU)") except Exception as gpt_err: logger.error(f"Failed to load GPT-2 for autocomplete: {gpt_err}") raise else: raise RuntimeError("hybrid_module.load_gpt2 not found; cannot run GPT-2-only autocomplete") logger.info("Autocomplete GPT-2-only model loaded successfully") except Exception as e: logger.error(f"Error loading autocomplete models: {str(e)}") raise RuntimeError(f"Failed to load autocomplete models: {str(e)}") def predict(self, text, n=5): """ Predict next words for autocomplete. Args: text: Input Arabic text n: Number of suggestions to return Returns: list: List of suggested completions """ if not self.enabled: return [] if self._hybrid is None or self._gpt_model is None or self._gpt_tokenizer is None: if self.lazy: try: self._load_model() except Exception as e: logger.error(f"Lazy loading of autocomplete failed: {str(e)}") self.enabled = False return [] else: raise RuntimeError("Model not loaded and lazy loading is off") try: # GPT-2 only autocomplete using hybrid_module.gpt2_next_token_probs if not self._hybrid or self._gpt_model is None or self._gpt_tokenizer is None: raise RuntimeError("GPT-2 autocomplete components not loaded") logger.info( f"[Autocomplete] predict called | mode=gpt2-only | text_tail='{text[-50:]}'" ) if not hasattr(self._hybrid, "gpt2_next_token_probs"): raise RuntimeError("hybrid_module.gpt2_next_token_probs not found") # Get token probability dict from GPT-2 prob_dict = self._hybrid.gpt2_next_token_probs( text, self._gpt_tokenizer, self._gpt_model, top_k=max(n * 2, 10), # grab a few extra for diversity ) # Convert dict to sorted list of (token, prob) preds = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True) logger.info(f"[Autocomplete] raw GPT-2 predictions (top 5): {preds[:5]}") # Extract top-n tokens as suggestions suggestions = [w for (w, _p) in preds[:n] if w] logger.info(f"[Autocomplete] suggestions returned to API: {suggestions}") return suggestions except Exception as e: logger.error(f"Error during autocomplete prediction: {str(e)}") raise RuntimeError(f"Autocomplete prediction failed: {str(e)}") class GrammarModel: """Wrapper class for the Arabic grammar correction model.""" def __init__(self, model_path=None): """ Initialize the grammar model. Args: model_path: Path to the model directory (defaults to GRAMMAR_PATH) """ self.model_path = Path(model_path) if model_path else GRAMMAR_PATH self.model = None self.tokenizer = None self.device = None self._validate_path() self._load_model() def _validate_path(self): """Validate that the model path exists and contains required files.""" if not self.model_path.exists(): raise FileNotFoundError(f"Model path does not exist: {self.model_path}") required_files = ['config.json', 'tokenizer.json', 'model.safetensors'] missing_files = [f for f in required_files if not (self.model_path / f).exists()] if missing_files: raise FileNotFoundError(f"Missing required files: {', '.join(missing_files)}") logger.info(f"Grammar model path validated: {self.model_path}") def _load_model(self): """Load the grammar model and tokenizer.""" try: # Force CPU-only to avoid GPU OOM / system freezes self.device = torch.device('cpu') logger.info("Loading grammar model on CPU (GPU disabled by design)...") self.tokenizer = AutoTokenizer.from_pretrained( str(self.model_path), local_files_only=True ) self.model = AutoModelForCausalLM.from_pretrained( str(self.model_path), local_files_only=True, trust_remote_code=True, torch_dtype=torch.float32 # safe default for CPU inference ) self.model.to(self.device) self.model.eval() logger.info("Grammar model loaded successfully") except Exception as e: logger.error(f"Error loading grammar model: {str(e)}") raise RuntimeError(f"Failed to load grammar model: {str(e)}") def correct(self, text): """ Correct grammar in Arabic text with a timeout. Args: text: Input Arabic text Returns: str: Grammar-corrected text """ if self.model is None or self.tokenizer is None: raise RuntimeError("Model not loaded") import threading result = [text] # default to original error = [None] def _generate(): try: # Use Gemma 3 chat template — pass text only (model is GEC-trained) messages = [{"role": "user", "content": text}] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = self.tokenizer( prompt, max_length=256, truncation=True, padding=True, return_tensors="pt", add_special_tokens=False, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} input_length = inputs['input_ids'].shape[1] with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=64, # enough for corrected sentence do_sample=False, ) # Decode only the NEW tokens (skip the prompt) new_tokens = outputs[0][input_length:] corrected = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() # If model returned empty or nonsense, keep original if not corrected or len(corrected) < 2: return # Reject generic instruction phrases (model fallback, not actual correction) generic_phrases = ( "أعد كتابتها", "أعد كتابته", "أعد كتابتها.", "أعد كتابته.", "اعيد كتابتها", "اعيد كتابته", "أعد كتابة", "اعيد كتابة", "أعد كتابتها فقط", "أعد كتابته فقط", ) corrected_lower = corrected.strip() for phrase in generic_phrases: if phrase in corrected_lower or corrected_lower.startswith(phrase): return # Take only the first non-empty line as the corrected sentence corrected_lines = [l.strip() for l in corrected.split('\n') if l.strip()] if corrected_lines: first = corrected_lines[0] if first and len(first) >= 2 and first not in generic_phrases: result[0] = first except Exception as e: error[0] = e # Run generation in a thread with a 30s timeout thread = threading.Thread(target=_generate) thread.start() thread.join(timeout=30) if thread.is_alive(): logger.warning("[Grammar] Timed out after 30s — returning original text") return text if error[0]: logger.error(f"Error during grammar correction: {str(error[0])}") logger.warning("Returning original text due to grammar error.") return result[0] class PunctuationModel: """Wrapper class for the Arabic punctuation model.""" def __init__(self, model_path=None): """ Initialize the punctuation model. Args: model_path: Path to the model directory (defaults to PUNCTUATION_PATH) """ self.model_path = Path(model_path) if model_path else PUNCTUATION_PATH self.model = None self.tokenizer = None self.device = None self._validate_path() self._load_model() def _validate_path(self): """Validate that the model path exists and contains required files.""" if not self.model_path.exists(): raise FileNotFoundError(f"Model path does not exist: {self.model_path}") required_files = ['config.json', 'tokenizer.json', 'model.safetensors'] missing_files = [f for f in required_files if not (self.model_path / f).exists()] if missing_files: raise FileNotFoundError(f"Missing required files: {', '.join(missing_files)}") logger.info(f"Punctuation model path validated: {self.model_path}") def _load_model(self): """Load the punctuation model and tokenizer.""" try: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Loading punctuation model on {self.device}...") self.tokenizer = AutoTokenizer.from_pretrained( str(self.model_path), local_files_only=True ) self.model = AutoModelForSeq2SeqLM.from_pretrained( str(self.model_path), local_files_only=True ) self.model.to(self.device) self.model.eval() logger.info("Punctuation model loaded successfully") except Exception as e: logger.error(f"Error loading punctuation model: {str(e)}") raise RuntimeError(f"Failed to load punctuation model: {str(e)}") def add_punctuation(self, text): """ Add punctuation to Arabic text. Args: text: Input Arabic text without punctuation Returns: str: Text with punctuation added """ if self.model is None or self.tokenizer is None: raise RuntimeError("Model not loaded") try: inputs = self.tokenizer( text, max_length=512, truncation=True, padding=True, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Since it's an EncoderDecoderModel, we use generate() with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=512, decoder_start_token_id=self.tokenizer.cls_token_id, eos_token_id=self.tokenizer.sep_token_id, pad_token_id=self.tokenizer.pad_token_id ) punctuated = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return punctuated.strip() except Exception as e: logger.error(f"Error during punctuation: {str(e)}") logger.warning("Returning original text due to punctuation error.") return text