""" model.py ======== Public entry point for WEO-SAS/sen2sr stored on HuggingFace Hub. All parameters are read from config.json. Usage ----- from huggingface_hub import snapshot_download import sys local_dir = snapshot_download("WEO-SAS/sen2sr") sys.path.insert(0, local_dir) from model import Model model = Model(local_dir=local_dir) # Array inference: (4, H, W) float32 in [0, 1] -> (4, H*4, W*4) float32 sr = model.predict(image) # GeoTIFF pipeline model.predict_tif("s2_scene.tif", "s2_sr.tif") """ from __future__ import annotations import importlib.util import json import os import sys from typing import List, Optional import numpy as np def _load_module(name: str, path: str): spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) sys.modules[name] = module spec.loader.exec_module(module) return module class Model: """ Public SEN2SR model interface for HuggingFace Hub users. Parameters ---------- local_dir : str Path to the directory returned by ``snapshot_download(repo_id)``. **overrides Optionally override any value from config.json, e.g. ``Model(local_dir=d, patch_size=256, overlap=64)``. """ def __init__(self, local_dir: str, **overrides): config_path = os.path.join(local_dir, "config.json") with open(config_path) as f: config = json.load(f) config.update(overrides) if local_dir not in sys.path: sys.path.insert(0, local_dir) sen2sr_pt = _load_module("sen2sr_pt", os.path.join(local_dir, "sen2sr_pt.py")) self._model = sen2sr_pt.SEN2SRPT(local_dir=local_dir, config=config) self.description = config.get("description", "") def predict(self, image: np.ndarray) -> np.ndarray: """ Run 4x super-resolution on a single image. Parameters ---------- image : (C, H, W) float32 numpy array, values in [0, 1] C must equal in_channels (4 for RGBN) Returns ------- (C, H*4, W*4) float32 numpy array """ return self._model.predict(image) def predict_tif( self, input_path: str, output_path: str, bands: Optional[List[int]] = None, ) -> None: """ Full GeoTIFF super-resolution pipeline. Parameters ---------- input_path : path to input Sentinel-2 GeoTIFF output_path : output path for the 2.5 m SR GeoTIFF bands : 0-based band indices to read (default: [0, 1, 2, 3]) """ self._model.predict_tif(input_path, output_path, bands)