Spaces:
Running
Running
| """ | |
| LoRA adapter hot-swap manager. | |
| Uses PEFT's multi-adapter API: | |
| - model.load_adapter(path, adapter_name=lang) — first load (~2s per adapter) | |
| - model.set_adapter(lang) — subsequent swap (~50ms) | |
| This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights, | |
| vs reloading the full 1.5GB model per language. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| from peft import PeftModel | |
| if TYPE_CHECKING: | |
| from transformers import WhisperForConditionalGeneration | |
| logger = logging.getLogger(__name__) | |
| class AdapterManager: | |
| """Manages registration and hot-swapping of LoRA language adapters.""" | |
| def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None: | |
| self._base_model = base_model | |
| self._config = config | |
| self._registry: dict[str, str] = {} # language_code -> adapter_path | |
| self._peft_model: PeftModel | None = None | |
| self._active: str | None = None | |
| def register(self, language: str, adapter_path: str) -> None: | |
| """Register an adapter path. Does not load it yet.""" | |
| path = Path(adapter_path) | |
| if not path.exists(): | |
| logger.warning( | |
| "Adapter path '%s' for language '%s' does not exist. " | |
| "Run training first, or check the path.", | |
| adapter_path, language, | |
| ) | |
| self._registry[language] = str(path) | |
| logger.info("Registered adapter '%s' → %s", language, adapter_path) | |
| def load_adapter(self, language: str) -> None: | |
| """ | |
| Load an adapter into the model for the first time. | |
| Slow (~2s): reads adapter weights from disk. | |
| Subsequent activate() calls reuse the already-loaded weights. | |
| """ | |
| if language not in self._registry: | |
| raise KeyError(f"No adapter registered for language '{language}'. " | |
| f"Available: {list(self._registry)}") | |
| adapter_path = self._registry[language] | |
| if self._peft_model is None: | |
| # First adapter: wrap the base model with PeftModel | |
| logger.info("Wrapping base model with first adapter '%s'...", language) | |
| self._peft_model = PeftModel.from_pretrained( | |
| self._base_model, | |
| adapter_path, | |
| adapter_name=language, | |
| ) | |
| else: | |
| # Subsequent adapters: load into the existing PeftModel | |
| logger.info("Loading adapter '%s' into existing PeftModel...", language) | |
| self._peft_model.load_adapter(adapter_path, adapter_name=language) | |
| self._active = language | |
| logger.info("Adapter '%s' loaded and active.", language) | |
| def activate(self, language: str) -> None: | |
| """ | |
| Hot-swap to a previously loaded adapter (~50ms). | |
| Call load_adapter() first if this adapter hasn't been loaded. | |
| """ | |
| if self._peft_model is None: | |
| self.load_adapter(language) | |
| return | |
| loaded = set(self._peft_model.peft_config.keys()) | |
| if language not in loaded: | |
| self.load_adapter(language) | |
| return | |
| self._peft_model.set_adapter(language) | |
| self._active = language | |
| logger.debug("Hot-swapped to adapter '%s'.", language) | |
| def get_model(self) -> "WhisperForConditionalGeneration | PeftModel": | |
| """Return the PeftModel (or base model if no adapter loaded yet).""" | |
| return self._peft_model if self._peft_model is not None else self._base_model | |
| def get_active(self) -> str | None: | |
| return self._active | |
| def list_available(self) -> list[str]: | |
| return list(self._registry.keys()) | |
| def list_loaded(self) -> list[str]: | |
| if self._peft_model is None: | |
| return [] | |
| return list(self._peft_model.peft_config.keys()) | |