""" ONNX export utilities. Provides functions for exporting trained models to ONNX format, including PyTorch checkpoints and JAX-trained models. """ from __future__ import annotations import logging from typing import TYPE_CHECKING import numpy as np import torch from torch import nn from learning_munsell import PROJECT_ROOT from learning_munsell.models.networks import TransformerToMunsell if TYPE_CHECKING: from pathlib import Path LOGGER = logging.getLogger(__name__) __all__ = [ "export_transformer_to_onnx", "export_jax_to_onnx", ] def export_transformer_to_onnx( checkpoint_path: Path | None = None, output_path: Path | None = None, ) -> Path: """ Export Transformer model from checkpoint to ONNX format. Parameters ---------- checkpoint_path : Path, optional Path to the checkpoint file. Defaults to models/from_xyY/transformer_large_best.pth. output_path : Path, optional Path for the ONNX output file. Defaults to models/from_xyY/transformer_large.onnx. Returns ------- Path Path to the exported ONNX file. Raises ------ FileNotFoundError If checkpoint file does not exist. """ model_directory = PROJECT_ROOT / "models" / "from_xyY" if checkpoint_path is None: checkpoint_path = model_directory / "transformer_large_best.pth" if output_path is None: output_path = model_directory / "transformer_large.onnx" if not checkpoint_path.exists(): msg = f"Checkpoint not found: {checkpoint_path}" raise FileNotFoundError(msg) LOGGER.info("Loading checkpoint from %s...", checkpoint_path) checkpoint = torch.load(checkpoint_path, weights_only=False, map_location="cpu") model = TransformerToMunsell( num_features=3, embedding_dim=256, num_blocks=6, num_heads=8, ff_dim=1024, dropout=0.1, ) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() LOGGER.info("Exporting to ONNX...") dummy_input = torch.randn(1, 3) torch.onnx.export( model, dummy_input, output_path, export_params=True, opset_version=14, input_names=["xyY"], output_names=["munsell_spec"], dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}}, do_constant_folding=True, dynamo=False, ) LOGGER.info("ONNX model exported to: %s", output_path) # Verify export import onnxruntime as ort # noqa: PLC0415 session = ort.InferenceSession(str(output_path)) test_np = np.random.randn(10, 3).astype(np.float32) onnx_output = session.run(None, {"xyY": test_np})[0] with torch.no_grad(): torch_output = model(torch.from_numpy(test_np)).numpy() max_diff = np.max(np.abs(onnx_output - torch_output)) LOGGER.info("Max difference between PyTorch and ONNX: %.6f", max_diff) if max_diff < 1e-4: LOGGER.info("ONNX export verified successfully!") else: LOGGER.warning("ONNX export may have precision issues") return output_path # JAX-specific classes for weight conversion class _ComponentMLP(nn.Module): """PyTorch MLP matching JAX ComponentMLP architecture for weight loading.""" def __init__(self, input_dim: int = 3, width_multiplier: float = 1.0) -> None: super().__init__() h1 = int(128 * width_multiplier) h2 = int(256 * width_multiplier) h3 = int(512 * width_multiplier) self.layers = nn.ModuleList( [ nn.Linear(input_dim, h1), nn.LayerNorm(h1), nn.Linear(h1, h2), nn.LayerNorm(h2), nn.Linear(h2, h3), nn.LayerNorm(h3), nn.Linear(h3, h2), nn.LayerNorm(h2), nn.Linear(h2, h1), nn.LayerNorm(h1), nn.Linear(h1, 1), ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layers[0](x) x = torch.relu(x) x = self.layers[1](x) x = self.layers[2](x) x = torch.relu(x) x = self.layers[3](x) x = self.layers[4](x) x = torch.relu(x) x = self.layers[5](x) x = self.layers[6](x) x = torch.relu(x) x = self.layers[7](x) x = self.layers[8](x) x = torch.relu(x) x = self.layers[9](x) return self.layers[10](x) class _MultiMLPJAX(nn.Module): """PyTorch Multi-MLP matching JAX architecture for weight loading.""" def __init__(self) -> None: super().__init__() self.hue_branch = _ComponentMLP(input_dim=3, width_multiplier=1.0) self.value_branch = _ComponentMLP(input_dim=3, width_multiplier=1.0) self.chroma_branch = _ComponentMLP(input_dim=3, width_multiplier=2.0) self.code_branch = _ComponentMLP(input_dim=3, width_multiplier=1.0) def forward(self, x: torch.Tensor) -> torch.Tensor: hue = self.hue_branch(x) value = self.value_branch(x) chroma = self.chroma_branch(x) code = self.code_branch(x) return torch.cat([hue, value, chroma, code], dim=-1) def _load_jax_weights(weights_path: Path, model: nn.Module) -> nn.Module: """Load JAX/Flax weights into PyTorch model.""" saved = np.load(weights_path) branch_map = { "ComponentMLP_0": model.hue_branch, "ComponentMLP_1": model.value_branch, "ComponentMLP_2": model.chroma_branch, "ComponentMLP_3": model.code_branch, } layer_map = { "Dense_0": 0, "LayerNorm_0": 1, "Dense_1": 2, "LayerNorm_1": 3, "Dense_2": 4, "LayerNorm_2": 5, "Dense_3": 6, "LayerNorm_3": 7, "Dense_4": 8, "LayerNorm_4": 9, "Dense_5": 10, } for key in saved.files: if key == "metadata": continue parts = key.split("_") if parts[0] != "params": continue component_name = f"{parts[1]}_{parts[2]}" layer_name = f"{parts[3]}_{parts[4]}" param_name = parts[5] branch = branch_map[component_name] layer_idx = layer_map[layer_name] layer = branch.layers[layer_idx] weight = saved[key] if "Dense" in layer_name: if param_name == "kernel": layer.weight.data = torch.from_numpy(weight.T).float() elif param_name == "bias": layer.bias.data = torch.from_numpy(weight).float() elif "LayerNorm" in layer_name: if param_name == "scale": layer.weight.data = torch.from_numpy(weight).float() elif param_name == "bias": layer.bias.data = torch.from_numpy(weight).float() return model def export_jax_to_onnx( weights_path: Path | None = None, output_path: Path | None = None, ) -> Path: """ Export JAX-trained Multi-MLP model to ONNX format. Loads weights from a JAX-trained model, creates an equivalent PyTorch model, and exports to ONNX format. Parameters ---------- weights_path : Path, optional Path to the JAX weights file (.npz). Defaults to models/from_xyY/multi_mlp_jax_delta_e.npz. output_path : Path, optional Path for the ONNX output file. Defaults to models/from_xyY/multi_mlp_jax_delta_e.onnx. Returns ------- Path Path to the exported ONNX file. Raises ------ FileNotFoundError If weights file does not exist. """ models_dir = PROJECT_ROOT / "models" / "from_xyY" if weights_path is None: weights_path = models_dir / "multi_mlp_jax_delta_e.npz" if output_path is None: output_path = models_dir / "multi_mlp_jax_delta_e.onnx" if not weights_path.exists(): msg = f"JAX weights not found: {weights_path}" raise FileNotFoundError(msg) LOGGER.info("Loading JAX weights from %s", weights_path) model = _MultiMLPJAX() model = _load_jax_weights(weights_path, model) model.eval() total_params = sum(p.numel() for p in model.parameters()) LOGGER.info("Model parameters: %s", f"{total_params:,}") dummy_input = torch.randn(1, 3) torch.onnx.export( model, dummy_input, output_path, input_names=["xyY"], output_names=["munsell_spec"], dynamic_axes={"xyY": {0: "batch"}, "munsell_spec": {0: "batch"}}, opset_version=17, ) LOGGER.info("Exported ONNX: %s", output_path) # Save normalization params norm_params_path = output_path.with_name( output_path.stem + "_normalization_parameters.npz" ) np.savez( norm_params_path, output_parameters={ "hue_range": [0.5, 10.0], "value_range": [0.0, 10.0], "chroma_range": [0.0, 50.0], "code_range": [1.0, 10.0], }, ) LOGGER.info("Saved normalization params: %s", norm_params_path) return output_path def main() -> None: """Export models to ONNX format.""" import argparse # noqa: PLC0415 parser = argparse.ArgumentParser(description="Export models to ONNX") parser.add_argument( "model", choices=["transformer", "jax"], help="Model type to export", ) args = parser.parse_args() if args.model == "transformer": export_transformer_to_onnx() elif args.model == "jax": export_jax_to_onnx() if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(message)s", force=True) main()