diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000000000000000000000000000000000000..12587fc189df5c71df36ea522c04008085fa2188 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(pip show:*)" + ] + } +} diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..b9bf3ba7e96a5522d9f89362f86cbf0b76e4479f --- /dev/null +++ b/.env.example @@ -0,0 +1,20 @@ +# HuggingFace read token (required for accessing google/waxal dataset) +HF_TOKEN=hf_your_token_here + +# Model +MODEL_ID=openai/whisper-large-v3-turbo + +# Adapter paths (relative to project root) +BAMBARA_ADAPTER_PATH=./adapters/bambara +FULA_ADAPTER_PATH=./adapters/fula + +# IoT sensor API endpoint (leave empty to use mock data in development) +SENSOR_API_URL= + +# FastAPI server +API_HOST=0.0.0.0 +API_PORT=8000 +LOG_LEVEL=INFO + +# Device: "cuda" for GPU, "cpu" for CPU-only +DEVICE=cuda diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..27956a691f4b2a18736838014870cc2b37d0e223 --- /dev/null +++ b/.gitignore @@ -0,0 +1,66 @@ +# Python +__pycache__/ +*.py[cod] +*.pyo +*.pyd +.Python +*.egg-info/ +dist/ +build/ +.eggs/ + +# Environment +.env +venv/ +.venv/ +env/ + +# Model weights (large binary files) +*.pt +*.pth +*.bin +*.safetensors +*.ckpt + +# ONNX / TFLite exports +*.onnx +*.tflite +models/onnx/ +models/tflite/ + +# HuggingFace cache +data_cache/ +.cache/ + +# Audio noise samples (user must provide their own) +noise_samples/*.wav +noise_samples/*.mp3 +noise_samples/*.ogg + +# Trained adapters (tracked separately or via DVC) +adapters/bambara/ +adapters/fula/ + +# IDE +.vscode/settings.json +.idea/ +*.code-workspace + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log +logs/ + +# Local feedback data (audio + corrections live in HF Dataset repo, not git) +feedback/ + +# Local model downloads +models/ + +# Pytest +.pytest_cache/ +htmlcov/ +.coverage diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000000000000000000000000000000000000..20b81770f4c1d7dcdb7595e11eb3e642a4f5a48d --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "anthropic.claude-code" + ] +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..930c751c9530c249578e97d03312603738112ec2 --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +--- +title: Sahel-Agri Voice AI +emoji: 🌾 +colorFrom: green +colorTo: yellow +sdk: gradio +sdk_version: "4.44.0" +app_file: app.py +hardware: cpu-basic +pinned: false +license: mit +tags: + - agriculture + - bambara + - fula + - speech-recognition + - text-to-speech + - west-africa + - low-resource-nlp +--- + +# 🌾 Sahel-Agri Voice AI + +Two-way voice assistant for Malian and Guinean farmers. Speak in **Bambara** or **Fula** β€” get agricultural insights spoken back in your language. + +## Features +- πŸŽ™οΈ Voice input via microphone or file upload +- 🌍 Bambara (bam) and Fula (ful) speech recognition via Whisper + LoRA adapters +- πŸ”Š Native-language voice responses via Facebook MMS-TTS +- πŸ“Š Soil, weather, irrigation, and pest alerts from IoT sensors +- πŸ’Ύ Feedback saved to HuggingFace Dataset for continuous improvement + +## Languages supported +| Language | STT | TTS | +|----------|-----|-----| +| Bambara (bam) | βœ… Whisper + LoRA | βœ… facebook/mms-tts-bam | +| Fula (ful) | βœ… Whisper + LoRA | βœ… facebook/mms-tts-ful | +| French (fr) | βœ… Whisper | βœ… facebook/mms-tts-fra | +| English (en) | βœ… Whisper | βœ… facebook/mms-tts-eng | diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f63075e79fb9e49ebbda237da0e18fe7be6570dd --- /dev/null +++ b/app.py @@ -0,0 +1,611 @@ +""" +Sahel-Agri Voice AI β€” HuggingFace Spaces (ZeroGPU) +Two-way voice assistant: Bambara / Fula / French / English β†’ voice response + +Environment variables (set in Space Settings β†’ Secrets): + HF_TOKEN β€” HF write-access token + FEEDBACK_REPO_ID β€” e.g. ous-sow/sahel-agri-feedback (dataset, private) + ADAPTER_REPO_ID β€” e.g. ous-sow/sahel-agri-adapters (model, private) + WHISPER_MODEL_ID β€” default: openai/whisper-large-v3-turbo + (use openai/whisper-base for local CPU testing) +""" + +from __future__ import annotations + +import io +import json +import os +import sys +import tempfile +import threading +from datetime import datetime, timezone +from pathlib import Path + +import gradio as gr +import numpy as np + +ROOT = Path(__file__).parent +sys.path.insert(0, str(ROOT)) + +# ── env ─────────────────────────────────────────────────────────────────────── +HF_TOKEN = os.environ.get("HF_TOKEN") +FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback") +ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters") +# whisper-small: ~10s on cpu-basic, good multilingual quality. +# Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later. +WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small") + +# On local CPU (no HF_TOKEN / no spaces package) fall back gracefully +_ON_SPACES = os.environ.get("SPACE_ID") is not None + +SUPPORTED_LANGUAGES = { + "Bambara (bam)": "bam", + "Fula (ful)": "ful", + "French / FranΓ§ais": "fr", + "English": "en", +} + +# ── ZeroGPU decorator (no-op locally) ──────────────────────────────────────── +try: + import spaces # type: ignore + _gpu = spaces.GPU(duration=55) +except ImportError: + def _gpu(fn): # local fallback: plain function + return fn + +# ── Module-level model state (CPU-resident between requests) ───────────────── +_whisper_model = None # WhisperForConditionalGeneration (base) +_whisper_processor = None +_adapter_manager = None # AdapterManager (wraps base model with PEFT if adapters loaded) +_model_lock = threading.Lock() +_model_status = "not loaded" +_adapters_loaded = set() # set of language codes with loaded adapters, e.g. {"bam", "ful"} + +from src.tts.mms_tts import MMSTTSEngine +from src.iot.intent_parser import IntentParser +from src.iot.sensor_bridge import SensorBridge +from src.iot.voice_responder import VoiceResponder + +_tts = MMSTTSEngine() +_intent_parser = IntentParser() +_sensor_bridge = SensorBridge() + +# HF API β€” only instantiate when token present +_hf_api = None +if HF_TOKEN: + from huggingface_hub import HfApi + _hf_api = HfApi(token=HF_TOKEN) + + +# ── Model loading ───────────────────────────────────────────────────────────── + +def _do_load_whisper(): + global _whisper_model, _whisper_processor, _adapter_manager, _model_status + import torch + from transformers import WhisperForConditionalGeneration, WhisperProcessor + from src.engine.adapter_manager import AdapterManager + + _model_status = "loading…" + try: + _whisper_processor = WhisperProcessor.from_pretrained( + WHISPER_MODEL_ID, token=HF_TOKEN + ) + _whisper_model = WhisperForConditionalGeneration.from_pretrained( + WHISPER_MODEL_ID, + torch_dtype=torch.float32, + token=HF_TOKEN, + ) + _whisper_model.eval() + + # Create the AdapterManager wrapping the base model + _adapter_manager = AdapterManager(base_model=_whisper_model, config={}) + + # Try to load adapters from the local adapter repo snapshot (if already downloaded) + _try_load_local_adapters() + + _model_status = f"ready ({WHISPER_MODEL_ID})" + except Exception as e: + _model_status = f"error: {e}" + + +def _try_load_local_adapters() -> None: + """Load any adapter snapshots that are already on disk (downloaded previously).""" + global _adapters_loaded + if _adapter_manager is None: + return + if not ADAPTER_REPO_ID: + return + try: + from huggingface_hub import try_to_load_from_cache + lang_dirs = {"bam": "adapters/bambara", "ful": "adapters/fula"} + for lang, subdir in lang_dirs.items(): + cached = try_to_load_from_cache( + repo_id=ADAPTER_REPO_ID, + filename=f"{subdir}/adapter_config.json", + repo_type="model", + token=HF_TOKEN, + ) + if cached: + import os + adapter_path = str(os.path.dirname(cached)) + _adapter_manager.register(lang, adapter_path) + try: + _adapter_manager.load_adapter(lang) + _adapters_loaded.add(lang) + except Exception: + pass + except Exception: + pass # Adapters not cached yet β€” will load after first Hub download + + +def _ensure_whisper_loaded(): + """Load Whisper to CPU in a background thread on first call. Non-blocking.""" + global _model_status + with _model_lock: + if _whisper_model is None and "loading" not in _model_status and "error" not in _model_status: + t = threading.Thread(target=_do_load_whisper, daemon=True) + t.start() + return _model_status + + +def get_model_status() -> str: + s = _ensure_whisper_loaded() + if "ready" in s: + return f"🟒 {s}" + if "loading" in s: + return f"🟑 {s}" + if "error" in s: + return f"πŸ”΄ {s}" + return f"βšͺ {s}" + + +# ── Core GPU pipeline ───────────────────────────────────────────────────────── + +@_gpu +def _run_pipeline(audio_path: str, language_code: str): + """ + Full STT β†’ Intent β†’ Sensor β†’ TTS pipeline. + Decorated with @spaces.GPU(duration=55) on HF Spaces; plain function locally. + Returns: (transcript, response_text, (sample_rate, wav_np)) + """ + import asyncio + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # ── 1. Whisper STT ──────────────────────────────────────────────────────── + if _whisper_model is None: + return "⏳ Model still loading…", "", None + + import librosa + + audio_np, _ = librosa.load(audio_path, sr=16000, mono=True) + + # Use adapter-wrapped model if an adapter for this language is loaded; + # otherwise fall back to base Whisper. + if _adapter_manager is not None and language_code in _adapters_loaded: + _adapter_manager.activate(language_code) + active_model = _adapter_manager.get_model() + else: + active_model = _whisper_model + + active_model.to(device) + with _model_lock: + inputs = _whisper_processor.feature_extractor( + audio_np, sampling_rate=16000, return_tensors="pt" + ) + input_features = inputs.input_features.to(device) + + # Bambara and Fula have no Whisper language token β€” pass None so the model + # auto-detects or falls back to multilingual decoding. + if language_code in ("bam", "ful"): + forced_ids = None + else: + forced_ids = _whisper_processor.get_decoder_prompt_ids( + language=language_code, task="transcribe" + ) + + with torch.no_grad(): + predicted_ids = active_model.generate( + input_features, + forced_decoder_ids=forced_ids if forced_ids else None, + max_new_tokens=256, + ) + + transcript = _whisper_processor.batch_decode( + predicted_ids, skip_special_tokens=True + )[0].strip() + + # Free GPU VRAM before TTS + active_model.to("cpu") + if device == "cuda": + torch.cuda.empty_cache() + + # ── 2. Intent + sensor data (CPU) ───────────────────────────────────────── + intent = _intent_parser.parse(transcript, language=language_code) + + try: + loop = asyncio.new_event_loop() + sensor_data = loop.run_until_complete(_sensor_bridge.fetch(intent)) + loop.close() + except Exception: + from src.iot.sensor_bridge import SensorData + sensor_data = SensorData(sensor_type="soil", values={ + "moisture_pct": 45.0, "ph": 6.5, "temperature_c": 28.0 + }) + + responder = VoiceResponder(language=language_code) + response_text = responder.generate_response(intent, sensor_data) + + # ── 3. MMS-TTS (GPU) ────────────────────────────────────────────────────── + wav_np, sample_rate = _tts.synthesize(response_text, language_code, device=device) + + return transcript, response_text, (sample_rate, wav_np) + + +# ── HF Hub feedback persistence ─────────────────────────────────────────────── + +def _save_feedback_to_hub( + audio_path: str | None, + transcript: str, + corrected_text: str, + response_text: str, + rating: int, + notes: str, + language_label: str, +) -> str: + language_code = SUPPORTED_LANGUAGES.get(language_label, "bam") + + if not corrected_text.strip(): + return "⚠️ Corrected text is empty." + + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f") + + record = { + "id": timestamp, + "timestamp": datetime.now(timezone.utc).isoformat(), + "language": language_code, + "audio_file": f"audio/{language_code}_{timestamp}.wav", + "whisper_output": transcript, + "corrected_text": corrected_text.strip(), + "response_text": response_text, + "rating": rating, + "notes": notes.strip(), + "is_correction": transcript.strip() != corrected_text.strip(), + "model": WHISPER_MODEL_ID, + } + + if _hf_api is None: + # Local: save to disk instead + fb_dir = ROOT / "feedback" + fb_dir.mkdir(exist_ok=True) + (fb_dir / "audio").mkdir(exist_ok=True) + corrections_path = fb_dir / "corrections.jsonl" + if audio_path: + import shutil + shutil.copy2(audio_path, fb_dir / "audio" / f"{language_code}_{timestamp}.wav") + with open(corrections_path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + total = sum(1 for _ in open(corrections_path, encoding="utf-8")) + return f"βœ… Saved locally (#{total}) β€” HF_TOKEN not set, Hub upload skipped." + + try: + # Upload audio + if audio_path: + _hf_api.upload_file( + path_or_fileobj=audio_path, + path_in_repo=f"audio/{language_code}_{timestamp}.wav", + repo_id=FEEDBACK_REPO_ID, + repo_type="dataset", + ) + + # Download β†’ append β†’ re-upload corrections.jsonl (with retry on conflict) + from huggingface_hub import hf_hub_download + for attempt in range(2): + try: + local_jsonl = hf_hub_download( + repo_id=FEEDBACK_REPO_ID, + filename="corrections.jsonl", + repo_type="dataset", + token=HF_TOKEN, + ) + with open(local_jsonl, encoding="utf-8") as f: + existing = f.read() + except Exception: + existing = "" + + updated = existing + json.dumps(record, ensure_ascii=False) + "\n" + buf = io.BytesIO(updated.encode("utf-8")) + + try: + _hf_api.upload_file( + path_or_fileobj=buf, + path_in_repo="corrections.jsonl", + repo_id=FEEDBACK_REPO_ID, + repo_type="dataset", + ) + break + except Exception as e: + if attempt == 1: + return f"⚠️ Audio uploaded but corrections.jsonl update failed: {e}" + + total = updated.count("\n") + return f"βœ… Saved to Hub (#{total}) β€” {FEEDBACK_REPO_ID}" + + except Exception as e: + return f"❌ Hub upload error: {e}" + + +# ── Adapter reload ──────────────────────────────────────────────────────────── + +def _reload_adapters_from_hub() -> str: + global _adapters_loaded + if _hf_api is None: + return "⚠️ HF_TOKEN not set β€” cannot download adapters." + if _adapter_manager is None: + return "⏳ Base model not loaded yet β€” wait for model to finish loading and try again." + try: + from huggingface_hub import snapshot_download + local_dir = snapshot_download( + repo_id=ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN + ) + results = [] + for lang, subdir in (("bam", "adapters/bambara"), ("ful", "adapters/fula")): + adapter_path = Path(local_dir) / subdir + if not adapter_path.exists(): + results.append(f"⚠️ {lang}: `{subdir}` not found in repo") + continue + # Check that this looks like a valid PEFT adapter + if not (adapter_path / "adapter_config.json").exists(): + results.append(f"⚠️ {lang}: `{subdir}` missing adapter_config.json β€” run training first") + continue + try: + _adapter_manager.register(lang, str(adapter_path)) + _adapter_manager.load_adapter(lang) + _adapters_loaded.add(lang) + results.append(f"βœ… {lang}: adapter loaded from `{subdir}`") + except Exception as e: + results.append(f"❌ {lang}: load failed β€” {e}") + + summary = "\n".join(results) + active = ", ".join(_adapters_loaded) if _adapters_loaded else "none" + return f"{summary}\n\n**Active adapters:** {active}\n**Repo:** `{ADAPTER_REPO_ID}`" + except Exception as e: + return f"❌ Adapter reload failed: {e}" + + +def _get_adapter_status() -> str: + lines = [] + + # Show which adapters are currently active in memory + if _adapters_loaded: + lines.append(f"**Active adapters (in memory):** {', '.join(sorted(_adapters_loaded))}") + else: + lines.append("**Active adapters:** none β€” using base Whisper") + + if _hf_api is None: + lines.append("_HF_TOKEN not set β€” Hub check skipped._") + return "\n".join(lines) + + try: + from huggingface_hub import list_repo_files + files = list(list_repo_files(ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN)) + bam_ok = any("bambara" in f and "adapter_config" in f for f in files) + ful_ok = any("fula" in f and "adapter_config" in f for f in files) + lines += [ + f"\n**Hub repo:** `{ADAPTER_REPO_ID}`", + f"- Bambara (bam): {'βœ… trained adapter present' if bam_ok else '⚠️ not yet trained β€” run bootstrap notebook'}", + f"- Fula (ful): {'βœ… trained adapter present' if ful_ok else '⚠️ not yet trained β€” run bootstrap notebook'}", + ] + if bam_ok or ful_ok: + lines.append("\n_Click **Reload Adapters** to activate them._") + except Exception as e: + lines.append(f"_Could not read Hub repo: {e}_") + + return "\n".join(lines) + + +# ── Main ask handler ────────────────────────────────────────────────────────── + +def handle_ask(audio_path, language_label): + if audio_path is None: + return "⚠️ No audio β€” press Record or upload a file.", "", None + + language_code = SUPPORTED_LANGUAGES.get(language_label, "bam") + status = _ensure_whisper_loaded() + + if _whisper_model is None: + return f"⏳ Model loading ({status}). Wait a moment and try again.", "", None + + try: + transcript, response_text, audio_out = _run_pipeline(audio_path, language_code) + return transcript, response_text, audio_out + except Exception as e: + return f"❌ {e}", "", None + + +# ── Gradio UI ───────────────────────────────────────────────────────────────── + +def build_ui() -> gr.Blocks: + with gr.Blocks(title="Sahel-Agri Voice AI") as demo: + gr.Markdown("# 🌾 Sahel-Agri Voice AI") + gr.Markdown( + "Speak in **Bambara** or **Fula** β€” get agricultural insights spoken back " + "in your language. Also supports French and English." + ) + + model_status_box = gr.Textbox( + value=get_model_status, + label="Model status", + interactive=False, + every=3, + ) + + with gr.Tabs(): + + # ── Tab 1: Voice Assistant ──────────────────────────────────────── + with gr.TabItem("πŸŽ™οΈ Voice Assistant"): + with gr.Row(): + with gr.Column(scale=1): + language_dd = gr.Dropdown( + choices=list(SUPPORTED_LANGUAGES.keys()), + value="Bambara (bam)", + label="Language / Kan", + ) + audio_input = gr.Audio( + sources=["microphone", "upload"], + type="filepath", + label="Record or upload audio", + ) + ask_btn = gr.Button("β–Ά Ask / ƝinΙ›", variant="primary") + + with gr.Column(scale=1): + transcript_box = gr.Textbox( + label="Whisper heard", + lines=3, + placeholder="Your words will appear here…", + interactive=False, + ) + response_box = gr.Textbox( + label="Response / Jaabi", + lines=3, + placeholder="Agricultural advice will appear here…", + interactive=False, + ) + audio_output = gr.Audio( + label="Voice response", + autoplay=True, + interactive=False, + ) + + ask_btn.click( + fn=handle_ask, + inputs=[audio_input, language_dd], + outputs=[transcript_box, response_box, audio_output], + ) + + # ── Tab 2: Feedback & Correction ───────────────────────────────── + with gr.TabItem("πŸ“ Feedback & Correction"): + gr.Markdown( + "Help improve the model by correcting transcription errors. " + "Your audio and corrections are saved to the training dataset." + ) + with gr.Row(): + with gr.Column(): + fb_lang = gr.Dropdown( + choices=list(SUPPORTED_LANGUAGES.keys()), + value="Bambara (bam)", + label="Language", + ) + fb_audio = gr.Audio( + sources=["microphone", "upload"], + type="filepath", + label="Audio (re-record or upload)", + ) + fb_transcript = gr.Textbox( + label="Whisper output (what it heard)", + lines=3, + placeholder="Paste or type what Whisper said…", + ) + fb_corrected = gr.Textbox( + label="Corrected transcription (what was actually said)", + lines=3, + placeholder="Type the correct text here…", + ) + + with gr.Column(): + fb_response = gr.Textbox( + label="Response text (optional β€” for rating)", + lines=2, + placeholder="Copy the response from Tab 1…", + ) + fb_rating = gr.Slider( + minimum=1, maximum=5, step=1, value=3, + label="Response quality (1 = poor, 5 = excellent)", + ) + fb_notes = gr.Textbox( + label="Notes (optional)", + lines=2, + placeholder="e.g. noisy background, strong accent…", + ) + save_btn = gr.Button("πŸ’Ύ Save to Dataset", variant="secondary") + save_status = gr.Textbox( + label="Save status", interactive=False, lines=2 + ) + + save_btn.click( + fn=_save_feedback_to_hub, + inputs=[ + fb_audio, fb_transcript, fb_corrected, + fb_response, fb_rating, fb_notes, fb_lang, + ], + outputs=[save_status], + ) + + # ── Tab 3: Training Status ──────────────────────────────────────── + with gr.TabItem("πŸ”§ Training Status"): + gr.Markdown( + "After collecting β‰₯10 corrections per language, run the training " + "notebook on Google Colab (free GPU), then reload adapters here." + ) + adapter_status_md = gr.Markdown(value=_get_adapter_status()) + reload_btn = gr.Button("πŸ”„ Reload Adapters from Hub") + reload_out = gr.Markdown() + + gr.Markdown("---") + gr.Markdown( + "**Training notebook**: " + "`notebooks/train_colab.ipynb` β€” open in Colab, run all cells." + ) + gr.Markdown( + "**Feedback dataset**: " + f"`{FEEDBACK_REPO_ID}` (private, auto-updated on each save)" + ) + gr.Markdown( + "**Adapter repo**: " + f"`{ADAPTER_REPO_ID}` (private, updated after each training run)" + ) + + reload_btn.click( + fn=_reload_adapters_from_hub, + outputs=[reload_out], + ) + reload_btn.click( + fn=_get_adapter_status, + outputs=[adapter_status_md], + ) + + return demo + + +# ── Entry point ─────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + from dotenv import load_dotenv + load_dotenv() + + # Re-read env after dotenv + HF_TOKEN = os.environ.get("HF_TOKEN") + FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback") + ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters") + WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small") + + if HF_TOKEN: + from huggingface_hub import HfApi + _hf_api = HfApi(token=HF_TOKEN) + + # Kick off background model load immediately + _ensure_whisper_loaded() + + print(f"Whisper model : {WHISPER_MODEL_ID}") + print(f"Feedback repo : {FEEDBACK_REPO_ID}") + print(f"Adapter repo : {ADAPTER_REPO_ID}") + print(f"HF_TOKEN set : {'yes' if HF_TOKEN else 'no (local-only mode)'}") + print() + + demo = build_ui() + demo.launch( + server_port=9001, + inbrowser=True, + share=False, + ) diff --git a/configs/api_config.yaml b/configs/api_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93ff802de271e62687d640079667648535a3ac86 --- /dev/null +++ b/configs/api_config.yaml @@ -0,0 +1,21 @@ +server: + host: "0.0.0.0" + port: 8000 + workers: 1 # Single worker: shares GPU model in memory + timeout_keep_alive: 30 + +inference: + default_language: "bam" + max_audio_size_mb: 10 + supported_languages: + - "bam" + - "ful" + +iot: + sensor_poll_timeout_s: 5 + response_language: "fr" # French for farmer-facing TTS output + intent_confidence_threshold: 0.7 + +rate_limit: + requests_per_minute: 60 + burst: 10 diff --git a/configs/base_config.yaml b/configs/base_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5a81b27af3f0e6b62b9e4b1115fcd238ce14a42 --- /dev/null +++ b/configs/base_config.yaml @@ -0,0 +1,30 @@ +model: + id: "openai/whisper-large-v3-turbo" + task: "transcribe" + max_new_tokens: 128 + chunk_length_s: 30 + +training: + output_dir: "./adapters" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + warmup_steps: 200 + max_steps: 4000 + save_steps: 500 + eval_steps: 500 + learning_rate: 1.0e-4 + fp16: true + # CRITICAL on Windows: multiprocessing spawn breaks with tokenizers + dataloader_num_workers: 0 + +audio: + sample_rate: 16000 + max_duration_s: 30 + noise_snr_db_range: [5, 20] + augmentation_prob: 0.6 + +paths: + data_cache: "./data_cache" + adapters: "./adapters" + models: "./models" + noise_samples: "./noise_samples" diff --git a/configs/lora_bambara.yaml b/configs/lora_bambara.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea8bf5b64504f397b191cda69837085deb97798d --- /dev/null +++ b/configs/lora_bambara.yaml @@ -0,0 +1,19 @@ +language: "bam" +language_code: "bm" # ISO 639-1 code used for Whisper forced_decoder_ids +dataset_subset: "bam" +adapter_name: "bambara" +output_dir: "./adapters/bambara" + +lora: + r: 32 + lora_alpha: 64 + target_modules: + - "q_proj" + - "v_proj" + - "k_proj" + - "out_proj" + - "fc1" + - "fc2" + lora_dropout: 0.05 + bias: "none" + task_type: "SEQ_2_SEQ_LM" diff --git a/configs/lora_fula.yaml b/configs/lora_fula.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fb2eedfc4bab1fc8e7e4d1f92ac05e81a1499bd4 --- /dev/null +++ b/configs/lora_fula.yaml @@ -0,0 +1,19 @@ +language: "ful" +language_code: "ff" # ISO 639-1 code used for Whisper forced_decoder_ids +dataset_subset: "ful" +adapter_name: "fula" +output_dir: "./adapters/fula" + +lora: + r: 16 # Smaller rank β€” Fula dataset is smaller than Bambara + lora_alpha: 32 + target_modules: + - "q_proj" + - "v_proj" + - "k_proj" + - "out_proj" + - "fc1" + - "fc2" + lora_dropout: 0.05 + bias: "none" + task_type: "SEQ_2_SEQ_LM" diff --git a/noise_samples/README.md b/noise_samples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..406995b2f67f1216cebd8c16d1b37cea71372ab5 --- /dev/null +++ b/noise_samples/README.md @@ -0,0 +1,20 @@ +# Field Noise Samples + +Place `.wav` audio files here to enable realistic field-noise augmentation during training. + +## Required Files (16kHz mono, any duration β‰₯5s) +- `tractor_engine.wav` β€” diesel tractor idling or working +- `wind_field.wav` β€” wind in open farmland +- `livestock_ambient.wav` β€” cattle, goats, or chickens in background + +## Suggested Sources +- [Freesound.org](https://freesound.org) β€” search "tractor", "wind field", "livestock ambient" (filter by CC0 / CC-BY) +- Field recordings from partner NGOs or agricultural organizations in Mali/Guinea + +## Licensing Note +Ensure all audio files are licensed for use in ML training datasets. +CC0 (public domain) or CC-BY are preferred. + +## Without Noise Files +The augmenter will fall back to Gaussian noise only. +Training will still work but model robustness to real-world conditions may be reduced. diff --git a/notebooks/bootstrap_repos.ipynb b/notebooks/bootstrap_repos.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..7f35bf205783087161e949cf5985408de94cf1be --- /dev/null +++ b/notebooks/bootstrap_repos.ipynb @@ -0,0 +1,308 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "id": "cell-title", + "metadata": {}, + "source": [ + "# 🌾 Sahel-Agri Voice AI β€” One-Time Bootstrap\n", + "\n", + "**Run this notebook ONCE** before deploying your Space. It:\n", + "\n", + "1. Creates the three HuggingFace repos (`sahel-agri-feedback`, `sahel-agri-adapters`, `sahel-agri-voice`)\n", + "2. Seeds the feedback dataset with a `corrections.jsonl` placeholder\n", + "3. Trains v0 LoRA adapters for **Bambara** and **Fula** on the full Google Waxal dataset\n", + "4. Pushes adapters to `ous-sow/sahel-agri-adapters`\n", + "\n", + "After this notebook completes, push your project code to the Space and your app will start\n", + "with working Bambara/Fula speech recognition from day 1 β€” **no user corrections needed yet**.\n", + "\n", + "For subsequent improvement runs (after collecting farmer feedback), use `train_colab.ipynb`.\n", + "\n", + "---\n", + "**Before running:** Runtime β†’ Change runtime type β†’ **T4 GPU**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-gpu-check", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 1 β€” GPU check\n", + "import subprocess\n", + "result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n", + "if result.returncode != 0:\n", + " raise RuntimeError('No GPU! Runtime β†’ Change runtime type β†’ T4 GPU')\n", + "print(result.stdout[:500])\n", + "print('βœ… GPU ready')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-install", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 2 β€” Install dependencies\n", + "!pip install -q \\\n", + " torch==2.11.0 torchaudio==2.11.0 \\\n", + " transformers==5.5.0 datasets==4.8.4 \\\n", + " accelerate==1.13.0 evaluate==0.4.2 \\\n", + " huggingface-hub==1.9.0 peft==0.18.1 \\\n", + " librosa==0.10.2 soundfile==0.12.1 \\\n", + " jiwer==3.0.4 pyyaml==6.0.2\n", + "print('βœ… Packages installed')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-hf-login", + "metadata": {}, + "outputs": [], + "source": "# Cell 3 β€” HuggingFace login\n# Colab: πŸ”‘ icon (left sidebar) β†’ Add new secret β†’ name=HF_TOKEN\nimport os\ntry:\n from google.colab import userdata # type: ignore\n HF_TOKEN = userdata.get('HF_TOKEN')\nexcept Exception:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError(\n 'HF_TOKEN not found.\\n'\n 'Colab: click the πŸ”‘ icon β†’ Add new secret β†’ name=HF_TOKEN'\n )\n\nfrom huggingface_hub import login, HfApi\nlogin(token=HF_TOKEN, add_to_git_credential=False)\napi = HfApi(token=HF_TOKEN)\n\nHF_USERNAME = 'ous-sow'\nFEEDBACK_REPO_ID = f'{HF_USERNAME}/sahel-agri-feedback'\nADAPTER_REPO_ID = f'{HF_USERNAME}/sahel-agri-adapters'\nSPACE_REPO_ID = f'{HF_USERNAME}/sahel-agri-voice'\n# whisper-small trains on Colab T4 in ~25 min and runs on CPU in ~10s.\n# Change to 'openai/whisper-large-v3-turbo' only if you upgrade to a GPU Space.\nWHISPER_MODEL_ID = 'openai/whisper-small'\n\nprint(f'βœ… Logged in as {HF_USERNAME}')" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-create-repos", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 4 β€” Create HuggingFace repos (skips if they already exist)\n", + "from huggingface_hub import RepoUrl\n", + "\n", + "def create_repo_if_missing(repo_id, repo_type, private=True):\n", + " try:\n", + " url = api.create_repo(\n", + " repo_id=repo_id,\n", + " repo_type=repo_type,\n", + " private=private,\n", + " exist_ok=True,\n", + " )\n", + " print(f' βœ… {repo_type}: {repo_id}')\n", + " return url\n", + " except Exception as e:\n", + " print(f' ⚠️ {repo_id}: {e}')\n", + "\n", + "print('Creating repos...')\n", + "create_repo_if_missing(FEEDBACK_REPO_ID, 'dataset', private=True)\n", + "create_repo_if_missing(ADAPTER_REPO_ID, 'model', private=True)\n", + "create_repo_if_missing(SPACE_REPO_ID, 'space', private=False)\n", + "\n", + "# Seed the feedback dataset with an empty corrections.jsonl\n", + "import io\n", + "try:\n", + " api.upload_file(\n", + " path_or_fileobj=io.BytesIO(b''),\n", + " path_in_repo='corrections.jsonl',\n", + " repo_id=FEEDBACK_REPO_ID,\n", + " repo_type='dataset',\n", + " commit_message='Init: empty corrections.jsonl',\n", + " )\n", + " print(f' βœ… {FEEDBACK_REPO_ID}/corrections.jsonl initialised')\n", + "except Exception as e:\n", + " print(f' ⚠️ corrections.jsonl upload: {e} (may already exist)')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-clone-space", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 5 β€” Clone Space code (so we can use src/ and configs/)\n", + "# If the Space is brand new and has no code yet, clone from the local zip instead.\n", + "import sys\n", + "from pathlib import Path\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "try:\n", + " space_dir = Path(snapshot_download(\n", + " repo_id=SPACE_REPO_ID, repo_type='space', token=HF_TOKEN\n", + " ))\n", + " print(f'Space code: {space_dir}')\n", + "except Exception as e:\n", + " print(f'Could not download Space ({e})')\n", + " print('Uploading project code to Space first...')\n", + " # If you have the project on Colab already (e.g. mounted Drive), set:\n", + " # space_dir = Path('/content/drive/MyDrive/voice-model')\n", + " # Otherwise upload via git (see README step 6) and re-run this cell.\n", + " raise RuntimeError(\n", + " 'Push your project to the Space first:\\n'\n", + " ' git remote add space https://huggingface.co/spaces/ous-sow/sahel-agri-voice\\n'\n", + " ' git push space main\\n'\n", + " 'Then re-run this notebook.'\n", + " )\n", + "\n", + "sys.path.insert(0, str(space_dir))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-train-bam", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 6 β€” Train v0 Bambara adapter on full Waxal (bam)\n", + "#\n", + "# Uses streaming β€” Waxal is ~4h of audio, we cap at 2000 samples for Colab budget.\n", + "# Full training (~4000 steps) on the entire dataset: use a Kaggle P100 (12h limit).\n", + "import os, yaml\n", + "os.environ['HF_TOKEN'] = HF_TOKEN\n", + "\n", + "from src.training.trainer import WhisperLoRATrainer\n", + "\n", + "WAXAL_CAP = 2000 # raise to 10000+ on Kaggle for a stronger v0 model\n", + "\n", + "base_cfg = str(space_dir / 'configs' / 'base_config.yaml')\n", + "bam_cfg_src = str(space_dir / 'configs' / 'lora_bambara.yaml')\n", + "bam_out = '/tmp/sahel_adapter_bam'\n", + "\n", + "# Override output_dir\n", + "with open(bam_cfg_src) as f:\n", + " bam_config = yaml.safe_load(f)\n", + "bam_config['output_dir'] = bam_out\n", + "tmp_bam_cfg = '/tmp/lora_bam.yaml'\n", + "with open(tmp_bam_cfg, 'w') as f:\n", + " yaml.dump(bam_config, f)\n", + "\n", + "# Also override max_steps in base config to match Waxal cap\n", + "with open(base_cfg) as f:\n", + " base_config = yaml.safe_load(f)\n", + "# ~2 steps per sample @ batch_size=4, gradient_acc=4\n", + "base_config['training']['max_steps'] = max(500, WAXAL_CAP // 8)\n", + "tmp_base_cfg = '/tmp/base_config.yaml'\n", + "with open(tmp_base_cfg, 'w') as f:\n", + " yaml.dump(base_config, f)\n", + "\n", + "print(f'Training Bambara v0 adapter (Waxal cap={WAXAL_CAP}, max_steps={base_config[\"training\"][\"max_steps\"]})...')\n", + "trainer_bam = WhisperLoRATrainer(\n", + " base_config_path=tmp_base_cfg,\n", + " language_config_path=tmp_bam_cfg,\n", + ")\n", + "trainer_bam.setup()\n", + "\n", + "# No feedback yet β€” materialise Waxal and train\n", + "trainer_bam.merge_extra_data([], repeat=1, waxal_cap=WAXAL_CAP)\n", + "\n", + "trainer_bam.train()\n", + "print(f'βœ… Bambara v0 adapter saved to {bam_out}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-train-ful", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 7 β€” Train v0 Fula adapter on full Waxal (ful)\n", + "ful_cfg_src = str(space_dir / 'configs' / 'lora_fula.yaml')\n", + "ful_out = '/tmp/sahel_adapter_ful'\n", + "\n", + "with open(ful_cfg_src) as f:\n", + " ful_config = yaml.safe_load(f)\n", + "ful_config['output_dir'] = ful_out\n", + "tmp_ful_cfg = '/tmp/lora_ful.yaml'\n", + "with open(tmp_ful_cfg, 'w') as f:\n", + " yaml.dump(ful_config, f)\n", + "\n", + "print(f'Training Fula v0 adapter (Waxal cap={WAXAL_CAP})...')\n", + "trainer_ful = WhisperLoRATrainer(\n", + " base_config_path=tmp_base_cfg,\n", + " language_config_path=tmp_ful_cfg,\n", + ")\n", + "trainer_ful.setup()\n", + "trainer_ful.merge_extra_data([], repeat=1, waxal_cap=WAXAL_CAP)\n", + "trainer_ful.train()\n", + "print(f'βœ… Fula v0 adapter saved to {ful_out}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-push-adapters", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 8 β€” Push both adapters to HF Model repo\n", + "from huggingface_hub import HfApi\n", + "api = HfApi(token=HF_TOKEN)\n", + "\n", + "for lang, out_dir, path_in_repo in [\n", + " ('bam', bam_out, 'adapters/bambara'),\n", + " ('ful', ful_out, 'adapters/fula'),\n", + "]:\n", + " api.upload_folder(\n", + " folder_path=out_dir,\n", + " repo_id=ADAPTER_REPO_ID,\n", + " repo_type='model',\n", + " path_in_repo=path_in_repo,\n", + " commit_message=f'v0 {lang} adapter trained on Waxal (cap={WAXAL_CAP} samples)',\n", + " )\n", + " print(f'βœ… {lang} β†’ {ADAPTER_REPO_ID}/{path_in_repo}')\n", + "\n", + "print()\n", + "print('Bootstrap complete!')\n", + "print()\n", + "print('Next steps:')\n", + "print(' 1. Push your project code to the Space (git push space main)')\n", + "print(' 2. In Space Settings β†’ Secrets, add HF_TOKEN, FEEDBACK_REPO_ID, ADAPTER_REPO_ID')\n", + "print(' 3. Space will build β€” your app at https://huggingface.co/spaces/ous-sow/sahel-agri-voice')\n", + "print(' 4. Tab 3 β†’ Reload Adapters β€” Bambara + Fula adapters will be loaded')\n", + "print(' 5. Collect farmer corrections, then run train_colab.ipynb to keep improving')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-verify", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 9 β€” Quick verification: list what was pushed to the adapter repo\n", + "from huggingface_hub import list_repo_files\n", + "\n", + "files = sorted(list_repo_files(ADAPTER_REPO_ID, repo_type='model', token=HF_TOKEN))\n", + "print(f'Files in {ADAPTER_REPO_ID}:')\n", + "for f in files:\n", + " print(f' {f}')\n", + "\n", + "bam_ok = any('bambara/adapter_config.json' in f for f in files)\n", + "ful_ok = any('fula/adapter_config.json' in f for f in files)\n", + "print()\n", + "print(f'Bambara adapter: {\"βœ…\" if bam_ok else \"❌\"}')\n", + "print(f'Fula adapter: {\"βœ…\" if ful_ok else \"❌\"}')\n", + "\n", + "if bam_ok and ful_ok:\n", + " print('\\nπŸŽ‰ Both adapters ready. Your Space will use them automatically on the next reload.')\n", + "else:\n", + " print('\\n⚠️ Some adapters are missing β€” check the training cells above for errors.')" + ] + } + ] +} \ No newline at end of file diff --git a/notebooks/train_colab.ipynb b/notebooks/train_colab.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3aae600e17fd8412449a992a5baa3d2cc509791f --- /dev/null +++ b/notebooks/train_colab.ipynb @@ -0,0 +1,283 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "id": "cell-title", + "metadata": {}, + "source": [ + "# 🌾 Sahel-Agri Voice AI β€” Fine-tune on Farmer Feedback\n", + "\n", + "**Run after collecting β‰₯10 corrections in the Space.** \n", + "First run? Use `bootstrap_repos.ipynb` instead to train the v0 Waxal adapter.\n", + "\n", + "This notebook fine-tunes the existing LoRA adapter using:\n", + "- **Waxal baseline** (up to 500 samples) β€” keeps the model grounded\n", + "- **Farmer corrections** (3Γ— upsampled) β€” targeted improvement from real field use\n", + "\n", + "**Before running:** Runtime β†’ Change runtime type β†’ **T4 GPU**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-gpu-check", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 1 β€” GPU check\n", + "import subprocess\n", + "result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n", + "if result.returncode != 0:\n", + " raise RuntimeError('No GPU! Runtime β†’ Change runtime type β†’ T4 GPU')\n", + "print(result.stdout[:500])\n", + "print('βœ… GPU ready')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-install", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 2 β€” Install dependencies (matching Space versions)\n", + "!pip install -q \\\n", + " torch==2.11.0 torchaudio==2.11.0 \\\n", + " transformers==5.5.0 datasets==4.8.4 \\\n", + " accelerate==1.13.0 evaluate==0.4.2 \\\n", + " huggingface-hub==1.9.0 peft==0.18.1 \\\n", + " librosa==0.10.2 soundfile==0.12.1 \\\n", + " jiwer==3.0.4 pyyaml==6.0.2\n", + "print('βœ… Packages installed')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-hf-login", + "metadata": {}, + "outputs": [], + "source": "# Cell 3 β€” HuggingFace login\n# Colab: πŸ”‘ icon (left sidebar) β†’ Add new secret β†’ name=HF_TOKEN\n# Kaggle: Add Data β†’ add as Kaggle secret named HF_TOKEN\nimport os\ntry:\n from google.colab import userdata # type: ignore\n HF_TOKEN = userdata.get('HF_TOKEN')\nexcept Exception:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError('HF_TOKEN not found β€” see instructions above.')\n\nfrom huggingface_hub import login\nlogin(token=HF_TOKEN, add_to_git_credential=False)\n\nSPACE_REPO_ID = 'ous-sow/sahel-agri-voice'\nFEEDBACK_REPO_ID = 'ous-sow/sahel-agri-feedback'\nADAPTER_REPO_ID = 'ous-sow/sahel-agri-adapters'\n# Must match what the Space uses β€” whisper-small for cpu-basic, whisper-large-v3-turbo for GPU.\nWHISPER_MODEL_ID = 'openai/whisper-small'\nTRAIN_LANG = 'bam' # ← change to 'ful' for Fula\n\nprint(f'βœ… Logged in | training language: {TRAIN_LANG}')" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-download", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 4 β€” Download Space code and feedback corrections\n", + "import json, shutil, sys\n", + "from pathlib import Path\n", + "from huggingface_hub import snapshot_download, hf_hub_download\n", + "\n", + "# Get Space code (contains src/, configs/)\n", + "space_dir = Path(snapshot_download(\n", + " repo_id=SPACE_REPO_ID, repo_type='space', token=HF_TOKEN\n", + "))\n", + "sys.path.insert(0, str(space_dir))\n", + "print(f'Space code: {space_dir}')\n", + "\n", + "# Download feedback corrections.jsonl\n", + "jsonl_path = hf_hub_download(\n", + " repo_id=FEEDBACK_REPO_ID,\n", + " filename='corrections.jsonl',\n", + " repo_type='dataset',\n", + " token=HF_TOKEN,\n", + ")\n", + "with open(jsonl_path, encoding='utf-8') as f:\n", + " all_records = [json.loads(l) for l in f if l.strip()]\n", + "\n", + "corrections = [\n", + " r for r in all_records\n", + " if r.get('is_correction') and r['language'] == TRAIN_LANG\n", + "]\n", + "print(f'Total feedback records : {len(all_records)}')\n", + "print(f'Corrections for {TRAIN_LANG} : {len(corrections)}')\n", + "\n", + "if len(corrections) < 5:\n", + " print('⚠️ Very few corrections β€” consider collecting more before training.')\n", + " print(' Training will proceed with Waxal only (corrections will be skipped).')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-download-audio", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 5 β€” Download feedback audio files from HF Dataset repo\n", + "fb_audio_dir = Path('/tmp/sahel_feedback_audio')\n", + "fb_audio_dir.mkdir(exist_ok=True)\n", + "\n", + "skipped = 0\n", + "for rec in corrections:\n", + " local_path = fb_audio_dir / Path(rec['audio_file']).name\n", + " if local_path.exists():\n", + " continue\n", + " try:\n", + " dl = hf_hub_download(\n", + " repo_id=FEEDBACK_REPO_ID,\n", + " filename=rec['audio_file'],\n", + " repo_type='dataset',\n", + " token=HF_TOKEN,\n", + " )\n", + " shutil.copy(dl, local_path)\n", + " except Exception as e:\n", + " skipped += 1\n", + " print(f' skip {rec[\"audio_file\"]}: {e}')\n", + "\n", + "# Point records at local paths\n", + "for rec in corrections:\n", + " local = fb_audio_dir / Path(rec['audio_file']).name\n", + " if local.exists():\n", + " rec['audio_file'] = str(local)\n", + "\n", + "available = [r for r in corrections if Path(r['audio_file']).exists()]\n", + "print(f'Downloaded {len(available)} / {len(corrections)} audio files (skipped {skipped})')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-train", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 6 β€” Fine-tune: Waxal baseline + farmer corrections\n", + "#\n", + "# WhisperLoRATrainer.setup() loads Waxal (streaming).\n", + "# merge_extra_data() materialises Waxal (up to 500 samples),\n", + "# appends corrections (3Γ— upsampled), shuffles the combined dataset.\n", + "# train() runs standard Seq2SeqTrainer on the merged dataset.\n", + "\n", + "import os\n", + "os.environ['HF_TOKEN'] = HF_TOKEN\n", + "\n", + "from src.training.trainer import WhisperLoRATrainer\n", + "\n", + "lang_config_map = {'bam': 'lora_bambara.yaml', 'ful': 'lora_fula.yaml'}\n", + "base_cfg = str(space_dir / 'configs' / 'base_config.yaml')\n", + "lang_cfg = str(space_dir / 'configs' / lang_config_map[TRAIN_LANG])\n", + "output_dir = f'/tmp/sahel_adapter_{TRAIN_LANG}'\n", + "\n", + "# Override output_dir so adapter saves to /tmp on Colab\n", + "import yaml\n", + "with open(lang_cfg) as f:\n", + " lang_config = yaml.safe_load(f)\n", + "lang_config['output_dir'] = output_dir\n", + "tmp_lang_cfg = f'/tmp/lora_{TRAIN_LANG}_tmp.yaml'\n", + "with open(tmp_lang_cfg, 'w') as f:\n", + " yaml.dump(lang_config, f)\n", + "\n", + "trainer = WhisperLoRATrainer(\n", + " base_config_path=base_cfg,\n", + " language_config_path=tmp_lang_cfg,\n", + ")\n", + "trainer.setup()\n", + "\n", + "if available:\n", + " print(f'Merging {len(available)} corrections (Γ—3) with Waxal baseline (cap=500)...')\n", + " trainer.merge_extra_data(available, repeat=3, waxal_cap=500)\n", + "else:\n", + " print('No corrections available β€” training on Waxal only.')\n", + "\n", + "trainer.train()\n", + "print(f'βœ… Training complete β€” adapter at {output_dir}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-push", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 7 β€” Push adapter to HF Model repo\n", + "from huggingface_hub import HfApi\n", + "api = HfApi(token=HF_TOKEN)\n", + "\n", + "path_in_repo = 'adapters/bambara' if TRAIN_LANG == 'bam' else 'adapters/fula'\n", + "n_corrections = len(available)\n", + "\n", + "api.upload_folder(\n", + " folder_path=output_dir,\n", + " repo_id=ADAPTER_REPO_ID,\n", + " repo_type='model',\n", + " path_in_repo=path_in_repo,\n", + " commit_message=(\n", + " f'Fine-tune {TRAIN_LANG}: Waxal baseline + {n_corrections} farmer corrections'\n", + " ),\n", + ")\n", + "print(f'βœ… Pushed to {ADAPTER_REPO_ID}/{path_in_repo}')\n", + "print('\\nNext: Space β†’ Tab 3 β†’ Reload Adapters from Hub')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-sanity", + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 8 β€” Sanity check: compare WER before vs after adapter\n", + "import random, torch, librosa, jiwer\n", + "from transformers import WhisperForConditionalGeneration, WhisperProcessor\n", + "from peft import PeftModel\n", + "\n", + "if not available:\n", + " print('No test samples β€” skipping sanity check.')\n", + "else:\n", + " test_rec = random.choice(available)\n", + " print(f'Audio : {Path(test_rec[\"audio_file\"]).name}')\n", + " print(f'Expected : {test_rec[\"corrected_text\"]}')\n", + " print(f'Pre-train: {test_rec[\"whisper_output\"]}')\n", + "\n", + " # Load base + adapter\n", + " processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n", + " base = WhisperForConditionalGeneration.from_pretrained(\n", + " WHISPER_MODEL_ID, torch_dtype=torch.float16, token=HF_TOKEN\n", + " ).to('cuda')\n", + " model = PeftModel.from_pretrained(base, output_dir).eval()\n", + "\n", + " audio_np, _ = librosa.load(test_rec['audio_file'], sr=16000, mono=True)\n", + " feats = processor.feature_extractor(\n", + " audio_np, sampling_rate=16000, return_tensors='pt'\n", + " ).input_features.half().to('cuda')\n", + "\n", + " with torch.no_grad():\n", + " ids = model.generate(feats, max_new_tokens=256)\n", + " result = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()\n", + " print(f'Post-train: {result}')\n", + "\n", + " ref = test_rec['corrected_text']\n", + " wer_before = jiwer.wer(ref, test_rec['whisper_output']) if test_rec.get('whisper_output') else 1.0\n", + " wer_after = jiwer.wer(ref, result)\n", + " print(f'\\nWER before: {wer_before:.1%} β†’ WER after: {wer_after:.1%}')\n", + " if wer_after < wer_before:\n", + " print('βœ… Adapter improved transcription quality!')\n", + " else:\n", + " print('ℹ️ No improvement on this single sample β€” collect more corrections and retrain.')" + ] + } + ] +} \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..20645e641240cb419f5fc66c14c1447e91daf669 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..565f3a86262df3f5262f225683bc5e7421d8ebbd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,50 @@ +# ----------------------------------------------------------------------------- +# Sahel-Agri Voice AI β€” Python Dependencies +# HuggingFace Spaces (ZeroGPU) deployment β€” CUDA pre-installed, no +cu128 suffix +# +# Local CPU test: +# pip install -r requirements.txt +# ----------------------------------------------------------------------------- + +# PyTorch (CPU build β€” works on HF Spaces cpu-basic and locally) +torch==2.11.0 +torchaudio==2.11.0 + +# HuggingFace core +transformers==5.5.0 +datasets==4.8.4 +accelerate==1.13.0 +evaluate==0.4.2 +huggingface-hub==1.9.0 + +# PEFT (LoRA adapters) +peft==0.18.1 + +# Audio processing +librosa==0.10.2 +soundfile==0.12.1 +audiomentations==0.43.1 + +# Quantization (CPU: installs fine; 4-bit/8-bit requires GPU at runtime) +bitsandbytes==0.49.2 + +# Metrics +jiwer==3.0.4 + +# Config & environment +pyyaml==6.0.2 +python-dotenv==1.1.0 + +# Gradio (must match sdk_version in README.md) +gradio==4.44.0 + +# Pydantic v2 +pydantic==2.11.3 + +# Testing +pytest==8.3.5 +pytest-asyncio==0.26.0 + +# Utilities +numpy==2.2.4 +scipy==1.15.2 diff --git a/scripts/export_onnx.py b/scripts/export_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..e9838b28e9885cf1c6134467dcd0e91773d7cfb6 --- /dev/null +++ b/scripts/export_onnx.py @@ -0,0 +1,67 @@ +""" +Phase 4a: Merge LoRA adapters and export language-specific ONNX models. +Validates that ONNX WER is within 2% of PyTorch baseline. + +Usage: + python scripts/export_onnx.py +""" +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s β€” %(message)s") + +import yaml + +from src.optimization.onnx_exporter import ONNXExporter + + +def export_language(language: str, adapter_path: str, config: dict) -> None: + from peft import PeftModel + from transformers import WhisperForConditionalGeneration, WhisperProcessor + + hf_token = os.getenv("HF_TOKEN") + model_id = config["model"]["id"] + + print(f"\n[{language.upper()}] Loading base model...") + base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token) + processor = WhisperProcessor.from_pretrained(model_id, token=hf_token) + + print(f"[{language.upper()}] Loading adapter from {adapter_path}...") + peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language) + + output_dir = f"{config['paths']['models']}/onnx/{language}" + exporter = ONNXExporter() + result_path = exporter.merge_and_export(peft_model, processor, output_dir, language) + print(f"[{language.upper()}] ONNX exported to: {result_path}") + + +def main() -> None: + with open("configs/base_config.yaml") as f: + config = yaml.safe_load(f) + + print("=" * 60) + print("Sahel-Agri Voice AI β€” ONNX Export") + print("=" * 60) + + bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara") + fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula") + + for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]: + if Path(adapter_path).exists(): + export_language(language, adapter_path, config) + else: + print(f"\nSkipping {language}: adapter not found at {adapter_path}") + + print("\nExport complete.") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_data_pipeline.py b/scripts/run_data_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ffaac7a511e851a7b0aee363e49c0f7957089fb7 --- /dev/null +++ b/scripts/run_data_pipeline.py @@ -0,0 +1,76 @@ +""" +Phase 2: Download google/waxal, apply augmentation, print statistics. +Streams examples and caches to data_cache/ as Arrow files. + +Usage: + python scripts/run_data_pipeline.py --subset bam --max-examples 100 +""" +import argparse +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import os + +from dotenv import load_dotenv + +load_dotenv() + + +def main(subset: str, max_examples: int) -> None: + import yaml + from transformers import WhisperProcessor + + from src.data.augmentation import FieldNoiseAugmenter + from src.data.waxal_loader import WaxalDataLoader + + with open("configs/base_config.yaml") as f: + config = yaml.safe_load(f) + + hf_token = os.getenv("HF_TOKEN") + model_id = config["model"]["id"] + + print("=" * 60) + print(f"Waxal Data Pipeline β€” subset: {subset}") + print("=" * 60) + + print(f"\n[1/4] Loading WhisperProcessor ({model_id})...") + processor = WhisperProcessor.from_pretrained(model_id, token=hf_token) + + print("[2/4] Initializing augmenter...") + augmenter = FieldNoiseAugmenter(config["paths"]["noise_samples"], config) + print(f" Augmenter ready: {augmenter.is_ready()}") + + print(f"[3/4] Streaming google/waxal subset={subset}...") + loader = WaxalDataLoader(subset, config, hf_token=hf_token) + + t0 = time.time() + count = 0 + total_duration = 0.0 + + for example in loader.iter_processed(processor, split="train", augmenter=augmenter): + count += 1 + # input_features shape: (80, 3000) = 30 seconds at most + # Estimate actual audio duration from non-padding frames + total_duration += 30.0 # max chunk + if count >= max_examples: + break + + elapsed = time.time() - t0 + + print(f"\n[4/4] Results:") + print(f" Examples processed: {count}") + print(f" Approx total audio: {total_duration / 3600:.2f} hours") + print(f" Processing time: {elapsed:.1f}s") + print(f" Throughput: {count / elapsed:.1f} examples/sec") + print(f"\nData pipeline PASSED.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--subset", default="bam", choices=["bam", "ful"]) + parser.add_argument("--max-examples", type=int, default=50) + args = parser.parse_args() + main(args.subset, args.max_examples) diff --git a/scripts/run_server.py b/scripts/run_server.py new file mode 100644 index 0000000000000000000000000000000000000000..15486ad493eacc47825c6b9fe7b70dc6b101c2a1 --- /dev/null +++ b/scripts/run_server.py @@ -0,0 +1,42 @@ +""" +Phase 4b: Start the FastAPI inference server. + +Usage: + python scripts/run_server.py + python scripts/run_server.py --host 0.0.0.0 --port 8000 +""" +import argparse +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dotenv import load_dotenv + +load_dotenv() + +import uvicorn + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Start Sahel-Agri Voice AI server") + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--reload", action="store_true", help="Enable hot-reload (dev only)") + args = parser.parse_args() + + print(f"Starting server on http://{args.host}:{args.port}") + print("Endpoints:") + print(f" GET http://localhost:{args.port}/api/v1/health") + print(f" POST http://localhost:{args.port}/api/v1/transcribe") + print(f" POST http://localhost:{args.port}/api/v1/query") + print(f" GET http://localhost:{args.port}/docs (Swagger UI)") + print() + + uvicorn.run( + "src.api.app:app", + host=args.host, + port=args.port, + workers=1, # Single worker: GPU model shared in memory + reload=args.reload, + log_level="info", + ) diff --git a/scripts/train_bambara.py b/scripts/train_bambara.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb102ba10db5a5b4b4999da9b9a1b0855e693f1 --- /dev/null +++ b/scripts/train_bambara.py @@ -0,0 +1,28 @@ +""" +Phase 3a: Fine-tune LoRA adapter for Bambara (bam). + +Usage: + python scripts/train_bambara.py +""" +import logging +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s β€” %(message)s") + +from src.training.trainer import WhisperLoRATrainer + +if __name__ == "__main__": + trainer = WhisperLoRATrainer( + base_config_path="configs/base_config.yaml", + language_config_path="configs/lora_bambara.yaml", + ) + trainer.setup() + trainer.train() + print("\nBambara training complete. Adapter saved to adapters/bambara/") diff --git a/scripts/train_fula.py b/scripts/train_fula.py new file mode 100644 index 0000000000000000000000000000000000000000..5c1fccb449e4cf18e90128c683941c47f70eeab6 --- /dev/null +++ b/scripts/train_fula.py @@ -0,0 +1,29 @@ +""" +Phase 3b: Fine-tune LoRA adapter for Fula (ful). +Trains on the same frozen backbone as Bambara β€” base model weights are NOT modified. + +Usage: + python scripts/train_fula.py +""" +import logging +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s β€” %(message)s") + +from src.training.trainer import WhisperLoRATrainer + +if __name__ == "__main__": + trainer = WhisperLoRATrainer( + base_config_path="configs/base_config.yaml", + language_config_path="configs/lora_fula.yaml", + ) + trainer.setup() + trainer.train() + print("\nFula training complete. Adapter saved to adapters/fula/") diff --git a/scripts/verify_baseline.py b/scripts/verify_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..168d1e8ae1dc23ef3dc0ed594cf85d74d9e50693 --- /dev/null +++ b/scripts/verify_baseline.py @@ -0,0 +1,78 @@ +""" +Phase 1 smoke test: load Whisper, run inference on a sample audio clip. +Prints model info, inference time, GPU memory usage, and sample transcript. + +Usage: + python scripts/verify_baseline.py +""" +import sys +import time +from pathlib import Path + +# Allow imports from project root +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import numpy as np +import torch + + +def main() -> None: + from src.engine.whisper_base import WhisperBackbone + + print("=" * 60) + print("Sahel-Agri Voice AI β€” Baseline Verification") + print("=" * 60) + + # 1. Check environment + print(f"\nPython: {sys.version.split()[0]}") + print(f"PyTorch: {torch.__version__}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") + + # 2. Load model + print("\n[1/3] Loading backbone model...") + t0 = time.time() + backbone = WhisperBackbone("configs/base_config.yaml") + backbone.load(device="cuda") + load_time = time.time() - t0 + print(f" Loaded in {load_time:.1f}s") + + if torch.cuda.is_available(): + used = torch.cuda.memory_allocated() / 1e9 + reserved = torch.cuda.memory_reserved() / 1e9 + print(f" GPU memory: {used:.2f} GB allocated / {reserved:.2f} GB reserved") + + # 3. Generate synthetic test audio (1 second of silence with slight noise) + print("\n[2/3] Generating test audio (1s white noise)...") + sample_rate = 16000 + duration = 1.0 + audio = np.random.randn(int(sample_rate * duration)).astype(np.float32) * 0.01 + + # 4. Run inference + print("[3/3] Running inference...") + processor = backbone.processor + model = backbone.model + + inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt") + input_features = inputs.input_features.to(backbone.device) + if backbone.device == "cuda": + input_features = input_features.half() + + t0 = time.time() + with torch.no_grad(): + predicted_ids = model.generate(input_features, max_new_tokens=50) + infer_time = time.time() - t0 + + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + + print(f"\n{'=' * 60}") + print(f"Transcript: '{transcription}' (noise input β€” blank expected)") + print(f"Inference time: {infer_time * 1000:.0f} ms") + print(f"\nBaseline verification PASSED.") + print(f"{'=' * 60}") + + +if __name__ == "__main__": + main() diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/api/app.py b/src/api/app.py new file mode 100644 index 0000000000000000000000000000000000000000..46460dfd9fb4abb0b50d0097d8d2412c05eaeec8 --- /dev/null +++ b/src/api/app.py @@ -0,0 +1,98 @@ +""" +FastAPI application factory. +Uses lifespan context manager to load the Whisper model at startup +and register language adapters β€” keeping a single backbone in GPU memory. +""" +from __future__ import annotations + +import logging +import os +from contextlib import asynccontextmanager + +import yaml +from fastapi import FastAPI + +from src.api.middleware import register_middleware +from src.api.routes import health, iot, transcribe +from src.engine.adapter_manager import AdapterManager +from src.engine.transcriber import Transcriber +from src.engine.whisper_base import WhisperBackbone +from src.iot.sensor_bridge import SensorBridge + +logger = logging.getLogger(__name__) + +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO"), + format="%(asctime)s %(levelname)s %(name)s β€” %(message)s", +) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Load model at startup, free GPU memory at shutdown.""" + with open("configs/base_config.yaml") as f: + config = yaml.safe_load(f) + + hf_token = os.getenv("HF_TOKEN") + device = os.getenv("DEVICE", "cuda") + bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara") + fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula") + sensor_api_url = os.getenv("SENSOR_API_URL") or None + + # 1. Load backbone + logger.info("Loading Whisper backbone...") + backbone = WhisperBackbone("configs/base_config.yaml") + backbone.load(device=device, hf_token=hf_token) + + # 2. Register adapters (they are loaded on first use via activate()) + adapter_manager = AdapterManager(backbone.model, config) + adapter_manager.register("bam", bambara_path) + adapter_manager.register("ful", fula_path) + + # 3. Pre-load the default adapter to warm up VRAM + try: + adapter_manager.load_adapter("bam") + logger.info("Default adapter 'bam' pre-loaded.") + except Exception as e: + logger.warning("Could not pre-load 'bam' adapter: %s", e) + + # 4. Create transcriber and sensor bridge + transcriber = Transcriber(backbone, adapter_manager) + sensor_bridge = SensorBridge(sensor_api_url=sensor_api_url) + + # 5. Attach to app.state for dependency injection + app.state.backbone = backbone + app.state.adapter_manager = adapter_manager + app.state.transcriber = transcriber + app.state.sensor_bridge = sensor_bridge + + logger.info("Sahel-Agri Voice AI server ready.") + yield + + # Shutdown + logger.info("Shutting down β€” freeing GPU memory...") + backbone.free() + + +def create_app() -> FastAPI: + app = FastAPI( + title="Sahel-Agri Voice AI", + description=( + "Modular STT engine for Bambara and Fula β€” serving Mali and Guinea farmers " + "via voice-first agricultural intelligence." + ), + version="0.1.0", + lifespan=lifespan, + ) + + register_middleware(app) + + # Register routes + app.include_router(health.router, prefix="/api/v1", tags=["health"]) + app.include_router(transcribe.router, prefix="/api/v1", tags=["transcribe"]) + app.include_router(iot.router, prefix="/api/v1", tags=["iot"]) + + return app + + +app = create_app() diff --git a/src/api/dependencies.py b/src/api/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..649790b14b1bd029abb34eca3225655ad152de4d --- /dev/null +++ b/src/api/dependencies.py @@ -0,0 +1,20 @@ +"""FastAPI dependency injection: retrieves shared model objects from app.state.""" +from __future__ import annotations + +from fastapi import Request + +from src.engine.adapter_manager import AdapterManager +from src.engine.transcriber import Transcriber +from src.iot.sensor_bridge import SensorBridge + + +def get_transcriber(request: Request) -> Transcriber: + return request.app.state.transcriber + + +def get_adapter_manager(request: Request) -> AdapterManager: + return request.app.state.adapter_manager + + +def get_sensor_bridge(request: Request) -> SensorBridge: + return request.app.state.sensor_bridge diff --git a/src/api/middleware.py b/src/api/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..7c726db440e66338c77a8ecda24725e47f4f2211 --- /dev/null +++ b/src/api/middleware.py @@ -0,0 +1,47 @@ +"""CORS, structured request logging, and rate-limit middleware.""" +from __future__ import annotations + +import logging +import time +import uuid + +from fastapi import FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address + +logger = logging.getLogger(__name__) + +limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"]) + + +def register_middleware(app: FastAPI) -> None: + """Attach all middleware to the FastAPI app.""" + + # CORS β€” allow WhatsApp webhook domain and local development + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Tighten in production with specific domains + allow_credentials=True, + allow_methods=["GET", "POST"], + allow_headers=["*"], + ) + + # Rate limiting + app.state.limiter = limiter + app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + + @app.middleware("http") + async def logging_middleware(request: Request, call_next) -> Response: + request_id = str(uuid.uuid4())[:8] + t0 = time.perf_counter() + response = await call_next(request) + elapsed_ms = int((time.perf_counter() - t0) * 1000) + logger.info( + "req_id=%s method=%s path=%s status=%d latency_ms=%d", + request_id, request.method, request.url.path, + response.status_code, elapsed_ms, + ) + response.headers["X-Request-ID"] = request_id + return response diff --git a/src/api/routes/__init__.py b/src/api/routes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/api/routes/health.py b/src/api/routes/health.py new file mode 100644 index 0000000000000000000000000000000000000000..805d1d6e30a15d1b376998eab0a077863b689998 --- /dev/null +++ b/src/api/routes/health.py @@ -0,0 +1,25 @@ +"""GET /api/v1/health β€” model status and adapter availability.""" +from __future__ import annotations + +from fastapi import APIRouter, Depends, Request + +from src.api.dependencies import get_adapter_manager +from src.api.schemas import HealthResponse +from src.engine.adapter_manager import AdapterManager + +router = APIRouter() + + +@router.get("/health", response_model=HealthResponse) +async def health_check( + request: Request, + adapter_manager: AdapterManager = Depends(get_adapter_manager), +) -> HealthResponse: + model_loaded = hasattr(request.app.state, "transcriber") + return HealthResponse( + status="ok" if model_loaded else "loading", + model_loaded=model_loaded, + active_adapter=adapter_manager.get_active(), + adapters_available=adapter_manager.list_available(), + adapters_loaded=adapter_manager.list_loaded(), + ) diff --git a/src/api/routes/iot.py b/src/api/routes/iot.py new file mode 100644 index 0000000000000000000000000000000000000000..79eeadf799f236faed0459bda3b029b18a286b23 --- /dev/null +++ b/src/api/routes/iot.py @@ -0,0 +1,90 @@ +"""POST /api/v1/query β€” full pipeline: audio β†’ transcription β†’ intent β†’ sensor β†’ voice response.""" +from __future__ import annotations + +import logging +import os +import tempfile +import time +from typing import Annotated, Optional + +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile + +from src.api.dependencies import get_sensor_bridge, get_transcriber +from src.api.schemas import IoTQueryResponse +from src.engine.transcriber import Transcriber +from src.iot.intent_parser import IntentParser +from src.iot.sensor_bridge import SensorBridge +from src.iot.voice_responder import VoiceResponder + +logger = logging.getLogger(__name__) +router = APIRouter() + +_intent_parser = IntentParser() +_voice_responder = VoiceResponder(language="fr") + +SUPPORTED_LANGUAGES = {"bam", "ful"} +MAX_AUDIO_BYTES = 10 * 1024 * 1024 + + +@router.post("/query", response_model=IoTQueryResponse) +async def agricultural_query( + audio_file: Annotated[UploadFile, File(description="Audio file with farmer's voice query")], + language: Annotated[str, Form(description="Language code: 'bam' or 'ful'")] = "bam", + field_id: Annotated[Optional[str], Form(description="Field/location ID for sensor lookup")] = None, + transcriber: Transcriber = Depends(get_transcriber), + sensor_bridge: SensorBridge = Depends(get_sensor_bridge), +) -> IoTQueryResponse: + t0 = time.perf_counter() + + if language not in SUPPORTED_LANGUAGES: + raise HTTPException( + status_code=422, + detail=f"Unsupported language '{language}'. Supported: {sorted(SUPPORTED_LANGUAGES)}", + ) + + audio_bytes = await audio_file.read() + if len(audio_bytes) > MAX_AUDIO_BYTES: + raise HTTPException(status_code=413, detail="Audio file too large. Max 10 MB.") + + ext = os.path.splitext(audio_file.filename or "audio.wav")[1].lower() or ".wav" + tmp_path = None + try: + # Step 1: Transcribe + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: + tmp.write(audio_bytes) + tmp_path = tmp.name + + transcription_result = transcriber.transcribe_file(tmp_path, language) + + # Step 2: Parse intent + intent = _intent_parser.parse(transcription_result.text, language) + + # Step 3: Fetch sensor data + sensor_data = await sensor_bridge.fetch(intent, field_id=field_id) + + # Step 4: Generate voice response + voice_response = _voice_responder.generate_response(intent, sensor_data) + + except HTTPException: + raise + except Exception as e: + logger.error("IoT query failed: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + finally: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + + elapsed_ms = int((time.perf_counter() - t0) * 1000) + + return IoTQueryResponse( + transcription=transcription_result.text, + language=language, + intent={ + "action": intent.action, + "entity": intent.entity, + "confidence": intent.confidence, + }, + sensor_data=sensor_data.values, + voice_response=voice_response, + processing_time_ms=elapsed_ms, + ) diff --git a/src/api/routes/transcribe.py b/src/api/routes/transcribe.py new file mode 100644 index 0000000000000000000000000000000000000000..7a0725336b1ba6e9dc4bc36bb00957957660ecbf --- /dev/null +++ b/src/api/routes/transcribe.py @@ -0,0 +1,74 @@ +"""POST /api/v1/transcribe β€” convert uploaded audio to text.""" +from __future__ import annotations + +import logging +import os +import tempfile +from typing import Annotated + +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile + +from src.api.dependencies import get_transcriber +from src.api.schemas import TranscribeResponse +from src.engine.transcriber import Transcriber + +logger = logging.getLogger(__name__) +router = APIRouter() + +SUPPORTED_LANGUAGES = {"bam", "ful"} +SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".ogg", ".m4a", ".flac", ".webm"} +MAX_AUDIO_BYTES = 10 * 1024 * 1024 # 10 MB + + +@router.post("/transcribe", response_model=TranscribeResponse) +async def transcribe_audio( + audio_file: Annotated[UploadFile, File(description="Audio file (wav/mp3/ogg/m4a/flac/webm)")], + language: Annotated[str, Form(description="Language code: 'bam' (Bambara) or 'ful' (Fula)")] = "bam", + transcriber: Transcriber = Depends(get_transcriber), +) -> TranscribeResponse: + # Validate language + if language not in SUPPORTED_LANGUAGES: + raise HTTPException( + status_code=422, + detail=f"Unsupported language '{language}'. Supported: {sorted(SUPPORTED_LANGUAGES)}", + ) + + # Validate file extension + filename = audio_file.filename or "audio.wav" + ext = os.path.splitext(filename)[1].lower() + if ext not in SUPPORTED_EXTENSIONS: + raise HTTPException( + status_code=422, + detail=f"Unsupported file type '{ext}'. Supported: {sorted(SUPPORTED_EXTENSIONS)}", + ) + + # Read and size-check + audio_bytes = await audio_file.read() + if len(audio_bytes) > MAX_AUDIO_BYTES: + raise HTTPException( + status_code=413, + detail=f"File too large ({len(audio_bytes) / 1e6:.1f} MB). Max 10 MB.", + ) + + # Windows-safe temp file: delete=False + manual unlink in finally + tmp_path = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: + tmp.write(audio_bytes) + tmp_path = tmp.name + + result = transcriber.transcribe_file(tmp_path, language) + except Exception as e: + logger.error("Transcription failed: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + finally: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + + return TranscribeResponse( + text=result.text, + language=result.language, + duration_s=result.duration_s, + processing_time_ms=result.processing_time_ms, + confidence=result.confidence, + ) diff --git a/src/api/schemas.py b/src/api/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..38905fb7db8951037ed7eedd2a7b4a6706267547 --- /dev/null +++ b/src/api/schemas.py @@ -0,0 +1,36 @@ +"""Pydantic v2 request and response models for all API endpoints.""" +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel, Field + + +class TranscribeResponse(BaseModel): + text: str + language: str + duration_s: float + processing_time_ms: int + confidence: Optional[float] = None + + +class IoTQueryResponse(BaseModel): + transcription: str + language: str + intent: dict + sensor_data: dict + voice_response: str + processing_time_ms: int + + +class HealthResponse(BaseModel): + status: str + model_loaded: bool + active_adapter: Optional[str] + adapters_available: list[str] + adapters_loaded: list[str] + + +class ErrorResponse(BaseModel): + error: str + detail: str diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/agri_dictionary.py b/src/data/agri_dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..c57fbfe291e42843190d466be0bc074f891db1ea --- /dev/null +++ b/src/data/agri_dictionary.py @@ -0,0 +1,92 @@ +""" +Agricultural vocabulary for Bambara and Fula. +Used to bias the Whisper decoder toward domain-specific terms via decoder prompt injection. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from transformers import WhisperProcessor + +# Bambara (bam) agricultural vocabulary +BAMBARA_VOCAB: dict[str, str] = { + "sΙ›nΙ›": "farming", + "jiriw": "trees", + "nΙ”gΙ”": "soil", + "sani": "fertilizer", + "kogomali": "groundnut", + "kaba": "corn/maize", + "tiga": "peanut", + "ji": "water", + "sanji": "rain", + "teliman": "weather", + "suruku": "pest/predator", + "bunding": "soil/earth", + "sira": "path/way", + "foro": "field", + "dugu": "village/land", + "dibi": "darkness/shade", + "fanga": "strength/fertilizer", + "kungoloni": "insects/pests", +} + +# Fula (ful / Fulfulde) agricultural vocabulary +FULA_VOCAB: dict[str, str] = { + "ngesa": "field", + "leydi": "land/soil", + "kosam": "milk", + "nagge": "cattle", + "leeΙ—e": "crops", + "ndiyam": "water", + "yeeso": "wind/weather", + "laabi": "road/way", + "demoore": "farming", + "hoore": "head/top", + "biΓ±-biΓ±": "insects/pests", + "fuΙ—orde": "sunrise/east field", + "ngaari": "bull", + "mbabba": "donkey", + "ladde": "bush/forest", + "wutte": "clothing/harvest", +} + +LANGUAGE_VOCABS: dict[str, dict[str, str]] = { + "bam": BAMBARA_VOCAB, + "ful": FULA_VOCAB, +} + + +class AgriculturalDictionary: + """Converts agricultural vocabulary into decoder prompt token IDs for Whisper.""" + + def get_vocab(self, language: str) -> dict[str, str]: + if language not in LANGUAGE_VOCABS: + raise ValueError(f"No vocabulary for language '{language}'. Available: {list(LANGUAGE_VOCABS)}") + return LANGUAGE_VOCABS[language] + + def get_prompt_text(self, language: str) -> str: + """Return a comma-joined string of all terms, used as decoder text prompt.""" + vocab = self.get_vocab(language) + return ", ".join(vocab.keys()) + + def build_prompt_ids(self, processor: "WhisperProcessor", language: str) -> torch.Tensor: + """ + Tokenize the vocabulary as a decoder prompt. + Pass this as `decoder_input_ids` or `prompt_ids` to model.generate() + to bias decoding toward known agricultural terms. + """ + prompt_text = self.get_prompt_text(language) + token_ids = processor.tokenizer( + prompt_text, + return_tensors="pt", + add_special_tokens=False, + ).input_ids + return token_ids # shape: (1, N) + + def get_token_ids(self, processor: "WhisperProcessor", language: str) -> list[int]: + """Return flat list of token IDs for all vocabulary terms.""" + ids = self.build_prompt_ids(processor, language) + return ids[0].tolist() diff --git a/src/data/augmentation.py b/src/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..c00051edd9fd8e71ce82fc5c2d56540169b72ff2 --- /dev/null +++ b/src/data/augmentation.py @@ -0,0 +1,84 @@ +""" +Field noise augmentation for West African farm environments. +Mixes clean speech with tractor, wind, and livestock audio samples. +Degrades gracefully to Gaussian noise when no .wav files are present. +""" +from __future__ import annotations + +import logging +from pathlib import Path + +import numpy as np + +logger = logging.getLogger(__name__) + + +class FieldNoiseAugmenter: + """ + Applies audiomentations transforms that simulate noisy field conditions. + If the noise_dir has no .wav files, falls back to Gaussian noise only. + """ + + def __init__(self, noise_dir: str, config: dict) -> None: + self.noise_dir = Path(noise_dir) + self.config = config + self._compose = None + self._gaussian_only = False + self._build_pipeline() + + def _build_pipeline(self) -> None: + try: + from audiomentations import ( + AddBackgroundNoise, + AddGaussianNoise, + Compose, + RoomSimulator, + TimeStretch, + ) + except ImportError: + logger.warning("audiomentations not installed β€” augmentation disabled.") + self._compose = None + return + + snr_range = self.config.get("audio", {}).get("noise_snr_db_range", [5, 20]) + prob = self.config.get("audio", {}).get("augmentation_prob", 0.6) + + wav_files = list(self.noise_dir.glob("*.wav")) if self.noise_dir.exists() else [] + + transforms = [] + + if wav_files: + transforms.append( + AddBackgroundNoise( + sounds_path=str(self.noise_dir), + min_snr_db=float(snr_range[0]), + max_snr_db=float(snr_range[1]), + p=prob, + ) + ) + logger.info("FieldNoiseAugmenter: loaded %d noise files from %s", len(wav_files), self.noise_dir) + else: + logger.warning( + "FieldNoiseAugmenter: no .wav files found in %s β€” using Gaussian noise only. " + "Populate noise_samples/ for realistic field augmentation.", + self.noise_dir, + ) + self._gaussian_only = True + + transforms += [ + AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.3), + TimeStretch(min_rate=0.9, max_rate=1.1, leave_length_unchanged=True, p=0.2), + RoomSimulator(p=0.3), + ] + + self._compose = Compose(transforms) + + def augment(self, audio: np.ndarray, sr: int) -> np.ndarray: + """Apply augmentation pipeline to a float32 audio array.""" + if self._compose is None: + return audio + return self._compose(samples=audio, sample_rate=sr) + + def is_ready(self) -> bool: + """Returns True if augmentation is available (even Gaussian-only).""" + return self._compose is not None diff --git a/src/data/feature_extractor.py b/src/data/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..57d42becc11d6105ee359d69da2bde5412a1a97c --- /dev/null +++ b/src/data/feature_extractor.py @@ -0,0 +1,89 @@ +""" +Log-mel spectrogram extraction, padding/truncation, and batch collation for Whisper. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torchaudio + +if TYPE_CHECKING: + from transformers import WhisperProcessor + +logger = logging.getLogger(__name__) + +TARGET_SR = 16_000 +MEL_FRAMES = 3000 # 30 seconds at 100 frames/sec +N_MELS = 80 + + +class AudioFeatureExtractor: + """Wraps WhisperProcessor to extract and normalize audio features.""" + + def __init__(self, processor: "WhisperProcessor", config: dict) -> None: + self.processor = processor + self.sample_rate = config.get("audio", {}).get("sample_rate", TARGET_SR) + + def extract(self, audio: np.ndarray, sr: int) -> torch.Tensor: + """ + Resample audio to 16kHz, extract log-mel features. + Returns tensor of shape (80, 3000). + """ + if sr != TARGET_SR: + tensor = torch.from_numpy(audio).unsqueeze(0) + tensor = torchaudio.functional.resample(tensor, sr, TARGET_SR) + audio = tensor.squeeze(0).numpy() + + inputs = self.processor.feature_extractor( + audio, + sampling_rate=TARGET_SR, + return_tensors="pt", + ) + features = inputs.input_features[0] # (80, 3000) + return features + + def pad_or_truncate(self, features: torch.Tensor) -> torch.Tensor: + """Ensure features are exactly (80, 3000).""" + _, t = features.shape + if t < MEL_FRAMES: + pad = torch.zeros(N_MELS, MEL_FRAMES - t, dtype=features.dtype) + features = torch.cat([features, pad], dim=-1) + elif t > MEL_FRAMES: + features = features[:, :MEL_FRAMES] + return features + + +@dataclass +class DataCollatorSpeechSeq2SeqWithPadding: + """ + Pads input_features to uniform length and label sequences with -100 + (so they are ignored in the cross-entropy loss). + Compatible with HuggingFace Seq2SeqTrainer. + """ + processor: Any + decoder_start_token_id: int + + def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]: + # Separate input_features and labels + input_features = [{"input_features": f["input_features"]} for f in features] + label_features = [{"input_ids": f["labels"]} for f in features] + + # Pad input features (processor handles this) + batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + + # Pad labels + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + labels = labels_batch["input_ids"].masked_fill( + labels_batch.attention_mask.ne(1), -100 + ) + + # Remove decoder start token if it was prepended + if (labels[:, 0] == self.decoder_start_token_id).all().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + return batch diff --git a/src/data/waxal_loader.py b/src/data/waxal_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3af8ab4b814d860a8ec590ad51ae3c1a9b9212 --- /dev/null +++ b/src/data/waxal_loader.py @@ -0,0 +1,119 @@ +""" +Loads and preprocesses the google/waxal dataset for Bambara (bam) and Fula (ful). +Uses streaming to avoid downloading the full corpus before training. +""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Callable, Iterator + +import numpy as np +import torch +import torchaudio +from datasets import load_dataset + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import WhisperProcessor + + from src.data.augmentation import FieldNoiseAugmenter + +logger = logging.getLogger(__name__) + +# google/waxal column names +AUDIO_COL = "audio" +TEXT_COL = "transcription" +TARGET_SR = 16_000 + + +class WaxalDataLoader: + """Streams the google/waxal dataset and prepares examples for Whisper training.""" + + def __init__( + self, + subset: str, + config: dict, + hf_token: str | None = None, + ) -> None: + if subset not in ("bam", "ful"): + raise ValueError(f"subset must be 'bam' or 'ful', got '{subset}'") + self.subset = subset + self.config = config + self.hf_token = hf_token + + def load_split(self, split: str = "train", streaming: bool = True) -> "IterableDataset | Dataset": + """Return a single split of google/waxal.""" + logger.info("Loading google/waxal subset=%s split=%s streaming=%s", self.subset, split, streaming) + ds = load_dataset( + "google/waxal", + self.subset, + split=split, + token=self.hf_token, + streaming=streaming, + trust_remote_code=True, + ) + if streaming: + ds = ds.shuffle(seed=42, buffer_size=1000) + return ds + + def get_splits(self, streaming: bool = True) -> dict[str, "IterableDataset | Dataset"]: + """Return train / validation / test splits.""" + splits = {} + for split in ("train", "validation", "test"): + try: + splits[split] = self.load_split(split, streaming=streaming) + except Exception: + logger.warning("Split '%s' not available for subset '%s'", split, self.subset) + return splits + + def make_preprocess_fn( + self, + processor: "WhisperProcessor", + augmenter: "FieldNoiseAugmenter | None" = None, + ) -> Callable[[dict], dict]: + """Return a function that converts a raw Waxal example into model inputs.""" + + def preprocess(example: dict) -> dict: + # Extract and resample audio + audio_array = np.array(example[AUDIO_COL]["array"], dtype=np.float32) + orig_sr: int = example[AUDIO_COL]["sampling_rate"] + + if orig_sr != TARGET_SR: + tensor = torch.from_numpy(audio_array).unsqueeze(0) + tensor = torchaudio.functional.resample(tensor, orig_sr, TARGET_SR) + audio_array = tensor.squeeze(0).numpy() + + # Apply field noise augmentation if provided + if augmenter is not None and augmenter.is_ready(): + audio_array = augmenter.augment(audio_array, TARGET_SR) + + # Extract log-mel features + inputs = processor.feature_extractor( + audio_array, + sampling_rate=TARGET_SR, + return_tensors="np", + ) + input_features = inputs.input_features[0] # shape (80, 3000) + + # Tokenize transcript + text: str = example[TEXT_COL] + labels = processor.tokenizer(text, return_tensors="np").input_ids[0] + + return { + "input_features": input_features, + "labels": labels, + } + + return preprocess + + def iter_processed( + self, + processor: "WhisperProcessor", + split: str = "train", + augmenter: "FieldNoiseAugmenter | None" = None, + ) -> Iterator[dict]: + """Yield preprocessed examples one at a time (streaming).""" + ds = self.load_split(split, streaming=True) + fn = self.make_preprocess_fn(processor, augmenter) + for example in ds: + yield fn(example) diff --git a/src/engine/__init__.py b/src/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/engine/adapter_manager.py b/src/engine/adapter_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..565c03f91b31cd493bdf3de082b52a4a27ae3f4d --- /dev/null +++ b/src/engine/adapter_manager.py @@ -0,0 +1,106 @@ +""" +LoRA adapter hot-swap manager. + +Uses PEFT's multi-adapter API: + - model.load_adapter(path, adapter_name=lang) β€” first load (~2s per adapter) + - model.set_adapter(lang) β€” subsequent swap (~50ms) + +This keeps a single backbone in VRAM and swaps only the ~50MB adapter weights, +vs reloading the full 1.5GB model per language. +""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from peft import PeftModel + +if TYPE_CHECKING: + from transformers import WhisperForConditionalGeneration + +logger = logging.getLogger(__name__) + + +class AdapterManager: + """Manages registration and hot-swapping of LoRA language adapters.""" + + def __init__(self, base_model: "WhisperForConditionalGeneration", config: dict) -> None: + self._base_model = base_model + self._config = config + self._registry: dict[str, str] = {} # language_code -> adapter_path + self._peft_model: PeftModel | None = None + self._active: str | None = None + + def register(self, language: str, adapter_path: str) -> None: + """Register an adapter path. Does not load it yet.""" + path = Path(adapter_path) + if not path.exists(): + logger.warning( + "Adapter path '%s' for language '%s' does not exist. " + "Run training first, or check the path.", + adapter_path, language, + ) + self._registry[language] = str(path) + logger.info("Registered adapter '%s' β†’ %s", language, adapter_path) + + def load_adapter(self, language: str) -> None: + """ + Load an adapter into the model for the first time. + Slow (~2s): reads adapter weights from disk. + Subsequent activate() calls reuse the already-loaded weights. + """ + if language not in self._registry: + raise KeyError(f"No adapter registered for language '{language}'. " + f"Available: {list(self._registry)}") + + adapter_path = self._registry[language] + + if self._peft_model is None: + # First adapter: wrap the base model with PeftModel + logger.info("Wrapping base model with first adapter '%s'...", language) + self._peft_model = PeftModel.from_pretrained( + self._base_model, + adapter_path, + adapter_name=language, + ) + else: + # Subsequent adapters: load into the existing PeftModel + logger.info("Loading adapter '%s' into existing PeftModel...", language) + self._peft_model.load_adapter(adapter_path, adapter_name=language) + + self._active = language + logger.info("Adapter '%s' loaded and active.", language) + + def activate(self, language: str) -> None: + """ + Hot-swap to a previously loaded adapter (~50ms). + Call load_adapter() first if this adapter hasn't been loaded. + """ + if self._peft_model is None: + self.load_adapter(language) + return + + loaded = set(self._peft_model.peft_config.keys()) + if language not in loaded: + self.load_adapter(language) + return + + self._peft_model.set_adapter(language) + self._active = language + logger.debug("Hot-swapped to adapter '%s'.", language) + + def get_model(self) -> "WhisperForConditionalGeneration | PeftModel": + """Return the PeftModel (or base model if no adapter loaded yet).""" + return self._peft_model if self._peft_model is not None else self._base_model + + def get_active(self) -> str | None: + return self._active + + def list_available(self) -> list[str]: + return list(self._registry.keys()) + + def list_loaded(self) -> list[str]: + if self._peft_model is None: + return [] + return list(self._peft_model.peft_config.keys()) diff --git a/src/engine/transcriber.py b/src/engine/transcriber.py new file mode 100644 index 0000000000000000000000000000000000000000..3301989c15bc6501eba2fef95cd71aa5a218986e --- /dev/null +++ b/src/engine/transcriber.py @@ -0,0 +1,132 @@ +""" +Public inference interface. +Accepts audio as a file path or numpy array and returns transcribed text. +Handles chunking for audio longer than 30 seconds. +""" +from __future__ import annotations + +import logging +import os +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from src.engine.adapter_manager import AdapterManager + from src.engine.whisper_base import WhisperBackbone + +logger = logging.getLogger(__name__) + +TARGET_SR = 16_000 + + +@dataclass +class TranscriptionResult: + text: str + language: str + duration_s: float + processing_time_ms: int + confidence: float | None = None + + +class Transcriber: + """ + Composes WhisperBackbone + AdapterManager to provide a simple transcription API. + Thread-safety: Not thread-safe by design β€” use one worker process. + """ + + def __init__(self, backbone: "WhisperBackbone", adapter_manager: "AdapterManager") -> None: + self._backbone = backbone + self._adapter_manager = adapter_manager + + def transcribe( + self, + audio: np.ndarray, + sample_rate: int, + language: str, + use_agri_prompt: bool = True, + ) -> TranscriptionResult: + """ + Transcribe a float32 audio array. + For audio > 30s, uses transformers pipeline with chunking. + """ + t0 = time.time() + + # Activate the correct language adapter + self._adapter_manager.activate(language) + + processor = self._backbone.processor + model = self._adapter_manager.get_model() + device = self._backbone.device + duration_s = len(audio) / sample_rate + + if duration_s <= 30.0: + text = self._transcribe_chunk(audio, sample_rate, language, processor, model, device) + else: + text = self._transcribe_long(audio, sample_rate, language, processor, model, device) + + elapsed_ms = int((time.time() - t0) * 1000) + return TranscriptionResult( + text=text.strip(), + language=language, + duration_s=duration_s, + processing_time_ms=elapsed_ms, + ) + + def transcribe_file(self, audio_path: str, language: str) -> TranscriptionResult: + """Load audio from disk and transcribe.""" + import librosa + audio, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True) + return self.transcribe(audio, sr, language) + + def _transcribe_chunk( + self, + audio: np.ndarray, + sr: int, + language: str, + processor, + model, + device: str, + ) -> str: + """Transcribe a single ≀30s chunk.""" + inputs = processor.feature_extractor( + audio, sampling_rate=sr, return_tensors="pt" + ) + input_features = inputs.input_features.to(device) + if device == "cuda": + input_features = input_features.half() + + forced_decoder_ids = processor.get_decoder_prompt_ids( + language=language, task="transcribe" + ) + + with torch.no_grad(): + predicted_ids = model.generate( + input_features, + forced_decoder_ids=forced_decoder_ids, + max_new_tokens=128, + ) + + return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + + def _transcribe_long( + self, + audio: np.ndarray, + sr: int, + language: str, + processor, + model, + device: str, + ) -> str: + """Chunk audio into 30s segments and concatenate transcriptions.""" + chunk_size = TARGET_SR * 30 + chunks = [audio[i : i + chunk_size] for i in range(0, len(audio), chunk_size)] + parts = [] + for chunk in chunks: + text = self._transcribe_chunk(chunk, sr, language, processor, model, device) + parts.append(text) + return " ".join(parts) diff --git a/src/engine/whisper_base.py b/src/engine/whisper_base.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe3a85d09c355a02fb4e1c0fc1b8663c57c0842 --- /dev/null +++ b/src/engine/whisper_base.py @@ -0,0 +1,77 @@ +""" +Loads the Whisper backbone model and processor once. +All other modules receive references to this shared instance. +""" +from __future__ import annotations + +import logging +from pathlib import Path + +import torch +import yaml +from transformers import WhisperForConditionalGeneration, WhisperProcessor + +logger = logging.getLogger(__name__) + + +class WhisperBackbone: + """Singleton-style loader for the Whisper base model and processor.""" + + def __init__(self, config_path: str = "configs/base_config.yaml") -> None: + config_path = Path(config_path) + with open(config_path) as f: + cfg = yaml.safe_load(f) + self._model_id: str = cfg["model"]["id"] + self._model: WhisperForConditionalGeneration | None = None + self._processor: WhisperProcessor | None = None + self._device: str = "cpu" + + def load(self, device: str = "cuda", hf_token: str | None = None) -> None: + """Load model and processor into memory. Call once at startup.""" + self._device = device if torch.cuda.is_available() and device == "cuda" else "cpu" + logger.info("Loading %s on %s", self._model_id, self._device) + + self._processor = WhisperProcessor.from_pretrained( + self._model_id, + token=hf_token, + ) + + dtype = torch.float16 if self._device == "cuda" else torch.float32 + self._model = WhisperForConditionalGeneration.from_pretrained( + self._model_id, + torch_dtype=dtype, + token=hf_token, + ).to(self._device) + + self._model.eval() + logger.info("Model loaded successfully (dtype=%s, device=%s)", dtype, self._device) + + @property + def model(self) -> WhisperForConditionalGeneration: + if self._model is None: + raise RuntimeError("Call WhisperBackbone.load() before accessing the model.") + return self._model + + @property + def processor(self) -> WhisperProcessor: + if self._processor is None: + raise RuntimeError("Call WhisperBackbone.load() before accessing the processor.") + return self._processor + + @property + def device(self) -> str: + return self._device + + @property + def model_id(self) -> str: + return self._model_id + + def free(self) -> None: + """Release GPU memory.""" + del self._model + del self._processor + self._model = None + self._processor = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("Backbone freed from memory.") diff --git a/src/iot/__init__.py b/src/iot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/iot/intent_parser.py b/src/iot/intent_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f79933a1f76066112b09d1c0b6629724943aee9a --- /dev/null +++ b/src/iot/intent_parser.py @@ -0,0 +1,75 @@ +""" +Maps transcribed Bambara/Fula text to structured intents for IoT sensor queries. +Uses keyword matching (no ML required for v1). +Confidence = fraction of intent keywords present in the transcription. +""" +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class Intent: + action: str # e.g., "check_soil", "check_weather" + entity: str # e.g., "soil", "weather" + parameters: dict = field(default_factory=dict) + confidence: float = 0.0 + + +# Intent keyword taxonomy for Bambara (bam) and Fula (ful) +INTENT_KEYWORDS: dict[str, dict[str, list[str]]] = { + "check_soil": { + "bam": ["bunding", "nΙ”gΙ”", "dugu", "foro", "sani"], + "ful": ["leydi", "ngesa", "ladde"], + }, + "check_weather": { + "bam": ["teliman", "sanji", "dibi", "sira"], + "ful": ["yeeso", "fuΙ—orde"], + }, + "irrigation_status": { + "bam": ["ji", "sanji", "foro"], + "ful": ["ndiyam", "ngesa"], + }, + "pest_alert": { + "bam": ["kungoloni", "suruku"], + "ful": ["biΓ±-biΓ±"], + }, +} + +INTENT_ENTITIES = { + "check_soil": "soil", + "check_weather": "weather", + "irrigation_status": "irrigation", + "pest_alert": "pest", +} + + +class IntentParser: + """Parses a transcription string into a structured Intent.""" + + def parse(self, text: str, language: str) -> Intent: + """ + Find the best matching intent by counting keyword overlaps. + Returns the highest-confidence intent. + """ + text_lower = text.lower() + best_action = "unknown" + best_confidence = 0.0 + + for action, lang_keywords in INTENT_KEYWORDS.items(): + keywords = lang_keywords.get(language, []) + if not keywords: + continue + + matches = sum(1 for kw in keywords if kw in text_lower) + confidence = matches / len(keywords) + + if confidence > best_confidence: + best_confidence = confidence + best_action = action + + return Intent( + action=best_action, + entity=INTENT_ENTITIES.get(best_action, "unknown"), + confidence=round(best_confidence, 3), + ) diff --git a/src/iot/sensor_bridge.py b/src/iot/sensor_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5376a8f588feb2a6f7fa9720ab6bddbfc716ae --- /dev/null +++ b/src/iot/sensor_bridge.py @@ -0,0 +1,121 @@ +""" +Fetches sensor data (soil moisture, weather, irrigation) from the IoT backend API. +Falls back to synthetic mock data when SENSOR_API_URL is not configured. +""" +from __future__ import annotations + +import logging +import random +from dataclasses import dataclass, field +from datetime import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from src.iot.intent_parser import Intent + +logger = logging.getLogger(__name__) + + +@dataclass +class SensorData: + sensor_type: str + values: dict[str, float] + timestamp: str + unit: str = "" + + +class SensorBridge: + """Async bridge to IoT sensor API. Uses mock data when no API URL is configured.""" + + def __init__(self, sensor_api_url: str | None = None, timeout_s: float = 5.0) -> None: + self.sensor_api_url = sensor_api_url + self.timeout_s = timeout_s + self._mock_mode = not sensor_api_url + + if self._mock_mode: + logger.info("SensorBridge: running in MOCK mode (set SENSOR_API_URL to use real sensors).") + + async def fetch(self, intent: "Intent", field_id: str | None = None) -> SensorData: + """Dispatch to the correct sensor fetch method based on intent entity.""" + action = intent.action + if action == "check_soil": + return await self.get_soil_data(field_id or "default") + elif action == "check_weather": + return await self.get_weather(field_id or "default") + elif action == "irrigation_status": + return await self.get_irrigation(field_id or "default") + elif action == "pest_alert": + return await self.get_pest_status(field_id or "default") + else: + return SensorData( + sensor_type="unknown", + values={}, + timestamp=datetime.utcnow().isoformat(), + ) + + async def get_soil_data(self, location_id: str) -> SensorData: + if self._mock_mode: + return SensorData( + sensor_type="soil", + values={ + "moisture_pct": round(random.uniform(25, 65), 1), + "ph": round(random.uniform(5.5, 7.5), 1), + "nitrogen_ppm": round(random.uniform(10, 40), 1), + "temperature_c": round(random.uniform(24, 35), 1), + }, + timestamp=datetime.utcnow().isoformat(), + ) + return await self._get(f"/sensors/soil/{location_id}", "soil") + + async def get_weather(self, location_id: str) -> SensorData: + if self._mock_mode: + return SensorData( + sensor_type="weather", + values={ + "temperature_c": round(random.uniform(28, 42), 1), + "humidity_pct": round(random.uniform(20, 80), 1), + "wind_speed_kmh": round(random.uniform(0, 25), 1), + "rain_probability_pct": round(random.uniform(0, 100), 1), + }, + timestamp=datetime.utcnow().isoformat(), + ) + return await self._get(f"/sensors/weather/{location_id}", "weather") + + async def get_irrigation(self, field_id: str) -> SensorData: + if self._mock_mode: + return SensorData( + sensor_type="irrigation", + values={ + "flow_rate_lph": round(random.uniform(0, 500), 1), + "pressure_bar": round(random.uniform(1.0, 4.0), 2), + "active": float(random.choice([0, 1])), + "last_irrigation_h_ago": round(random.uniform(1, 48), 1), + }, + timestamp=datetime.utcnow().isoformat(), + ) + return await self._get(f"/sensors/irrigation/{field_id}", "irrigation") + + async def get_pest_status(self, field_id: str) -> SensorData: + if self._mock_mode: + return SensorData( + sensor_type="pest", + values={ + "trap_count_24h": float(random.randint(0, 50)), + "alert_level": float(random.randint(0, 3)), # 0=none 1=low 2=medium 3=high + }, + timestamp=datetime.utcnow().isoformat(), + ) + return await self._get(f"/sensors/pest/{field_id}", "pest") + + async def _get(self, path: str, sensor_type: str) -> SensorData: + import httpx + url = f"{self.sensor_api_url}{path}" + async with httpx.AsyncClient(timeout=self.timeout_s) as client: + response = await client.get(url) + response.raise_for_status() + data = response.json() + return SensorData( + sensor_type=sensor_type, + values=data.get("values", data), + timestamp=data.get("timestamp", datetime.utcnow().isoformat()), + ) diff --git a/src/iot/voice_responder.py b/src/iot/voice_responder.py new file mode 100644 index 0000000000000000000000000000000000000000..396fa5efd533e100e86f03eba4347745aa3bb64e --- /dev/null +++ b/src/iot/voice_responder.py @@ -0,0 +1,260 @@ +""" +Generates voice response text from sensor data in the farmer's own language. +Supports Bambara (bam), Fula (ful), French (fr), and English (en). +Bambara/Fula templates use short sentences (≀15 words) for best MMS-TTS quality. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from src.iot.intent_parser import Intent + from src.iot.sensor_bridge import SensorData + +# Alert thresholds +SOIL_MOISTURE_LOW = 30.0 # Below this β†’ immediate irrigation recommended +SOIL_MOISTURE_HIGH = 70.0 # Above this β†’ drainage warning +SOIL_PH_LOW = 5.5 +SOIL_PH_HIGH = 7.5 +TEMP_HIGH = 38.0 +PEST_ALERT_HIGH = 2 # Alert level β‰₯ 2 β†’ warning + +# ── Bambara templates (≀6 words per sentence for clear MMS-TTS output) ─────── +BAMBARA_TEMPLATES = { + "soil_moisture_low": "Bunding ji dΙ”gΙ”. I ka foro ji.", + "soil_moisture_high": "Ji ca kojugu. Foro ma fΙ›.", + "soil_ph_low": "Bunding kΙ”nΙ” jugu. Kalisi fara a kan.", + "soil_ph_high": "Bunding kΙ”nΙ” tΙ›mΙ›. Soufre fara a kan.", + "weather_hot": "Teliman gbΙ›lΙ›. Tile ma sigi.", + "rain_likely": "Sanji bΙ› na. SΙ”rΙ” jΙ”.", + "pest_high": "DΙ”gΙ”w bΙ› foro kΙ”nΙ”. BΙ” u.", + "irrigation_needed": "Foro fΙ› ji. Ji sira yΙ”rΙ”.", + "irrigation_active": "Ji bΙ› taa. A bΙ› kΙ› cogo di.", + "default": "Kabako jumanw sΙ”rΙ”la.", +} + +# ── Fula templates (≀6 words per sentence for clear MMS-TTS output) ────────── +FULA_TEMPLATES = { + "soil_moisture_low": "Leydi ndiyam famΙ—i. Wado ngesa.", + "soil_moisture_high": "Ndiyam heewi. Leydi famΙ—aali.", + "soil_ph_low": "Leydi suurii. WaΙ— kalisi.", + "soil_ph_high": "Leydi alkalii. WaΙ— soufre.", + "weather_hot": "Nguleeki heewi. Muusal.", + "rain_likely": "Ndiyam wadata. Loosu ngesa.", + "pest_high": "BiΓ±-biΓ± ngesa nder. Fiil Ι—en.", + "irrigation_needed": "Ngesa fΙ›Ι—Ι›li ndiyam. Wado.", + "irrigation_active": "Ndiyam wona jooni.", + "default": "Humpito juuti waΙ—aama.", +} + + +class VoiceResponder: + """Converts sensor readings into actionable voice messages in the farmer's language.""" + + def __init__(self, language: str = "fr") -> None: + self.language = language + + def generate_response(self, intent: "Intent", sensor_data: "SensorData") -> str: + if self.language == "bam": + return self._bambara_response(sensor_data) + elif self.language == "ful": + return self._fula_response(sensor_data) + else: + return self._french_response(sensor_data) + + # ── Bambara ────────────────────────────────────────────────────────────── + + def _bambara_response(self, sensor_data: "SensorData") -> str: + t = sensor_data.sensor_type + v = sensor_data.values + T = BAMBARA_TEMPLATES + + if t == "soil": + moisture = v.get("moisture_pct") + if moisture is not None: + if moisture < SOIL_MOISTURE_LOW: + return T["soil_moisture_low"] + elif moisture > SOIL_MOISTURE_HIGH: + return T["soil_moisture_high"] + ph = v.get("ph") + if ph is not None: + if ph < SOIL_PH_LOW: + return T["soil_ph_low"] + elif ph > SOIL_PH_HIGH: + return T["soil_ph_high"] + + elif t == "weather": + temp = v.get("temperature_c") + rain = v.get("rain_probability_pct") + if temp is not None and temp > TEMP_HIGH: + return T["weather_hot"] + if rain is not None and rain > 70: + return T["rain_likely"] + + elif t == "irrigation": + last = v.get("last_irrigation_h_ago") + active = v.get("active") + if active: + return T["irrigation_active"] + if last is not None and last > 24: + return T["irrigation_needed"] + + elif t == "pest": + level = int(v.get("alert_level", 0)) + if level >= PEST_ALERT_HIGH: + return T["pest_high"] + + return T["default"] + + # ── Fula ───────────────────────────────────────────────────────────────── + + def _fula_response(self, sensor_data: "SensorData") -> str: + t = sensor_data.sensor_type + v = sensor_data.values + T = FULA_TEMPLATES + + if t == "soil": + moisture = v.get("moisture_pct") + if moisture is not None: + if moisture < SOIL_MOISTURE_LOW: + return T["soil_moisture_low"] + elif moisture > SOIL_MOISTURE_HIGH: + return T["soil_moisture_high"] + ph = v.get("ph") + if ph is not None: + if ph < SOIL_PH_LOW: + return T["soil_ph_low"] + elif ph > SOIL_PH_HIGH: + return T["soil_ph_high"] + + elif t == "weather": + temp = v.get("temperature_c") + rain = v.get("rain_probability_pct") + if temp is not None and temp > TEMP_HIGH: + return T["weather_hot"] + if rain is not None and rain > 70: + return T["rain_likely"] + + elif t == "irrigation": + active = v.get("active") + last = v.get("last_irrigation_h_ago") + if active: + return T["irrigation_active"] + if last is not None and last > 24: + return T["irrigation_needed"] + + elif t == "pest": + level = int(v.get("alert_level", 0)) + if level >= PEST_ALERT_HIGH: + return T["pest_high"] + + return T["default"] + + # ── French (original) ───────────────────────────────────────────────────── + + def _french_response(self, sensor_data: "SensorData") -> str: + t = sensor_data.sensor_type + v = sensor_data.values + if t == "soil": + return self._soil_response(v) + elif t == "weather": + return self._weather_response(v) + elif t == "irrigation": + return self._irrigation_response(v) + elif t == "pest": + return self._pest_response(v) + else: + return "DonnΓ©es du capteur non disponibles pour le moment." + + def _soil_response(self, v: dict) -> str: + parts = [] + moisture = v.get("moisture_pct") + ph = v.get("ph") + temp = v.get("temperature_c") + nitrogen = v.get("nitrogen_ppm") + + if moisture is not None: + parts.append(f"HumiditΓ© du sol : {moisture:.0f}%.") + if moisture < SOIL_MOISTURE_LOW: + parts.append("Irrigation recommandΓ©e immΓ©diatement.") + elif moisture > SOIL_MOISTURE_HIGH: + parts.append("Sol trop humide, risque d'engorgement.") + + if ph is not None: + parts.append(f"pH du sol : {ph:.1f}.") + if ph < SOIL_PH_LOW: + parts.append("Sol trop acide β€” envisagez un amendement calcaire.") + elif ph > SOIL_PH_HIGH: + parts.append("Sol trop alcalin β€” un apport de soufre peut aider.") + + if temp is not None: + parts.append(f"TempΓ©rature du sol : {temp:.0f}Β°C.") + + if nitrogen is not None: + parts.append(f"Azote disponible : {nitrogen:.0f} ppm.") + if nitrogen < 15: + parts.append("Niveau d'azote faible β€” envisagez un engrais azotΓ©.") + + return " ".join(parts) if parts else "DonnΓ©es du sol reΓ§ues." + + def _weather_response(self, v: dict) -> str: + parts = [] + temp = v.get("temperature_c") + humidity = v.get("humidity_pct") + wind = v.get("wind_speed_kmh") + rain = v.get("rain_probability_pct") + + if temp is not None: + parts.append(f"TempΓ©rature : {temp:.0f}Β°C.") + if temp > TEMP_HIGH: + parts.append("Chaleur excessive β€” Γ©vitez les travaux aux heures les plus chaudes.") + + if humidity is not None: + parts.append(f"HumiditΓ© de l'air : {humidity:.0f}%.") + + if wind is not None: + parts.append(f"Vent : {wind:.0f} km/h.") + + if rain is not None: + parts.append(f"ProbabilitΓ© de pluie : {rain:.0f}%.") + if rain > 70: + parts.append("Pluie probable β€” reportez les traitements pesticides.") + + return " ".join(parts) if parts else "DonnΓ©es mΓ©tΓ©o reΓ§ues." + + def _irrigation_response(self, v: dict) -> str: + parts = [] + active = v.get("active") + last = v.get("last_irrigation_h_ago") + flow = v.get("flow_rate_lph") + + if active is not None: + state = "en marche" if active else "arrΓͺtΓ©e" + parts.append(f"Irrigation {state}.") + + if flow is not None and active: + parts.append(f"DΓ©bit : {flow:.0f} litres par heure.") + + if last is not None: + parts.append(f"DerniΓ¨re irrigation il y a {last:.0f} heures.") + if last > 24: + parts.append("Plus de 24 heures sans irrigation β€” vΓ©rifiez les besoins en eau.") + + return " ".join(parts) if parts else "Statut d'irrigation reΓ§u." + + def _pest_response(self, v: dict) -> str: + level = int(v.get("alert_level", 0)) + count = v.get("trap_count_24h") + + level_labels = {0: "aucune", 1: "faible", 2: "modΓ©rΓ©e", 3: "Γ©levΓ©e"} + label = level_labels.get(level, "inconnue") + + parts = [f"PrΓ©sence d'insectes nuisibles : niveau {label}."] + + if count is not None: + parts.append(f"{count:.0f} insectes capturΓ©s en 24 heures.") + + if level >= PEST_ALERT_HIGH: + parts.append("Traitement recommandΓ© β€” consultez un agent agricole.") + + return " ".join(parts) diff --git a/src/optimization/__init__.py b/src/optimization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/optimization/onnx_exporter.py b/src/optimization/onnx_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..02368708a122e65468babbf496e40e5412c7f73f --- /dev/null +++ b/src/optimization/onnx_exporter.py @@ -0,0 +1,106 @@ +""" +Merges LoRA adapter weights into the backbone and exports to ONNX. +Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime). + +Requires: optimum[onnxruntime] +""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from peft import PeftModel + from transformers import WhisperProcessor + +logger = logging.getLogger(__name__) + + +class ONNXExporter: + """Merges a LoRA PeftModel into its base model and exports to ONNX.""" + + def merge_and_export( + self, + peft_model: "PeftModel", + processor: "WhisperProcessor", + output_dir: str, + language: str, + ) -> Path: + """ + 1. Merge LoRA weights into base model (merge_and_unload) + 2. Export merged model to ONNX via optimum + Returns the output directory path. + """ + output_path = Path(output_dir) / language + output_path.mkdir(parents=True, exist_ok=True) + + logger.info("Merging LoRA adapter '%s' into base model...", language) + merged_model = peft_model.merge_and_unload() + merged_model.eval() + + logger.info("Exporting to ONNX: %s", output_path) + self._export_with_optimum(merged_model, processor, str(output_path)) + + return output_path + + def _export_with_optimum( + self, + merged_model, + processor: "WhisperProcessor", + output_dir: str, + ) -> None: + """Use optimum's ONNX export pipeline.""" + from optimum.exporters.onnx import main_export + + # Save merged model to a temp directory first + import tempfile + + with tempfile.TemporaryDirectory() as tmp_dir: + logger.info("Saving merged model to temp dir for export...") + merged_model.save_pretrained(tmp_dir) + processor.save_pretrained(tmp_dir) + + logger.info("Running optimum ONNX export...") + main_export( + model_name_or_path=tmp_dir, + output=output_dir, + task="automatic-speech-recognition", + opset=17, + optimize="O2", + ) + + logger.info("ONNX export complete: %s", output_dir) + + def validate( + self, + onnx_dir: str, + processor: "WhisperProcessor", + test_audio_arrays: list, + sample_rate: int = 16_000, + reference_texts: list[str] | None = None, + ) -> dict: + """ + Run inference with the exported ONNX model and compute WER vs. references. + """ + import numpy as np + from optimum.onnxruntime import ORTModelForSpeechSeq2Seq + + logger.info("Validating ONNX model at %s...", onnx_dir) + ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir) + + transcriptions = [] + for audio in test_audio_arrays: + inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt") + outputs = ort_model.generate(inputs.input_features) + text = processor.batch_decode(outputs, skip_special_tokens=True)[0] + transcriptions.append(text) + + result = {"transcriptions": transcriptions} + + if reference_texts: + import jiwer + wer = jiwer.wer(reference_texts, transcriptions) + result["wer"] = round(wer, 4) + + return result diff --git a/src/optimization/quantizer.py b/src/optimization/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9c21b2f85bf19fc0620904c5b9c4431d7dfe04 --- /dev/null +++ b/src/optimization/quantizer.py @@ -0,0 +1,95 @@ +""" +BitsAndBytes quantization for GPU-constrained deployment. +4-bit NF4: reduces Whisper-large-v3-turbo from ~3GB to ~1GB VRAM. +8-bit: intermediate option with less accuracy loss. +""" +from __future__ import annotations + +import logging +import time +from typing import TYPE_CHECKING + +import torch +from transformers import BitsAndBytesConfig, WhisperForConditionalGeneration, WhisperProcessor + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +def load_4bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration: + """Load Whisper with 4-bit NF4 quantization. Reduces VRAM to ~1GB.""" + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + ) + logger.info("Loading %s with 4-bit NF4 quantization...", model_id) + model = WhisperForConditionalGeneration.from_pretrained( + model_id, + quantization_config=bnb_config, + device_map="auto", + token=hf_token, + ) + return model + + +def load_8bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration: + """Load Whisper with 8-bit quantization. Reduces VRAM to ~1.5GB.""" + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + logger.info("Loading %s with 8-bit quantization...", model_id) + model = WhisperForConditionalGeneration.from_pretrained( + model_id, + quantization_config=bnb_config, + device_map="auto", + token=hf_token, + ) + return model + + +class ModelQuantizer: + """Benchmarks quantized vs full-precision models.""" + + def __init__(self, model_id: str, hf_token: str | None = None) -> None: + self.model_id = model_id + self.hf_token = hf_token + + def benchmark( + self, + model: WhisperForConditionalGeneration, + processor: WhisperProcessor, + test_audio_arrays: list, + sample_rate: int = 16_000, + ) -> dict: + """Measure latency and memory for a list of audio arrays.""" + import numpy as np + + device = next(model.parameters()).device + latencies = [] + + for audio in test_audio_arrays: + inputs = processor.feature_extractor(audio, sampling_rate=sample_rate, return_tensors="pt") + features = inputs.input_features.to(device) + + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + with torch.no_grad(): + model.generate(features, max_new_tokens=50) + + if device.type == "cuda": + torch.cuda.synchronize() + latencies.append((time.perf_counter() - t0) * 1000) + + result = { + "mean_latency_ms": round(sum(latencies) / len(latencies), 1), + "max_latency_ms": round(max(latencies), 1), + } + + if torch.cuda.is_available(): + result["vram_allocated_gb"] = round(torch.cuda.memory_allocated() / 1e9, 2) + + return result diff --git a/src/optimization/tflite_converter.py b/src/optimization/tflite_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..7341a086ed88b3f76f91d641fa2609d346985045 --- /dev/null +++ b/src/optimization/tflite_converter.py @@ -0,0 +1,76 @@ +""" +Converts ONNX models to TFLite for offline edge deployment (Android phones in rural areas). +Note: Whisper's encoder and decoder are exported as separate TFLite models and +orchestrated together at inference time. + +Requires: onnx-tf, tensorflow (install separately β€” large dependencies) +""" +from __future__ import annotations + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class TFLiteConverter: + """Converts ONNX Whisper models to TFLite format for edge deployment.""" + + def convert( + self, + onnx_encoder_path: str, + onnx_decoder_path: str, + output_dir: str, + quantize: bool = True, + ) -> dict[str, Path]: + """ + Convert encoder and decoder ONNX models to TFLite. + Returns paths to the generated .tflite files. + """ + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + encoder_tflite = output_path / "encoder.tflite" + decoder_tflite = output_path / "decoder.tflite" + + logger.info("Converting encoder ONNX β†’ TFLite...") + self._onnx_to_tflite(onnx_encoder_path, str(encoder_tflite), quantize=quantize) + + logger.info("Converting decoder ONNX β†’ TFLite...") + self._onnx_to_tflite(onnx_decoder_path, str(decoder_tflite), quantize=quantize) + + return {"encoder": encoder_tflite, "decoder": decoder_tflite} + + def _onnx_to_tflite(self, onnx_path: str, output_path: str, quantize: bool) -> None: + """Convert a single ONNX model to TFLite via onnx-tf + tensorflow.""" + try: + import onnx + import onnx_tf + import tensorflow as tf + except ImportError as e: + raise ImportError( + "TFLite conversion requires onnx-tf and tensorflow. " + "Install with: pip install onnx-tf tensorflow" + ) from e + + import tempfile + + # Step 1: ONNX β†’ TensorFlow SavedModel + with tempfile.TemporaryDirectory() as tmp_dir: + onnx_model = onnx.load(onnx_path) + tf_rep = onnx_tf.backend.prepare(onnx_model) + tf_rep.export_graph(tmp_dir) + + # Step 2: TF SavedModel β†’ TFLite + converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir) + + if quantize: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + + tflite_model = converter.convert() + + with open(output_path, "wb") as f: + f.write(tflite_model) + + size_mb = Path(output_path).stat().st_size / 1e6 + logger.info("TFLite model saved: %s (%.1f MB)", output_path, size_mb) diff --git a/src/training/__init__.py b/src/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/training/callbacks.py b/src/training/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..33ac11d58dacc4a34dc805e5b4175156a0278c18 --- /dev/null +++ b/src/training/callbacks.py @@ -0,0 +1,83 @@ +""" +Custom HuggingFace Trainer callbacks: +- EarlyStoppingOnWER: stops training when WER stops improving +- AdapterCheckpointCallback: saves only adapter weights (not full model) per checkpoint +""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class EarlyStoppingOnWER(TrainerCallback): + """ + Stops training if eval WER does not improve by min_delta over `patience` evaluations. + """ + + def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None: + self.patience = patience + self.min_delta = min_delta + self._best_wer: float = float("inf") + self._no_improve_count: int = 0 + + def on_evaluate( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + metrics: dict, + **kwargs, + ) -> None: + wer = metrics.get("eval_wer") + if wer is None: + return + + if wer < self._best_wer - self.min_delta: + self._best_wer = wer + self._no_improve_count = 0 + logger.info("WER improved to %.4f", wer) + else: + self._no_improve_count += 1 + logger.info( + "WER %.4f did not improve (best: %.4f). No-improve count: %d/%d", + wer, self._best_wer, self._no_improve_count, self.patience, + ) + if self._no_improve_count >= self.patience: + logger.warning("Early stopping triggered after %d evaluations without improvement.", self.patience) + control.should_training_stop = True + + +class AdapterCheckpointCallback(TrainerCallback): + """ + Saves only the LoRA adapter weights on each checkpoint event. + Adapter weights are ~50MB vs ~3GB for the full model. + """ + + def __init__(self, adapter_output_dir: str) -> None: + self.adapter_output_dir = Path(adapter_output_dir) + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + model, + **kwargs, + ) -> None: + checkpoint_dir = self.adapter_output_dir / f"checkpoint-{state.global_step}" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # model is a PeftModel β€” save only adapter weights + if hasattr(model, "save_pretrained"): + model.save_pretrained(str(checkpoint_dir)) + logger.info("Adapter checkpoint saved: %s", checkpoint_dir) + else: + logger.warning("Model does not have save_pretrained β€” skipping adapter checkpoint.") diff --git a/src/training/metrics.py b/src/training/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2729fec2037e8244c0a840b0e9b21f9e9e7400f5 --- /dev/null +++ b/src/training/metrics.py @@ -0,0 +1,40 @@ +""" +WER and CER computation for Seq2SeqTrainer eval loop. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +import numpy as np + +if TYPE_CHECKING: + from transformers import EvalPrediction, WhisperProcessor + + +def make_compute_metrics(processor: "WhisperProcessor") -> Callable[["EvalPrediction"], dict[str, float]]: + """ + Returns a compute_metrics function compatible with HuggingFace Seq2SeqTrainer. + Computes Word Error Rate (WER) and Character Error Rate (CER). + """ + import jiwer + + def compute_metrics(pred: "EvalPrediction") -> dict[str, float]: + pred_ids = pred.predictions + label_ids = pred.label_ids + + # Replace -100 (loss mask) with pad token id so decoding doesn't fail + label_ids = np.where(label_ids != -100, label_ids, processor.tokenizer.pad_token_id) + + pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) + label_str = processor.batch_decode(label_ids, skip_special_tokens=True) + + # Normalize whitespace + pred_str = [" ".join(s.split()) for s in pred_str] + label_str = [" ".join(s.split()) for s in label_str] + + wer = jiwer.wer(label_str, pred_str) + cer = jiwer.cer(label_str, pred_str) + + return {"wer": round(wer, 4), "cer": round(cer, 4)} + + return compute_metrics diff --git a/src/training/trainer.py b/src/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..62f6fbc6272e9339be642c3dc93e80ec9700ad6e --- /dev/null +++ b/src/training/trainer.py @@ -0,0 +1,235 @@ +""" +Orchestrates full LoRA fine-tuning: + WhisperBackbone + PEFT LoraConfig + WaxalDataLoader + Seq2SeqTrainer + +Usage: + trainer = WhisperLoRATrainer("configs/base_config.yaml", "configs/lora_bambara.yaml") + trainer.setup() + trainer.train() +""" +from __future__ import annotations + +import logging +import os +from pathlib import Path + +import torch +import yaml +from peft import LoraConfig, TaskType, get_peft_model +from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments + +from src.data.augmentation import FieldNoiseAugmenter +from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding +from src.data.waxal_loader import WaxalDataLoader +from src.engine.whisper_base import WhisperBackbone +from src.training.callbacks import AdapterCheckpointCallback, EarlyStoppingOnWER +from src.training.metrics import make_compute_metrics + +logger = logging.getLogger(__name__) + + +class WhisperLoRATrainer: + """Fine-tunes a language-specific LoRA adapter on top of Whisper.""" + + def __init__(self, base_config_path: str, language_config_path: str) -> None: + self._base_config_path = base_config_path + with open(base_config_path) as f: + self.config = yaml.safe_load(f) + with open(language_config_path) as f: + self.lang_config = yaml.safe_load(f) + + self._backbone: WhisperBackbone | None = None + self._peft_model = None + self._processor = None + self._train_dataset = None + self._eval_dataset = None + + def setup(self) -> None: + """Load backbone, build LoRA config, prepare datasets.""" + hf_token = os.getenv("HF_TOKEN") + device = "cuda" if torch.cuda.is_available() else "cpu" + + # 1. Load backbone + logger.info("Loading backbone model...") + self._backbone = WhisperBackbone(config_path=self._base_config_path) + self._backbone.load(device=device, hf_token=hf_token) + self._processor = self._backbone.processor + + # Disable cache for training + self._backbone.model.config.use_cache = False + + # 2. Wrap with LoRA + lora_cfg = self.lang_config["lora"] + lora_config = LoraConfig( + r=lora_cfg["r"], + lora_alpha=lora_cfg["lora_alpha"], + target_modules=lora_cfg["target_modules"], + lora_dropout=lora_cfg["lora_dropout"], + bias=lora_cfg["bias"], + task_type=TaskType.SEQ_2_SEQ_LM, + ) + self._peft_model = get_peft_model(self._backbone.model, lora_config) + self._peft_model.print_trainable_parameters() + + # 3. Load data + subset = self.lang_config["dataset_subset"] + augmenter = FieldNoiseAugmenter(self.config["paths"]["noise_samples"], self.config) + loader = WaxalDataLoader(subset, self.config, hf_token=hf_token) + + logger.info("Loading training data (streaming)...") + self._train_dataset = loader.load_split("train", streaming=True) + self._train_dataset = self._train_dataset.map( + loader.make_preprocess_fn(self._processor, augmenter), + remove_columns=self._train_dataset.column_names, + ) + + try: + self._eval_dataset = loader.load_split("validation", streaming=False) + self._eval_dataset = self._eval_dataset.map( + loader.make_preprocess_fn(self._processor, augmenter=None), + remove_columns=self._eval_dataset.column_names, + ) + except Exception: + logger.warning("No validation split found β€” eval will be skipped.") + self._eval_dataset = None + + def build_training_args(self) -> Seq2SeqTrainingArguments: + tc = self.config["training"] + output_dir = self.lang_config.get("output_dir", tc["output_dir"]) + return Seq2SeqTrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=tc["per_device_train_batch_size"], + gradient_accumulation_steps=tc["gradient_accumulation_steps"], + warmup_steps=tc["warmup_steps"], + max_steps=tc["max_steps"], + save_steps=tc["save_steps"], + eval_steps=tc["eval_steps"] if self._eval_dataset is not None else None, + evaluation_strategy="steps" if self._eval_dataset is not None else "no", + learning_rate=tc["learning_rate"], + fp16=tc["fp16"] and torch.cuda.is_available(), + dataloader_num_workers=tc["dataloader_num_workers"], # 0 on Windows + predict_with_generate=True, + generation_max_length=128, + logging_steps=25, + load_best_model_at_end=self._eval_dataset is not None, + metric_for_best_model="wer", + greater_is_better=False, + report_to="none", + ) + + def merge_extra_data( + self, + feedback_records: list[dict], + repeat: int = 3, + waxal_cap: int = 500, + ) -> None: + """ + Merge feedback corrections into the training dataset. + + Materializes up to `waxal_cap` Waxal samples (converts streaming β†’ Dataset), + then appends `feedback_records` (each repeated `repeat` times for upsampling) + preprocessed into {input_features, labels} format. + + Call this after setup() and before train(). + + Args: + feedback_records: List of dicts from corrections.jsonl with keys + 'audio_file' (path) and 'corrected_text'. + repeat: How many times to repeat each feedback sample + (3Γ— keeps corrections competitive with Waxal baseline). + waxal_cap: Max Waxal samples to materialise (avoids OOM on Colab T4). + """ + if self._peft_model is None: + raise RuntimeError("Call setup() before merge_extra_data().") + + import librosa + import numpy as np + from datasets import Dataset, concatenate_datasets + + logger.info( + "Merging %d feedback records (Γ—%d) with Waxal (cap=%d)...", + len(feedback_records), repeat, waxal_cap, + ) + + # ── 1. Materialise Waxal streaming dataset ───────────────────────────── + waxal_rows: list[dict] = [] + for row in self._train_dataset: + waxal_rows.append(row) + if len(waxal_rows) >= waxal_cap: + break + waxal_ds = Dataset.from_list(waxal_rows) + logger.info("Materialised %d Waxal samples.", len(waxal_ds)) + + # ── 2. Preprocess feedback records ───────────────────────────────────── + def _load_preprocess(rec: dict) -> dict | None: + try: + audio_np, _ = librosa.load(rec["audio_file"], sr=16000, mono=True) + inputs = self._processor.feature_extractor( + audio_np, sampling_rate=16000, return_tensors="np" + ) + labels = self._processor.tokenizer( + rec["corrected_text"], return_tensors="np" + ).input_ids[0] + return { + "input_features": inputs.input_features[0], + "labels": labels, + } + except Exception as e: + logger.warning("Skipping feedback record %s: %s", rec.get("id", "?"), e) + return None + + fb_rows = [] + for rec in feedback_records * repeat: + processed = _load_preprocess(rec) + if processed is not None: + fb_rows.append(processed) + + if not fb_rows: + logger.warning("No feedback records could be processed β€” using Waxal only.") + self._train_dataset = waxal_ds + return + + fb_ds = Dataset.from_list(fb_rows) + logger.info("Preprocessed %d feedback rows (after Γ—%d repeat).", len(fb_ds), repeat) + + # ── 3. Concatenate and replace train_dataset ─────────────────────────── + self._train_dataset = concatenate_datasets([waxal_ds, fb_ds]).shuffle(seed=42) + logger.info("Final training dataset: %d samples.", len(self._train_dataset)) + + def train(self) -> None: + if self._peft_model is None: + raise RuntimeError("Call setup() before train().") + + training_args = self.build_training_args() + output_dir = self.lang_config.get("output_dir", self.config["training"]["output_dir"]) + + collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=self._processor, + decoder_start_token_id=self._backbone.model.config.decoder_start_token_id, + ) + + callbacks = [ + AdapterCheckpointCallback(output_dir), + EarlyStoppingOnWER(patience=5), + ] + + compute_metrics = make_compute_metrics(self._processor) if self._eval_dataset is not None else None + + trainer = Seq2SeqTrainer( + model=self._peft_model, + args=training_args, + train_dataset=self._train_dataset, + eval_dataset=self._eval_dataset, + data_collator=collator, + compute_metrics=compute_metrics, + callbacks=callbacks, + tokenizer=self._processor.feature_extractor, + ) + + logger.info("Starting training for language '%s'...", self.lang_config["language"]) + trainer.train() + + # Save final adapter weights + Path(output_dir).mkdir(parents=True, exist_ok=True) + self._peft_model.save_pretrained(output_dir) + logger.info("Adapter saved to %s", output_dir) diff --git a/src/tts/__init__.py b/src/tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/tts/mms_tts.py b/src/tts/mms_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5d2ec23cf266695222f619494c5d593f1b4134 --- /dev/null +++ b/src/tts/mms_tts.py @@ -0,0 +1,145 @@ +""" +Facebook MMS-TTS engine for Bambara, Fula, French, and English. + +Usage: + engine = MMSTTSEngine() + wav_np, sample_rate = engine.synthesize("Foro fΙ› ji.", "bam", device="cuda") + wav_bytes = engine.text_to_audio_bytes("Foro fΙ› ji.", "bam", device="cuda") +""" +from __future__ import annotations + +import io +import re +from typing import Dict, Tuple + +import numpy as np +import soundfile as sf + +MODEL_IDS: Dict[str, str] = { + "bam": "facebook/mms-tts-bam", + "ful": "facebook/mms-tts-ful", + "fr": "facebook/mms-tts-fra", + "en": "facebook/mms-tts-eng", +} + +# Fallback for unknown languages β€” use French +_DEFAULT_LANG = "fr" + +# MMS-TTS quality degrades beyond ~15 words; split longer text at sentence boundaries +_MAX_WORDS_PER_CHUNK = 15 + +# Sentence-boundary split pattern (period, exclamation, question mark followed by space or end) +_SENTENCE_RE = re.compile(r"(?<=[.!?])\s+") + + +class MMSTTSEngine: + """Lazy-loading MMS-TTS engine. Models are loaded on first use and cached in CPU RAM.""" + + def __init__(self) -> None: + # {language_code: (VitsModel, VitsTokenizer)} + self._cache: Dict[str, tuple] = {} + + # ── private helpers ────────────────────────────────────────────────────── + + def _get_model(self, language: str): + """Return (model, tokenizer) for the requested language, loading if needed.""" + lang = language if language in MODEL_IDS else _DEFAULT_LANG + if lang not in self._cache: + from transformers import VitsModel, VitsTokenizer # type: ignore + model_id = MODEL_IDS[lang] + tokenizer = VitsTokenizer.from_pretrained(model_id) + model = VitsModel.from_pretrained(model_id) + model.eval() + # Keep on CPU until synthesize() moves it to the target device + self._cache[lang] = (model, tokenizer) + return self._cache[lang] + + @staticmethod + def _split_sentences(text: str) -> list[str]: + """Split text into chunks of ≀ _MAX_WORDS_PER_CHUNK words.""" + sentences = _SENTENCE_RE.split(text.strip()) + chunks: list[str] = [] + current: list[str] = [] + current_words = 0 + + for sent in sentences: + words = sent.split() + if current_words + len(words) > _MAX_WORDS_PER_CHUNK and current: + chunks.append(" ".join(current)) + current = words + current_words = len(words) + else: + current.extend(words) + current_words += len(words) + + if current: + chunks.append(" ".join(current)) + + return chunks or [text] + + def _synthesize_chunk( + self, text: str, model, tokenizer, device: str + ) -> np.ndarray: + """Synthesize a single short text chunk. Returns 1-D float32 numpy array.""" + import torch + + model.to(device) + inputs = tokenizer(text, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + output = model(**inputs) + + waveform = output.waveform[0].cpu().numpy() # shape: (samples,) + return waveform + + # ── public API ─────────────────────────────────────────────────────────── + + def synthesize( + self, text: str, language: str, device: str = "cuda" + ) -> Tuple[np.ndarray, int]: + """ + Convert text to speech waveform. + + Args: + text: Text to synthesize (any length β€” long text is split automatically). + language: Language code: "bam", "ful", "fr", or "en". + device: "cuda" or "cpu". + + Returns: + (waveform_np, sample_rate) β€” float32 numpy array, sample rate in Hz. + """ + lang = language if language in MODEL_IDS else _DEFAULT_LANG + model, tokenizer = self._get_model(lang) + + chunks = self._split_sentences(text) + waveforms: list[np.ndarray] = [] + + for chunk in chunks: + if not chunk.strip(): + continue + waveforms.append(self._synthesize_chunk(chunk, model, tokenizer, device)) + + # Free device memory before returning + model.to("cpu") + + if not waveforms: + return np.zeros(1, dtype=np.float32), model.config.sampling_rate + + combined = np.concatenate(waveforms) + return combined, model.config.sampling_rate + + def text_to_audio_bytes( + self, text: str, language: str, device: str = "cuda" + ) -> bytes: + """ + Convert text to WAV bytes suitable for gr.Audio or HTTP response. + + Returns raw WAV file bytes (16-bit PCM). + """ + waveform, sample_rate = self.synthesize(text, language, device=device) + + buf = io.BytesIO() + # soundfile expects float32 in [-1, 1]; MMS output is already normalised + sf.write(buf, waveform, sample_rate, format="WAV", subtype="PCM_16") + return buf.getvalue() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f23bd45db159555ea76e1d2ed673e0600d524565 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,178 @@ +""" +Integration tests for the FastAPI endpoints. +Uses httpx.AsyncClient with a mocked transcriber so GPU hardware is not required. +""" +from __future__ import annotations + +import io +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + + +def _make_mock_app(): + """Create a test FastAPI app with mocked model state.""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from src.api.routes import health, iot, transcribe + from src.engine.transcriber import TranscriptionResult + + app = FastAPI() + app.include_router(health.router, prefix="/api/v1") + app.include_router(transcribe.router, prefix="/api/v1") + app.include_router(iot.router, prefix="/api/v1") + + # Mock adapter manager + mock_adapter_manager = MagicMock() + mock_adapter_manager.get_active.return_value = "bam" + mock_adapter_manager.list_available.return_value = ["bam", "ful"] + mock_adapter_manager.list_loaded.return_value = ["bam"] + + # Mock transcriber + mock_transcriber = MagicMock() + mock_transcriber.transcribe_file.return_value = TranscriptionResult( + text="bunding nΙ”gΙ” foro", + language="bam", + duration_s=3.0, + processing_time_ms=250, + ) + + # Mock sensor bridge + mock_sensor_bridge = MagicMock() + from src.iot.sensor_bridge import SensorData + from datetime import datetime + mock_sensor_bridge.fetch = AsyncMock( + return_value=SensorData( + sensor_type="soil", + values={"moisture_pct": 42.0, "ph": 6.5}, + timestamp=datetime.utcnow().isoformat(), + ) + ) + + app.state.adapter_manager = mock_adapter_manager + app.state.transcriber = mock_transcriber + app.state.sensor_bridge = mock_sensor_bridge + + return app, TestClient(app) + + +class TestHealthEndpoint: + def setup_method(self): + self.app, self.client = _make_mock_app() + + def test_health_returns_200(self): + response = self.client.get("/api/v1/health") + assert response.status_code == 200 + + def test_health_response_structure(self): + data = self.client.get("/api/v1/health").json() + assert "status" in data + assert "model_loaded" in data + assert "adapters_available" in data + + def test_health_model_loaded_true(self): + data = self.client.get("/api/v1/health").json() + assert data["model_loaded"] is True + + def test_health_active_adapter(self): + data = self.client.get("/api/v1/health").json() + assert data["active_adapter"] == "bam" + + +class TestTranscribeEndpoint: + def setup_method(self): + self.app, self.client = _make_mock_app() + + def _wav_bytes(self) -> bytes: + """Minimal valid WAV file bytes for testing.""" + import wave + import struct + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframes(struct.pack("<" + "h" * 160, *([0] * 160))) + return buf.getvalue() + + def test_transcribe_returns_200(self): + response = self.client.post( + "/api/v1/transcribe", + data={"language": "bam"}, + files={"audio_file": ("test.wav", self._wav_bytes(), "audio/wav")}, + ) + assert response.status_code == 200 + + def test_transcribe_response_has_text(self): + data = self.client.post( + "/api/v1/transcribe", + data={"language": "bam"}, + files={"audio_file": ("test.wav", self._wav_bytes(), "audio/wav")}, + ).json() + assert "text" in data + assert isinstance(data["text"], str) + + def test_invalid_language_returns_422(self): + response = self.client.post( + "/api/v1/transcribe", + data={"language": "xyz"}, + files={"audio_file": ("test.wav", self._wav_bytes(), "audio/wav")}, + ) + assert response.status_code == 422 + + def test_unsupported_file_type_returns_422(self): + response = self.client.post( + "/api/v1/transcribe", + data={"language": "bam"}, + files={"audio_file": ("test.txt", b"not audio", "text/plain")}, + ) + assert response.status_code == 422 + + +class TestIoTQueryEndpoint: + def setup_method(self): + self.app, self.client = _make_mock_app() + + def _wav_bytes(self) -> bytes: + import io + import struct + import wave + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframes(struct.pack("<" + "h" * 160, *([0] * 160))) + return buf.getvalue() + + def test_query_returns_200(self): + response = self.client.post( + "/api/v1/query", + data={"language": "bam", "field_id": "field_001"}, + files={"audio_file": ("test.wav", self._wav_bytes(), "audio/wav")}, + ) + assert response.status_code == 200 + + def test_query_response_structure(self): + data = self.client.post( + "/api/v1/query", + data={"language": "bam"}, + files={"audio_file": ("test.wav", self._wav_bytes(), "audio/wav")}, + ).json() + assert "transcription" in data + assert "intent" in data + assert "sensor_data" in data + assert "voice_response" in data + + def test_query_voice_response_is_french(self): + data = self.client.post( + "/api/v1/query", + data={"language": "bam"}, + files={"audio_file": ("test.wav", self._wav_bytes(), "audio/wav")}, + ).json() + # French response should contain at least one French word + response_text = data["voice_response"] + french_indicators = ["du", "de", "le", "la", "les", "et", "HumiditΓ©", "sol", "pH"] + assert any(word in response_text for word in french_indicators) diff --git a/tests/test_data_pipeline.py b/tests/test_data_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a474159147b258bd34ae634134c4f53ca9947221 --- /dev/null +++ b/tests/test_data_pipeline.py @@ -0,0 +1,100 @@ +""" +Unit tests for the data pipeline: augmentation, feature extractor, agri dictionary. +""" +from __future__ import annotations + +import numpy as np +import pytest + + +class TestFieldNoiseAugmenter: + def test_augmenter_without_noise_files(self, tmp_path): + """Augmenter with empty noise_dir should fall back to Gaussian-only and still be ready.""" + config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.6}} + from src.data.augmentation import FieldNoiseAugmenter + + augmenter = FieldNoiseAugmenter(str(tmp_path), config) + assert augmenter.is_ready() + assert augmenter._gaussian_only + + def test_augmenter_output_shape(self, tmp_path): + """Augmented audio should have the same length as input.""" + config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 1.0}} + from src.data.augmentation import FieldNoiseAugmenter + + augmenter = FieldNoiseAugmenter(str(tmp_path), config) + audio = np.random.randn(16000).astype(np.float32) * 0.01 + augmented = augmenter.augment(audio, 16000) + assert augmented.shape == audio.shape + + def test_augmenter_no_crash_on_silent_audio(self, tmp_path): + """Silent audio (all zeros) should not crash the augmenter.""" + config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.5}} + from src.data.augmentation import FieldNoiseAugmenter + + augmenter = FieldNoiseAugmenter(str(tmp_path), config) + audio = np.zeros(16000, dtype=np.float32) + result = augmenter.augment(audio, 16000) + assert result is not None + + +class TestAgriculturalDictionary: + def test_bambara_vocab_not_empty(self): + from src.data.agri_dictionary import BAMBARA_VOCAB + + assert len(BAMBARA_VOCAB) > 0 + + def test_fula_vocab_not_empty(self): + from src.data.agri_dictionary import FULA_VOCAB + + assert len(FULA_VOCAB) > 0 + + def test_get_vocab_invalid_language(self): + from src.data.agri_dictionary import AgriculturalDictionary + + d = AgriculturalDictionary() + with pytest.raises(ValueError): + d.get_vocab("xyz") + + def test_prompt_text_contains_terms(self): + from src.data.agri_dictionary import AgriculturalDictionary + + d = AgriculturalDictionary() + prompt = d.get_prompt_text("bam") + assert "sΙ›nΙ›" in prompt + assert "kaba" in prompt + + +class TestDataCollator: + def test_collator_pads_labels(self): + """DataCollator should pad labels and replace pad tokens with -100.""" + from unittest.mock import MagicMock + + import torch + + from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding + + # Mock processor + processor = MagicMock() + processor.feature_extractor.pad.return_value = { + "input_features": torch.zeros(2, 80, 3000) + } + # Simulate padded labels batch + padded_labels = MagicMock() + padded_labels.input_ids = torch.tensor([[1, 2, 3, 0], [1, 4, 0, 0]]) + padded_labels.attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]) + processor.tokenizer.pad.return_value = padded_labels + + collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + decoder_start_token_id=1, + ) + + features = [ + {"input_features": np.zeros((80, 3000)), "labels": [1, 2, 3]}, + {"input_features": np.zeros((80, 3000)), "labels": [1, 4]}, + ] + batch = collator(features) + assert "labels" in batch + # -100 should appear where attention_mask is 0 + assert -100 in batch["labels"].tolist()[0] or -100 in batch["labels"].tolist()[1] diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..81c62d8fc8daba7647ad044448fadd4f107dc30d --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,96 @@ +""" +Unit tests for the engine layer: WhisperBackbone, AdapterManager, Transcriber. +These tests use mocks so they run without GPU or downloaded model weights. +""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + + +class TestWhisperBackbone: + def test_raises_before_load(self, tmp_path): + """Accessing model or processor before load() should raise RuntimeError.""" + import yaml + + config = {"model": {"id": "openai/whisper-large-v3-turbo"}, "training": {}, "audio": {}, "paths": {}} + config_path = tmp_path / "base_config.yaml" + config_path.write_text(yaml.dump(config)) + + from src.engine.whisper_base import WhisperBackbone + + backbone = WhisperBackbone(str(config_path)) + with pytest.raises(RuntimeError): + _ = backbone.model + with pytest.raises(RuntimeError): + _ = backbone.processor + + def test_model_id_read_from_config(self, tmp_path): + import yaml + + config = {"model": {"id": "test-model-id"}, "training": {}, "audio": {}, "paths": {}} + config_path = tmp_path / "base_config.yaml" + config_path.write_text(yaml.dump(config)) + + from src.engine.whisper_base import WhisperBackbone + + backbone = WhisperBackbone(str(config_path)) + assert backbone.model_id == "test-model-id" + + +class TestAdapterManager: + def _make_mock_model(self): + mock_model = MagicMock() + mock_model.peft_config = {} + return mock_model + + def test_register_missing_path(self, tmp_path): + """Registering a non-existent path should log a warning but not raise.""" + from src.engine.adapter_manager import AdapterManager + + model = self._make_mock_model() + manager = AdapterManager(model, {}) + # Should not raise + manager.register("bam", str(tmp_path / "nonexistent")) + assert "bam" in manager.list_available() + + def test_list_available(self, tmp_path): + from src.engine.adapter_manager import AdapterManager + + bam_path = tmp_path / "bam" + bam_path.mkdir() + ful_path = tmp_path / "ful" + ful_path.mkdir() + + model = self._make_mock_model() + manager = AdapterManager(model, {}) + manager.register("bam", str(bam_path)) + manager.register("ful", str(ful_path)) + + available = manager.list_available() + assert "bam" in available + assert "ful" in available + + def test_unregistered_language_raises(self): + from src.engine.adapter_manager import AdapterManager + + model = self._make_mock_model() + manager = AdapterManager(model, {}) + with pytest.raises(KeyError): + manager.load_adapter("xyz") + + +class TestTranscriptionResult: + def test_dataclass_fields(self): + from src.engine.transcriber import TranscriptionResult + + result = TranscriptionResult( + text="test", + language="bam", + duration_s=5.0, + processing_time_ms=120, + ) + assert result.text == "test" + assert result.confidence is None diff --git a/tests/test_iot.py b/tests/test_iot.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ebbbc7c865d5d5b45739a368f99ae7a40c72ce --- /dev/null +++ b/tests/test_iot.py @@ -0,0 +1,119 @@ +""" +Unit tests for the IoT layer: IntentParser, SensorBridge (mock), VoiceResponder. +""" +from __future__ import annotations + +import asyncio + +import pytest + + +class TestIntentParser: + def setup_method(self): + from src.iot.intent_parser import IntentParser + self.parser = IntentParser() + + def test_bambara_soil_intent(self): + intent = self.parser.parse("bunding nΙ”gΙ” foro", "bam") + assert intent.action == "check_soil" + assert intent.confidence > 0 + + def test_bambara_water_intent(self): + intent = self.parser.parse("ji sanji foro", "bam") + assert intent.action == "irrigation_status" + + def test_fula_soil_intent(self): + intent = self.parser.parse("leydi ngesa ladde", "ful") + assert intent.action == "check_soil" + + def test_fula_pest_intent(self): + intent = self.parser.parse("biΓ±-biΓ±", "ful") + assert intent.action == "pest_alert" + + def test_unknown_text_returns_some_intent(self): + intent = self.parser.parse("hello world", "bam") + # Should return something without crashing + assert intent.action is not None + assert intent.confidence == 0.0 + + def test_confidence_range(self): + intent = self.parser.parse("bunding nΙ”gΙ”", "bam") + assert 0.0 <= intent.confidence <= 1.0 + + +class TestSensorBridgeMock: + def setup_method(self): + from src.iot.sensor_bridge import SensorBridge + self.bridge = SensorBridge(sensor_api_url=None) + + def test_mock_mode_enabled(self): + assert self.bridge._mock_mode is True + + def test_get_soil_data_returns_sensor_data(self): + from src.iot.sensor_bridge import SensorData + data = asyncio.run(self.bridge.get_soil_data("field_001")) + assert isinstance(data, SensorData) + assert data.sensor_type == "soil" + assert "moisture_pct" in data.values + assert "ph" in data.values + + def test_get_weather_returns_sensor_data(self): + from src.iot.sensor_bridge import SensorData + data = asyncio.run(self.bridge.get_weather("zone_a")) + assert isinstance(data, SensorData) + assert "temperature_c" in data.values + + def test_fetch_dispatches_by_intent_action(self): + from src.iot.intent_parser import Intent + from src.iot.sensor_bridge import SensorData + + intent = Intent(action="check_soil", entity="soil", confidence=0.8) + data = asyncio.run(self.bridge.fetch(intent)) + assert isinstance(data, SensorData) + assert data.sensor_type == "soil" + + +class TestVoiceResponder: + def setup_method(self): + from src.iot.voice_responder import VoiceResponder + self.responder = VoiceResponder(language="fr") + + def _make_intent(self, action): + from src.iot.intent_parser import Intent + return Intent(action=action, entity=action, confidence=1.0) + + def _make_sensor_data(self, sensor_type, values): + from src.iot.sensor_bridge import SensorData + from datetime import datetime + return SensorData(sensor_type=sensor_type, values=values, timestamp=datetime.utcnow().isoformat()) + + def test_low_moisture_triggers_irrigation_warning(self): + intent = self._make_intent("check_soil") + data = self._make_sensor_data("soil", {"moisture_pct": 20.0, "ph": 6.5}) + response = self.responder.generate_response(intent, data) + assert "Irrigation" in response or "irrigation" in response + + def test_normal_soil_contains_moisture(self): + intent = self._make_intent("check_soil") + data = self._make_sensor_data("soil", {"moisture_pct": 45.0, "ph": 6.8}) + response = self.responder.generate_response(intent, data) + assert "45" in response + + def test_high_pest_level_triggers_warning(self): + intent = self._make_intent("pest_alert") + data = self._make_sensor_data("pest", {"alert_level": 3.0, "trap_count_24h": 45.0}) + response = self.responder.generate_response(intent, data) + assert "Traitement" in response or "traitement" in response + + def test_weather_response_contains_temperature(self): + intent = self._make_intent("check_weather") + data = self._make_sensor_data("weather", {"temperature_c": 36.0, "rain_probability_pct": 10.0}) + response = self.responder.generate_response(intent, data) + assert "36" in response + + def test_french_response_output(self): + intent = self._make_intent("check_weather") + data = self._make_sensor_data("weather", {"temperature_c": 30.0}) + response = self.responder.generate_response(intent, data) + # Response should be in French + assert any(word in response for word in ["TempΓ©rature", "tempΓ©rature", "HumiditΓ©", "Vent", "pluie"])