| """ |
| Common utilities for learning-munsell. |
| |
| Provides shared functions for MLflow tracking, model comparison reports, |
| and benchmarking across all training and comparison scripts. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import time |
| from datetime import datetime |
| from pathlib import Path |
| from typing import TYPE_CHECKING |
|
|
| import mlflow |
| import numpy as np |
|
|
| from learning_munsell import PROJECT_ROOT |
|
|
| if TYPE_CHECKING: |
| from collections.abc import Callable |
|
|
| __all__ = [ |
| "setup_mlflow_experiment", |
| "log_training_epoch", |
| "fix_onnx_dynamic_batch", |
| "get_model_size_mb", |
| "benchmark_inference_speed", |
| "generate_html_report_header", |
| "generate_html_report_footer", |
| "generate_best_models_summary", |
| "generate_ranking_section", |
| ] |
|
|
|
|
| def setup_mlflow_experiment(direction: str, model_name: str) -> str: |
| """ |
| Set up MLflow experiment and return run name. |
| |
| Parameters |
| ---------- |
| direction |
| Conversion direction, either "from_xyY" or "to_xyY". |
| model_name |
| Name of the model being trained. |
| |
| Returns |
| ------- |
| str |
| Generated run name with timestamp. |
| """ |
|
|
| mlflow.set_tracking_uri(f"sqlite:///{PROJECT_ROOT / 'mlruns.db'}") |
| mlflow.set_experiment(f"learning-munsell-{direction}") |
|
|
| |
| logging.getLogger().setLevel(logging.INFO) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| return f"{model_name}_{timestamp}" |
|
|
|
|
| def log_training_epoch( |
| epoch: int, train_loss: float, val_loss: float, lr: float |
| ) -> None: |
| """ |
| Log standard training metrics for an epoch. |
| |
| Parameters |
| ---------- |
| epoch |
| Current epoch number. |
| train_loss |
| Training loss for the epoch. |
| val_loss |
| Validation loss for the epoch. |
| lr |
| Current learning rate. |
| """ |
|
|
| mlflow.log_metrics( |
| { |
| "train_loss": train_loss, |
| "val_loss": val_loss, |
| "learning_rate": lr, |
| }, |
| step=epoch, |
| ) |
|
|
|
|
| def fix_onnx_dynamic_batch(model_path: str | Path) -> None: |
| """ |
| Fix hardcoded batch dimensions in an *ONNX* model. |
| |
| *PyTorch*'s ``torch.onnx.export`` may bake intermediate tensor shapes |
| from the trace (typically ``batch=1``) into ``value_info`` annotations. |
| *ONNX Runtime*'s graph optimiser then uses those shapes to insert |
| ``Reshape`` nodes (e.g. ``gemm_input_reshape``) that fail at runtime |
| with a different batch size. |
| |
| This function clears all intermediate shape annotations and marks |
| output batch dimensions as dynamic so that the model accepts any |
| batch size. |
| |
| Parameters |
| ---------- |
| model_path |
| Path to the *ONNX* model file. The file is overwritten in place. |
| """ |
|
|
| import onnx |
|
|
| model = onnx.load(str(model_path)) |
| graph = model.graph |
|
|
| |
| |
| for vi in graph.value_info: |
| if vi.type.tensor_type.HasField("shape"): |
| vi.type.tensor_type.ClearField("shape") |
|
|
| |
| for out in graph.output: |
| shape = out.type.tensor_type.shape |
| if shape.dim and shape.dim[0].dim_value > 0: |
| shape.dim[0].dim_value = 0 |
| shape.dim[0].dim_param = "batch_size" |
|
|
| onnx.save(model, str(model_path)) |
|
|
|
|
| def get_model_size_mb(file_paths: list[Path]) -> float: |
| """Get total size of model files in MB (includes .data files).""" |
| total_bytes = 0 |
| for f in file_paths: |
| if f.exists(): |
| total_bytes += os.path.getsize(f) |
| |
| data_file = Path(str(f) + ".data") |
| if data_file.exists(): |
| total_bytes += os.path.getsize(data_file) |
| return total_bytes / (1024 * 1024) |
|
|
|
|
| def benchmark_inference_speed( |
| session_callable: Callable, |
| input_data: np.ndarray, |
| num_iterations: int = 10, |
| warmup_iterations: int = 3, |
| ) -> float: |
| """ |
| Benchmark inference speed in milliseconds per sample. |
| |
| Parameters |
| ---------- |
| session_callable |
| Function that performs inference |
| input_data |
| Input data for inference |
| num_iterations |
| Number of iterations for benchmarking |
| warmup_iterations |
| Number of warmup iterations |
| |
| Returns |
| ------- |
| float |
| Average time per sample in milliseconds |
| """ |
| |
| for _ in range(warmup_iterations): |
| session_callable() |
|
|
| |
| start_time = time.perf_counter() |
| for _ in range(num_iterations): |
| session_callable() |
| end_time = time.perf_counter() |
|
|
| total_time_ms = (end_time - start_time) * 1000 |
| time_per_iteration_ms = total_time_ms / num_iterations |
| return time_per_iteration_ms / len(input_data) |
|
|
|
|
| def generate_html_report_header( |
| title: str, |
| subtitle: str, |
| num_samples: int, |
| ) -> str: |
| """Generate HTML report header with styling.""" |
| return f"""<!DOCTYPE html> |
| <html lang="en" class="dark"> |
| <head> |
| <meta charset="UTF-8"> |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| <title>{title} - {datetime.now().strftime("%Y-%m-%d %H:%M")}</title> |
| <script src="https://cdn.tailwindcss.com"></script> |
| <script> |
| tailwind.config = {{ |
| darkMode: 'class', |
| theme: {{ |
| extend: {{ |
| colors: {{ |
| border: "hsl(240 3.7% 15.9%)", |
| input: "hsl(240 3.7% 15.9%)", |
| ring: "hsl(240 4.9% 83.9%)", |
| background: "hsl(240 10% 3.9%)", |
| foreground: "hsl(0 0% 98%)", |
| primary: {{ |
| DEFAULT: "hsl(263 70% 60%)", |
| foreground: "hsl(0 0% 98%)", |
| }}, |
| secondary: {{ |
| DEFAULT: "hsl(240 3.7% 15.9%)", |
| foreground: "hsl(0 0% 98%)", |
| }}, |
| muted: {{ |
| DEFAULT: "hsl(240 3.7% 15.9%)", |
| foreground: "hsl(240 5% 64.9%)", |
| }}, |
| accent: {{ |
| DEFAULT: "hsl(240 3.7% 15.9%)", |
| foreground: "hsl(0 0% 98%)", |
| }}, |
| card: {{ |
| DEFAULT: "hsl(240 10% 6%)", |
| foreground: "hsl(0 0% 98%)", |
| }}, |
| }} |
| }} |
| }} |
| }} |
| </script> |
| <style> |
| .gradient-primary {{ |
| background: linear-gradient( |
| 135deg, hsl(263 70% 50%) 0%, hsl(280 70% 45%) 100%); |
| }} |
| .bar-fill {{ |
| background: linear-gradient( |
| 90deg, hsl(263 70% 60%) 0%, hsl(280 70% 55%) 100%); |
| transition: width 0.5s cubic-bezier(0.4, 0, 0.2, 1); |
| }} |
| </style> |
| </head> |
| <body class="bg-background text-foreground antialiased"> |
| <div class="max-w-7xl mx-auto p-6 space-y-6"> |
| <!-- Header --> |
| <div class="gradient-primary rounded-lg p-8 shadow-2xl |
| border border-primary/20"> |
| <h1 class="text-4xl font-bold text-white mb-2">{title}</h1> |
| <div class="text-white/90 space-y-1"> |
| <p class="text-lg">{subtitle}</p> |
| <p class="text-sm"> |
| Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p> |
| <p class="text-sm">Test Samples: |
| <span class="font-semibold">{num_samples:,}</span> |
| real Munsell colors</p> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def generate_html_report_footer() -> str: |
| """Generate HTML report footer.""" |
| return """ |
| </div> |
| </body> |
| </html> |
| """ |
|
|
|
|
| def generate_best_models_summary( |
| results: dict, |
| metrics: list[tuple[str, str, str]], |
| ) -> str: |
| """ |
| Generate best models summary section. |
| |
| Parameters |
| ---------- |
| results |
| Dictionary of model results |
| metrics |
| List of (metric_key, display_name, format_string) tuples |
| """ |
| html = """ |
| <!-- Best Models Summary --> |
| <div class="bg-card rounded-lg border border-border p-6 shadow-lg"> |
| <h2 class="text-2xl font-semibold mb-6 pb-3 |
| border-b border-primary/30">Best Models by Metric</h2> |
| <div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4"> |
| """ |
|
|
| for metric_key, display_name, fmt in metrics: |
| |
| valid_results = [ |
| (name, res[metric_key]) |
| for name, res in results.items() |
| if metric_key in res |
| and not (isinstance(res[metric_key], float) and np.isnan(res[metric_key])) |
| ] |
| if not valid_results: |
| continue |
|
|
| best_model, best_value = min(valid_results, key=lambda x: x[1]) |
|
|
| html += f""" |
| <div class="bg-gradient-to-br from-primary/10 to-primary/5 |
| rounded-lg p-5 border border-primary/20"> |
| <div class="text-xs font-semibold text-muted-foreground |
| uppercase tracking-wide mb-2">{display_name}</div> |
| <div class="text-3xl font-bold text-primary mb-3"> |
| {fmt.format(best_value)}</div> |
| <div class="text-sm text-foreground/80">{best_model}</div> |
| </div> |
| """ |
|
|
| html += """ |
| </div> |
| </div> |
| """ |
| return html |
|
|
|
|
| def generate_ranking_section( |
| results: dict, |
| metric_key: str, |
| title: str, |
| lower_is_better: bool = True, |
| ) -> str: |
| """Generate a ranking bar chart section.""" |
| |
| sorted_results = sorted( |
| [ |
| (name, res[metric_key]) |
| for name, res in results.items() |
| if not (isinstance(res[metric_key], float) and np.isnan(res[metric_key])) |
| ], |
| key=lambda x: x[1], |
| reverse=not lower_is_better, |
| ) |
|
|
| if not sorted_results: |
| return "" |
|
|
| max_value = max(v for _, v in sorted_results) if sorted_results else 1.0 |
|
|
| html = f""" |
| <!-- {title} --> |
| <div class="bg-card rounded-lg border border-border p-6 shadow-lg"> |
| <h2 class="text-2xl font-semibold mb-4 pb-2 |
| border-b border-primary/30">{title}</h2> |
| <div class="space-y-1"> |
| """ |
|
|
| for rank, (model_name, value) in enumerate(sorted_results, 1): |
| width_pct = (value / max_value) * 100 if max_value > 0 else 0 |
| html += f""" |
| <div class="flex items-center gap-3 p-2 rounded-md |
| hover:bg-muted/50 transition-colors"> |
| <div class="flex-none w-80 text-sm font-medium"> |
| <span class="text-muted-foreground">{rank}.</span> |
| {model_name} |
| </div> |
| <div class="flex-1 h-6 bg-muted rounded-md overflow-hidden"> |
| <div class="bar-fill h-full rounded-md" |
| style="width: {width_pct}%"></div> |
| </div> |
| <div class="flex-none w-24 text-right font-bold |
| text-primary">{value:.6f}</div> |
| </div> |
| """ |
|
|
| html += """ |
| </div> |
| </div> |
| """ |
| return html |
|
|