| |
| """ |
| Export DAP (Depth Any Panoramas) to CoreML for iOS/macOS. |
| |
| Produces a single-output CoreML model (depth map only) with ImageType input, |
| compatible with Vision framework and the included DepthPredictor.swift. |
| |
| Usage: |
| python export_and_validate_coreml.py |
| python export_and_validate_coreml.py --height 768 --width 1536 |
| """ |
|
|
| import os |
| import sys |
| import time |
| import numpy as np |
| import torch |
| import coremltools as ct |
| from PIL import Image |
| from argparse import ArgumentParser |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from networks.dap import make_model |
|
|
|
|
| |
| |
| |
| class DAPSingleOutputWrapper(torch.nn.Module): |
| """Returns only depth map (clamped to [0, 10]). Uses torch.where for CoreML compat.""" |
|
|
| IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) |
| IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) |
|
|
| def __init__(self, dap_model): |
| super().__init__() |
| self.dap = dap_model |
| self.register_buffer("mean", self.IMAGENET_MEAN) |
| self.register_buffer("std", self.IMAGENET_STD) |
|
|
| def forward(self, x): |
| |
| x = (x - self.mean) / self.std |
| out = self.dap(x) |
| depth = out["pred_depth"] |
| |
| depth = torch.where(depth < 0.0, torch.zeros_like(depth), depth) |
| depth = torch.where(depth > 10.0, torch.full_like(depth, 10.0), depth) |
| depth = torch.where(torch.isnan(depth), torch.zeros_like(depth), depth) |
| return depth |
|
|
|
|
| |
| |
| |
| def load_weights(model, weight_path): |
| state_dict = torch.load(weight_path, map_location="cpu", weights_only=False) |
| cleaned = {} |
| for k, v in state_dict.items(): |
| if k.startswith("module."): |
| k = k[len("module."):] |
| if not isinstance(v, torch.Tensor): |
| continue |
| cleaned[k] = v |
| model_keys = set(model.state_dict().keys()) |
| matched = {k: v for k, v in cleaned.items() if k in model_keys} |
| unmatched = set(cleaned.keys()) - model_keys |
| if unmatched: |
| print(f" [info] Skipping {len(unmatched)} unmatched weight keys") |
| if not matched: |
| print(" [error] No weights matched!") |
| sys.exit(1) |
| model.load_state_dict(matched, strict=False) |
| print(f" Loaded {len(matched)} weight tensors from {weight_path}") |
|
|
|
|
| def prepare_image(image_path, height, width): |
| """Load and resize image, return PIL Image + float32 tensor.""" |
| img = Image.open(image_path).convert("RGB") |
| img_resized = img.resize((width, height), Image.LANCZOS) |
| img_np = np.array(img_resized).astype(np.float32) / 255.0 |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0) |
| return img_tensor, img_resized |
|
|
|
|
| def validate_trace(original, traced, example): |
| original.eval() |
| traced.eval() |
| with torch.no_grad(): |
| out_orig = original(example) |
| out_traced = traced(example) |
| max_diff = (out_orig - out_traced).abs().max().item() |
| has_nan = torch.isnan(out_traced).any().item() |
| print(f" Trace validation: max_diff={max_diff:.2e}, has_nan={has_nan}") |
| if has_nan: |
| print(" [error] Traced output contains NaN!") |
| sys.exit(1) |
| return max_diff |
|
|
|
|
| def run_pytorch_inference(model, img_tensor): |
| """Run PyTorch inference (no normalization wrapper — raw model).""" |
| IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) |
| IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) |
| model.eval() |
| with torch.no_grad(): |
| x = (img_tensor - IMAGENET_MEAN) / IMAGENET_STD |
| output = model(x) |
| depth = output["pred_depth"].squeeze().cpu().numpy() |
| |
| depth = np.where(depth < 0.0, 0.0, depth) |
| depth = np.where(depth > 10.0, 10.0, depth) |
| depth = np.nan_to_num(depth, nan=0.0, posinf=10.0, neginf=0.0) |
| return depth |
|
|
|
|
| def run_coreml_inference(mlpackage_path, pil_image): |
| """Run CoreML inference with PIL Image (ImageType input).""" |
| model = ct.models.MLModel(mlpackage_path) |
|
|
| start = time.time() |
| output = model.predict({"image": pil_image}) |
| elapsed = time.time() - start |
|
|
| output_tensor = list(output.values())[0] |
| depth = output_tensor[0, 0, :, :] |
|
|
| return depth, elapsed |
|
|
|
|
| def compute_metrics(depth_pt, depth_cl): |
| if depth_pt.shape != depth_cl.shape: |
| print(f" [warn] Shape mismatch: PyTorch {depth_pt.shape} vs CoreML {depth_cl.shape}") |
| from scipy.ndimage import zoom |
| depth_cl = zoom(depth_cl, (depth_pt.shape[0] / depth_cl.shape[0], depth_pt.shape[1] / depth_cl.shape[1])) |
|
|
| diff = np.abs(depth_pt - depth_cl) |
| max_diff = diff.max() |
| mean_diff = diff.mean() |
| rmse = np.sqrt(np.mean((depth_pt - depth_cl) ** 2)) |
|
|
| mask = depth_pt > 1e-6 |
| rel_error = (diff[mask] / depth_pt[mask]).mean() if mask.sum() > 0 else float("nan") |
| correlation = np.corrcoef(depth_pt.flatten(), depth_cl.flatten())[0, 1] |
|
|
| return { |
| "max_abs_diff": max_diff, |
| "mean_abs_diff": mean_diff, |
| "rmse": rmse, |
| "mean_rel_error": rel_error, |
| "correlation": correlation, |
| } |
|
|
|
|
| def save_comparison_viz(depth_pt, depth_cl, metrics, output_dir): |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| fig, axes = plt.subplots(3, 1, figsize=(6, 12)) |
|
|
| vmax = max(depth_pt.max(), depth_cl.max()) |
| im0 = axes[0].imshow(depth_pt, cmap="Spectral", vmin=0, vmax=vmax) |
| axes[0].set_title(f"PyTorch Depth\n[{depth_pt.min():.4f}, {depth_pt.max():.4f}]") |
| axes[0].axis("off") |
|
|
| im1 = axes[1].imshow(depth_cl, cmap="Spectral", vmin=0, vmax=vmax) |
| axes[1].set_title(f"CoreML Depth\n[{depth_cl.min():.4f}, {depth_cl.max():.4f}]") |
| axes[1].axis("off") |
|
|
| diff = np.abs(depth_pt - depth_cl) |
| im2 = axes[2].imshow(diff, cmap="hot") |
| axes[2].set_title(f"Abs Diff\nmax={diff.max():.6f}, mean={diff.mean():.6f}") |
| axes[2].axis("off") |
|
|
| plt.colorbar(im0, ax=axes[0], fraction=0.046) |
| plt.colorbar(im1, ax=axes[1], fraction=0.046) |
| plt.colorbar(im2, ax=axes[2], fraction=0.046) |
|
|
| plt.tight_layout() |
| viz_path = os.path.join(output_dir, "comparison.png") |
| plt.savefig(viz_path, dpi=150) |
| plt.close() |
| print(f" Saved visualization to {viz_path}") |
|
|
|
|
| |
| |
| |
| def main(): |
| parser = ArgumentParser(description="Export DAP to CoreML and validate against PyTorch") |
| parser.add_argument("--image", default=os.path.join(os.path.dirname(__file__), "test", "test.png")) |
| parser.add_argument("--height", type=int, default=768) |
| parser.add_argument("--width", type=int, default=1536) |
| parser.add_argument("--model_type", choices=["vits", "vitb", "vitl", "vitg"], default="vitl") |
| parser.add_argument("--weights", default=os.path.join(os.path.dirname(__file__), "model.pth")) |
| parser.add_argument("--output", default=os.path.join(os.path.dirname(__file__), "DAPModel.mlpackage")) |
| parser.add_argument("--results", default=os.path.join(os.path.dirname(__file__), "test_output")) |
| parser.add_argument("--threshold", type=float, default=0.05) |
| parser.add_argument("--skip_export", action="store_true", help="Skip CoreML export, only validate") |
| args = parser.parse_args() |
|
|
| patch_size = 16 |
| if args.height % patch_size != 0 or args.width % patch_size != 0: |
| print(f" [error] Dimensions must be multiples of {patch_size}. Got {args.height}x{args.width}") |
| sys.exit(1) |
|
|
| |
| |
| |
| print("=" * 60) |
| print("DAP CoreML Export + Validation (single output, ImageType)") |
| print("=" * 60) |
|
|
| print(f"\n[1/5] Loading test image: {args.image}") |
| if not os.path.exists(args.image): |
| print(f" [error] Image not found: {args.image}") |
| sys.exit(1) |
| img_tensor, pil_image = prepare_image(args.image, args.height, args.width) |
| print(f" Image resized to {args.width}x{args.height}") |
|
|
| |
| |
| |
| print(f"\n[2/5] Building DAP model ({args.model_type}) + PyTorch inference ...") |
| model = make_model(midas_model_type=args.model_type) |
| load_weights(model, args.weights) |
| model.eval() |
|
|
| pt_start = time.time() |
| depth_pt = run_pytorch_inference(model, img_tensor) |
| pt_time = time.time() - pt_start |
| print(f" PyTorch time: {pt_time*1000:.1f}ms") |
| print(f" Depth: {depth_pt.shape}, range=[{depth_pt.min():.4f}, {depth_pt.max():.4f}]") |
|
|
| |
| os.makedirs(args.results, exist_ok=True) |
| np.save(os.path.join(args.results, "pytorch_depth.npy"), depth_pt) |
| print(f" Ground truth saved to {args.results}/pytorch_depth.npy") |
|
|
| |
| |
| |
| if not args.skip_export: |
| print(f"\n[3/5] Exporting CoreML model (ImageType input) ...") |
| wrapped = DAPSingleOutputWrapper(model) |
| wrapped.eval() |
|
|
| |
| example_input = torch.rand(1, 3, args.height, args.width) |
| traced = torch.jit.trace(wrapped, example_input) |
| validate_trace(wrapped, traced, example_input) |
|
|
| print(" Converting to CoreML (this may take a few minutes) ...") |
| image_input = ct.ImageType( |
| name="image", |
| shape=(1, 3, args.height, args.width), |
| scale=1 / 255.0, |
| bias=[0.0, 0.0, 0.0], |
| color_layout=ct.colorlayout.RGB, |
| channel_first=True, |
| ) |
|
|
| mlmodel = ct.convert( |
| traced, |
| inputs=[image_input], |
| outputs=[ct.TensorType(name="depth", dtype=np.float32)], |
| minimum_deployment_target=ct.target.iOS18, |
| compute_precision=ct.precision.FLOAT32, |
| compute_units=ct.ComputeUnit.ALL, |
| ) |
| mlmodel.save(args.output) |
|
|
| total_size = sum( |
| os.path.getsize(os.path.join(dp, f)) |
| for dp, _, fnames in os.walk(args.output) |
| for f in fnames |
| ) |
| print(f" Saved to {args.output} ({total_size / (1024**2):.0f} MB)") |
| else: |
| print("\n[3/5] Skipping CoreML export (--skip_export)") |
|
|
| |
| |
| |
| print(f"\n[4/5] Validating CoreML model ...") |
| if os.path.exists(args.output): |
| depth_cl, cl_time = run_coreml_inference(args.output, pil_image) |
| print(f" CoreML time: {cl_time*1000:.1f}ms") |
| print(f" Depth: {depth_cl.shape}, range=[{depth_cl.min():.4f}, {depth_cl.max():.4f}]") |
|
|
| metrics = compute_metrics(depth_pt, depth_cl) |
| passed = metrics["max_abs_diff"] < args.threshold |
|
|
| print(f" Max Abs Diff: {metrics['max_abs_diff']:.2e}") |
| print(f" Mean Abs Diff: {metrics['mean_abs_diff']:.2e}") |
| print(f" RMSE: {metrics['rmse']:.2e}") |
| print(f" Rel Error: {metrics['mean_rel_error']:.2e}") |
| print(f" Correlation: {metrics['correlation']:.6f}") |
| print(f" {'PASS' if passed else 'FAIL'} (threshold: {args.threshold})") |
|
|
| save_comparison_viz(depth_pt, depth_cl, metrics, args.results) |
| else: |
| print(f" [warn] Model not found: {args.output}") |
| metrics = None |
| passed = False |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("SUMMARY") |
| print("=" * 60) |
| print(f" Model: DAP {args.model_type}") |
| print(f" Input: {args.width}x{args.height} (ImageType)") |
| print(f" Image: {args.image}") |
| print(f" Threshold: {args.threshold}") |
| print("-" * 60) |
| if metrics: |
| status = "PASS" if passed else "FAIL" |
| print(f" Validation: {status} (max_diff={metrics['max_abs_diff']:.2e})") |
| print("-" * 60) |
| print(f" PyTorch: {pt_time*1000:.1f}ms") |
| if 'cl_time' in dir(): |
| print(f" CoreML: {cl_time*1000:.1f}ms") |
| print(f" Results: {args.results}/") |
| print("=" * 60) |
|
|
| |
| if metrics and not passed: |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|