KelSolaar's picture
Initial commit.
3c7db92
"""
Common utilities for learning-munsell.
Provides shared functions for MLflow tracking, model comparison reports,
and benchmarking across all training and comparison scripts.
"""
from __future__ import annotations
import logging
import os
import time
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING
import mlflow
import numpy as np
from learning_munsell import PROJECT_ROOT
if TYPE_CHECKING:
from collections.abc import Callable
__all__ = [
"setup_mlflow_experiment",
"log_training_epoch",
"fix_onnx_dynamic_batch",
"get_model_size_mb",
"benchmark_inference_speed",
"generate_html_report_header",
"generate_html_report_footer",
"generate_best_models_summary",
"generate_ranking_section",
]
def setup_mlflow_experiment(direction: str, model_name: str) -> str:
"""
Set up MLflow experiment and return run name.
Parameters
----------
direction
Conversion direction, either "from_xyY" or "to_xyY".
model_name
Name of the model being trained.
Returns
-------
str
Generated run name with timestamp.
"""
mlflow.set_tracking_uri(f"sqlite:///{PROJECT_ROOT / 'mlruns.db'}")
mlflow.set_experiment(f"learning-munsell-{direction}")
# MLflow changes root logger level to WARNING; restore INFO for training output
logging.getLogger().setLevel(logging.INFO)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{model_name}_{timestamp}"
def log_training_epoch(
epoch: int, train_loss: float, val_loss: float, lr: float
) -> None:
"""
Log standard training metrics for an epoch.
Parameters
----------
epoch
Current epoch number.
train_loss
Training loss for the epoch.
val_loss
Validation loss for the epoch.
lr
Current learning rate.
"""
mlflow.log_metrics(
{
"train_loss": train_loss,
"val_loss": val_loss,
"learning_rate": lr,
},
step=epoch,
)
def fix_onnx_dynamic_batch(model_path: str | Path) -> None:
"""
Fix hardcoded batch dimensions in an *ONNX* model.
*PyTorch*'s ``torch.onnx.export`` may bake intermediate tensor shapes
from the trace (typically ``batch=1``) into ``value_info`` annotations.
*ONNX Runtime*'s graph optimiser then uses those shapes to insert
``Reshape`` nodes (e.g. ``gemm_input_reshape``) that fail at runtime
with a different batch size.
This function clears all intermediate shape annotations and marks
output batch dimensions as dynamic so that the model accepts any
batch size.
Parameters
----------
model_path
Path to the *ONNX* model file. The file is overwritten in place.
"""
import onnx # noqa: PLC0415
model = onnx.load(str(model_path))
graph = model.graph
# Clear intermediate shape annotations so that ORT computes shapes
# at runtime instead of relying on baked-in values.
for vi in graph.value_info:
if vi.type.tensor_type.HasField("shape"):
vi.type.tensor_type.ClearField("shape")
# Mark the first dimension of every output as dynamic.
for out in graph.output:
shape = out.type.tensor_type.shape
if shape.dim and shape.dim[0].dim_value > 0:
shape.dim[0].dim_value = 0
shape.dim[0].dim_param = "batch_size"
onnx.save(model, str(model_path))
def get_model_size_mb(file_paths: list[Path]) -> float:
"""Get total size of model files in MB (includes .data files)."""
total_bytes = 0
for f in file_paths:
if f.exists():
total_bytes += os.path.getsize(f)
# Also include .data files for ONNX external data
data_file = Path(str(f) + ".data")
if data_file.exists():
total_bytes += os.path.getsize(data_file)
return total_bytes / (1024 * 1024)
def benchmark_inference_speed(
session_callable: Callable,
input_data: np.ndarray,
num_iterations: int = 10,
warmup_iterations: int = 3,
) -> float:
"""
Benchmark inference speed in milliseconds per sample.
Parameters
----------
session_callable
Function that performs inference
input_data
Input data for inference
num_iterations
Number of iterations for benchmarking
warmup_iterations
Number of warmup iterations
Returns
-------
float
Average time per sample in milliseconds
"""
# Warmup
for _ in range(warmup_iterations):
session_callable()
# Benchmark
start_time = time.perf_counter()
for _ in range(num_iterations):
session_callable()
end_time = time.perf_counter()
total_time_ms = (end_time - start_time) * 1000
time_per_iteration_ms = total_time_ms / num_iterations
return time_per_iteration_ms / len(input_data)
def generate_html_report_header(
title: str,
subtitle: str,
num_samples: int,
) -> str:
"""Generate HTML report header with styling."""
return f"""<!DOCTYPE html>
<html lang="en" class="dark">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{title} - {datetime.now().strftime("%Y-%m-%d %H:%M")}</title>
<script src="https://cdn.tailwindcss.com"></script>
<script>
tailwind.config = {{
darkMode: 'class',
theme: {{
extend: {{
colors: {{
border: "hsl(240 3.7% 15.9%)",
input: "hsl(240 3.7% 15.9%)",
ring: "hsl(240 4.9% 83.9%)",
background: "hsl(240 10% 3.9%)",
foreground: "hsl(0 0% 98%)",
primary: {{
DEFAULT: "hsl(263 70% 60%)",
foreground: "hsl(0 0% 98%)",
}},
secondary: {{
DEFAULT: "hsl(240 3.7% 15.9%)",
foreground: "hsl(0 0% 98%)",
}},
muted: {{
DEFAULT: "hsl(240 3.7% 15.9%)",
foreground: "hsl(240 5% 64.9%)",
}},
accent: {{
DEFAULT: "hsl(240 3.7% 15.9%)",
foreground: "hsl(0 0% 98%)",
}},
card: {{
DEFAULT: "hsl(240 10% 6%)",
foreground: "hsl(0 0% 98%)",
}},
}}
}}
}}
}}
</script>
<style>
.gradient-primary {{
background: linear-gradient(
135deg, hsl(263 70% 50%) 0%, hsl(280 70% 45%) 100%);
}}
.bar-fill {{
background: linear-gradient(
90deg, hsl(263 70% 60%) 0%, hsl(280 70% 55%) 100%);
transition: width 0.5s cubic-bezier(0.4, 0, 0.2, 1);
}}
</style>
</head>
<body class="bg-background text-foreground antialiased">
<div class="max-w-7xl mx-auto p-6 space-y-6">
<!-- Header -->
<div class="gradient-primary rounded-lg p-8 shadow-2xl
border border-primary/20">
<h1 class="text-4xl font-bold text-white mb-2">{title}</h1>
<div class="text-white/90 space-y-1">
<p class="text-lg">{subtitle}</p>
<p class="text-sm">
Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p>
<p class="text-sm">Test Samples:
<span class="font-semibold">{num_samples:,}</span>
real Munsell colors</p>
</div>
</div>
"""
def generate_html_report_footer() -> str:
"""Generate HTML report footer."""
return """
</div>
</body>
</html>
"""
def generate_best_models_summary(
results: dict,
metrics: list[tuple[str, str, str]],
) -> str:
"""
Generate best models summary section.
Parameters
----------
results
Dictionary of model results
metrics
List of (metric_key, display_name, format_string) tuples
"""
html = """
<!-- Best Models Summary -->
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
<h2 class="text-2xl font-semibold mb-6 pb-3
border-b border-primary/30">Best Models by Metric</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
"""
for metric_key, display_name, fmt in metrics:
# Find best model for this metric
valid_results = [
(name, res[metric_key])
for name, res in results.items()
if metric_key in res
and not (isinstance(res[metric_key], float) and np.isnan(res[metric_key]))
]
if not valid_results:
continue
best_model, best_value = min(valid_results, key=lambda x: x[1])
html += f"""
<div class="bg-gradient-to-br from-primary/10 to-primary/5
rounded-lg p-5 border border-primary/20">
<div class="text-xs font-semibold text-muted-foreground
uppercase tracking-wide mb-2">{display_name}</div>
<div class="text-3xl font-bold text-primary mb-3">
{fmt.format(best_value)}</div>
<div class="text-sm text-foreground/80">{best_model}</div>
</div>
"""
html += """
</div>
</div>
"""
return html
def generate_ranking_section(
results: dict,
metric_key: str,
title: str,
lower_is_better: bool = True,
) -> str:
"""Generate a ranking bar chart section."""
# Sort results
sorted_results = sorted(
[
(name, res[metric_key])
for name, res in results.items()
if not (isinstance(res[metric_key], float) and np.isnan(res[metric_key]))
],
key=lambda x: x[1],
reverse=not lower_is_better,
)
if not sorted_results:
return ""
max_value = max(v for _, v in sorted_results) if sorted_results else 1.0
html = f"""
<!-- {title} -->
<div class="bg-card rounded-lg border border-border p-6 shadow-lg">
<h2 class="text-2xl font-semibold mb-4 pb-2
border-b border-primary/30">{title}</h2>
<div class="space-y-1">
"""
for rank, (model_name, value) in enumerate(sorted_results, 1):
width_pct = (value / max_value) * 100 if max_value > 0 else 0
html += f"""
<div class="flex items-center gap-3 p-2 rounded-md
hover:bg-muted/50 transition-colors">
<div class="flex-none w-80 text-sm font-medium">
<span class="text-muted-foreground">{rank}.</span>
{model_name}
</div>
<div class="flex-1 h-6 bg-muted rounded-md overflow-hidden">
<div class="bar-fill h-full rounded-md"
style="width: {width_pct}%"></div>
</div>
<div class="flex-none w-24 text-right font-bold
text-primary">{value:.6f}</div>
</div>
"""
html += """
</div>
</div>
"""
return html