AbstractPhil commited on
Commit
cf76b2b
Β·
verified Β·
1 Parent(s): a6ac8e9

Create trainer_model.py

Browse files
Files changed (1) hide show
  1. trainer_model.py +528 -0
trainer_model.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flow Matching β€” Constellation Bottleneck
4
+ ==========================================
5
+ The constellation IS the bottleneck. Not a regulator. Not a side channel.
6
+ All information passes through S^15 triangulation.
7
+
8
+ Architecture:
9
+ Encoder: 3Γ—32Γ—32 β†’ 64Γ—32 β†’ 128Γ—16 β†’ 256Γ—8
10
+ Bottleneck:
11
+ flatten 256Γ—8Γ—8 = 16384 β†’ Linear(16384, 256) β†’ L2 normalize
12
+ β†’ Constellation: 16 patches Γ— 16d, 16 anchors, 3 phases
13
+ β†’ Triangulation profile: 16 patches Γ— 48 = 768 dims
14
+ β†’ Condition injection: concat(tri, time_emb, class_emb)
15
+ β†’ Patchwork MLP: 768+cond β†’ 256 β†’ 16384 β†’ reshape 256Γ—8Γ—8
16
+ Decoder: 256Γ—8 β†’ 128Γ—16 β†’ 64Γ—32 β†’ 3Γ—32Γ—32
17
+
18
+ The triangulation profile IS the representation.
19
+ Time and class conditioning enter at the triangulation level β€”
20
+ they modulate what the patchwork does with the geometric reading.
21
+ """
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import numpy as np
27
+ import math
28
+ import os
29
+ import time
30
+ from tqdm import tqdm
31
+ from torchvision import datasets, transforms
32
+ from torchvision.utils import save_image, make_grid
33
+
34
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
+ torch.backends.cuda.matmul.allow_tf32 = True
36
+ torch.backends.cudnn.allow_tf32 = True
37
+
38
+
39
+ # ══════════════════════════════════════════════════════════════════
40
+ # CONSTELLATION BOTTLENECK
41
+ # ══════════════════════════════════════════════════════════════════
42
+
43
+ class ConstellationBottleneck(nn.Module):
44
+ """
45
+ The constellation as information bottleneck.
46
+
47
+ Input: (B, spatial_dim) flattened feature map
48
+ Output: (B, spatial_dim) reconstructed through geometric encoding
49
+
50
+ All information passes through S^(d-1) triangulation.
51
+ Time + class conditioning injected at the triangulation level.
52
+ """
53
+ def __init__(
54
+ self,
55
+ spatial_dim, # 256*8*8 = 16384
56
+ embed_dim=256, # project to this before sphere
57
+ patch_dim=16,
58
+ n_anchors=16,
59
+ n_phases=3,
60
+ cond_dim=256,
61
+ pw_hidden=512,
62
+ ):
63
+ super().__init__()
64
+ self.spatial_dim = spatial_dim
65
+ self.embed_dim = embed_dim
66
+ self.patch_dim = patch_dim
67
+ self.n_patches = embed_dim // patch_dim
68
+ self.n_anchors = n_anchors
69
+ self.n_phases = n_phases
70
+
71
+ P, A, d = self.n_patches, n_anchors, patch_dim
72
+
73
+ # Project feature map β†’ embedding sphere
74
+ self.proj_in = nn.Linear(spatial_dim, embed_dim)
75
+ self.proj_in_norm = nn.LayerNorm(embed_dim)
76
+
77
+ # Constellation anchors
78
+ home = torch.empty(P, A, d)
79
+ nn.init.xavier_normal_(home.view(P * A, d))
80
+ home = F.normalize(home.view(P, A, d), dim=-1)
81
+ self.register_buffer('home', home)
82
+ self.anchors = nn.Parameter(home.clone())
83
+
84
+ # Triangulation β†’ total dims = P * (A * n_phases)
85
+ tri_dim_per_patch = A * n_phases
86
+ total_tri_dim = P * tri_dim_per_patch
87
+
88
+ # Patchwork reads triangulation + conditioning
89
+ # This is where time and class information enter
90
+ pw_input = total_tri_dim + cond_dim
91
+ self.patchwork = nn.Sequential(
92
+ nn.Linear(pw_input, pw_hidden),
93
+ nn.GELU(),
94
+ nn.LayerNorm(pw_hidden),
95
+ nn.Linear(pw_hidden, pw_hidden),
96
+ nn.GELU(),
97
+ nn.LayerNorm(pw_hidden),
98
+ nn.Linear(pw_hidden, spatial_dim),
99
+ )
100
+
101
+ # Skip projection β€” residual through the bottleneck
102
+ self.skip_proj = nn.Linear(spatial_dim, spatial_dim)
103
+ self.skip_gate = nn.Parameter(torch.tensor(-2.0)) # sigmoid β‰ˆ 0.12
104
+
105
+ def drift(self):
106
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
107
+ return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
108
+
109
+ def at_phase(self, t):
110
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
111
+ omega = self.drift().unsqueeze(-1)
112
+ so = omega.sin().clamp(min=1e-7)
113
+ return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
114
+
115
+ def triangulate(self, emb_norm):
116
+ """
117
+ Multi-phase triangulation on the sphere.
118
+ emb_norm: (B, P, d) normalized patches on S^(d-1)
119
+ Returns: (B, P * A * n_phases) full triangulation profile
120
+ """
121
+ phases = torch.linspace(0, 1, self.n_phases, device=emb_norm.device).tolist()
122
+ tris = []
123
+ for t in phases:
124
+ anchors_t = F.normalize(self.at_phase(t), dim=-1) # (P, A, d)
125
+ cos = torch.einsum('bpd,pad->bpa', emb_norm, anchors_t)
126
+ tris.append(1.0 - cos)
127
+ # (B, P, A*phases) β†’ flatten β†’ (B, P*A*phases)
128
+ tri = torch.cat(tris, dim=-1)
129
+ return tri.reshape(emb_norm.shape[0], -1)
130
+
131
+ def forward(self, x_flat, cond):
132
+ """
133
+ x_flat: (B, spatial_dim) β€” flattened bottleneck features
134
+ cond: (B, cond_dim) β€” time + class conditioning
135
+ Returns: (B, spatial_dim)
136
+ """
137
+ B = x_flat.shape[0]
138
+
139
+ # Project to embedding space β†’ normalize to sphere
140
+ emb = self.proj_in(x_flat)
141
+ emb = self.proj_in_norm(emb)
142
+ patches = emb.reshape(B, self.n_patches, self.patch_dim)
143
+ patches_n = F.normalize(patches, dim=-1) # on S^(d-1)
144
+
145
+ # Triangulate β€” the geometric encoding
146
+ tri_profile = self.triangulate(patches_n) # (B, P*A*phases)
147
+
148
+ # Inject conditioning at the triangulation level
149
+ pw_input = torch.cat([tri_profile, cond], dim=-1)
150
+
151
+ # Patchwork reads the geometric profile + conditioning
152
+ decoded = self.patchwork(pw_input) # (B, spatial_dim)
153
+
154
+ # Gated skip connection through the bottleneck
155
+ skip = self.skip_proj(x_flat)
156
+ gate = self.skip_gate.sigmoid()
157
+ return gate * skip + (1 - gate) * decoded
158
+
159
+
160
+ # ══════════════════════════════════════════════════════════════════
161
+ # UNET BUILDING BLOCKS
162
+ # ══════════════════════════════════════════════════════════════════
163
+
164
+ class SinusoidalPosEmb(nn.Module):
165
+ def __init__(self, dim):
166
+ super().__init__()
167
+ self.dim = dim
168
+
169
+ def forward(self, t):
170
+ half = self.dim // 2
171
+ emb = math.log(10000) / (half - 1)
172
+ emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
173
+ emb = t.unsqueeze(-1) * emb.unsqueeze(0)
174
+ return torch.cat([emb.sin(), emb.cos()], dim=-1)
175
+
176
+
177
+ class AdaGroupNorm(nn.Module):
178
+ def __init__(self, channels, cond_dim, n_groups=8):
179
+ super().__init__()
180
+ self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False)
181
+ self.proj = nn.Linear(cond_dim, channels * 2)
182
+ nn.init.zeros_(self.proj.weight)
183
+ nn.init.zeros_(self.proj.bias)
184
+
185
+ def forward(self, x, cond):
186
+ x = self.gn(x)
187
+ scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
188
+ return x * (1 + scale) + shift
189
+
190
+
191
+ class ConvBlock(nn.Module):
192
+ def __init__(self, channels, cond_dim):
193
+ super().__init__()
194
+ self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
195
+ self.norm = AdaGroupNorm(channels, cond_dim)
196
+ self.pw1 = nn.Conv2d(channels, channels * 4, 1)
197
+ self.pw2 = nn.Conv2d(channels * 4, channels, 1)
198
+ self.act = nn.GELU()
199
+
200
+ def forward(self, x, cond):
201
+ residual = x
202
+ x = self.dw_conv(x)
203
+ x = self.norm(x, cond)
204
+ x = self.pw1(x)
205
+ x = self.act(x)
206
+ x = self.pw2(x)
207
+ return residual + x
208
+
209
+
210
+ class Downsample(nn.Module):
211
+ def __init__(self, ch):
212
+ super().__init__()
213
+ self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
214
+ def forward(self, x):
215
+ return self.conv(x)
216
+
217
+
218
+ class Upsample(nn.Module):
219
+ def __init__(self, ch):
220
+ super().__init__()
221
+ self.conv = nn.Conv2d(ch, ch, 3, padding=1)
222
+ def forward(self, x):
223
+ return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
224
+
225
+
226
+ # ══════════════════════════════════════════════════════════════════
227
+ # FLOW MATCHING UNET WITH CONSTELLATION BOTTLENECK
228
+ # ══════════════════════════════════════════════════════════════════
229
+
230
+ class FlowMatchConstellationUNet(nn.Module):
231
+ """
232
+ UNet where the middle block IS the constellation.
233
+ No attention. The constellation is the information bottleneck.
234
+
235
+ 32Γ—32 β†’ 16Γ—16 β†’ 8Γ—8 β†’ flatten β†’ project β†’ S^15 β†’ triangulate
236
+ β†’ patchwork(tri + time + class) β†’ project back β†’ 8Γ—8 β†’ 16Γ—16 β†’ 32Γ—32
237
+ """
238
+ def __init__(
239
+ self,
240
+ in_channels=3,
241
+ base_ch=64,
242
+ channel_mults=(1, 2, 4),
243
+ n_classes=10,
244
+ cond_dim=256,
245
+ embed_dim=256,
246
+ n_anchors=16,
247
+ n_phases=3,
248
+ pw_hidden=512,
249
+ ):
250
+ super().__init__()
251
+ self.channel_mults = channel_mults
252
+
253
+ # Conditioning
254
+ self.time_emb = nn.Sequential(
255
+ SinusoidalPosEmb(cond_dim),
256
+ nn.Linear(cond_dim, cond_dim), nn.GELU(),
257
+ nn.Linear(cond_dim, cond_dim))
258
+ self.class_emb = nn.Embedding(n_classes, cond_dim)
259
+
260
+ # Input
261
+ self.in_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1)
262
+
263
+ # Encoder
264
+ self.enc = nn.ModuleList()
265
+ self.enc_down = nn.ModuleList()
266
+ ch = base_ch
267
+ enc_channels = [base_ch]
268
+
269
+ for i, mult in enumerate(channel_mults):
270
+ ch_out = base_ch * mult
271
+ self.enc.append(nn.ModuleList([
272
+ ConvBlock(ch, cond_dim) if ch == ch_out
273
+ else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)),
274
+ ConvBlock(ch_out, cond_dim),
275
+ ]))
276
+ ch = ch_out
277
+ enc_channels.append(ch)
278
+ if i < len(channel_mults) - 1:
279
+ self.enc_down.append(Downsample(ch))
280
+
281
+ # Constellation bottleneck
282
+ # At this point: (B, ch, 8, 8) where ch = base_ch * channel_mults[-1]
283
+ mid_ch = ch
284
+ spatial = 8 * 8 # after two downsamples from 32
285
+ spatial_dim = mid_ch * spatial
286
+
287
+ self.bottleneck = ConstellationBottleneck(
288
+ spatial_dim=spatial_dim,
289
+ embed_dim=embed_dim,
290
+ patch_dim=16,
291
+ n_anchors=n_anchors,
292
+ n_phases=n_phases,
293
+ cond_dim=cond_dim,
294
+ pw_hidden=pw_hidden,
295
+ )
296
+ self.mid_ch = mid_ch
297
+ self.mid_spatial = spatial
298
+
299
+ # Decoder
300
+ self.dec_up = nn.ModuleList()
301
+ self.dec_skip_proj = nn.ModuleList()
302
+ self.dec = nn.ModuleList()
303
+
304
+ for i in range(len(channel_mults) - 1, -1, -1):
305
+ ch_out = base_ch * channel_mults[i]
306
+ skip_ch = enc_channels.pop()
307
+ self.dec_skip_proj.append(nn.Conv2d(ch + skip_ch, ch_out, 1))
308
+ self.dec.append(nn.ModuleList([
309
+ ConvBlock(ch_out, cond_dim),
310
+ ConvBlock(ch_out, cond_dim),
311
+ ]))
312
+ ch = ch_out
313
+ if i > 0:
314
+ self.dec_up.append(Upsample(ch))
315
+
316
+ # Output
317
+ self.out_norm = nn.GroupNorm(8, ch)
318
+ self.out_conv = nn.Conv2d(ch, in_channels, 3, padding=1)
319
+ nn.init.zeros_(self.out_conv.weight)
320
+ nn.init.zeros_(self.out_conv.bias)
321
+
322
+ def forward(self, x, t, class_labels):
323
+ cond = self.time_emb(t) + self.class_emb(class_labels)
324
+ h = self.in_conv(x)
325
+ skips = [h]
326
+
327
+ # Encoder
328
+ for i in range(len(self.channel_mults)):
329
+ for block in self.enc[i]:
330
+ if isinstance(block, ConvBlock):
331
+ h = block(h, cond)
332
+ elif isinstance(block, nn.Sequential):
333
+ h = block[0](h)
334
+ h = block[1](h, cond)
335
+ skips.append(h)
336
+ if i < len(self.enc_down):
337
+ h = self.enc_down[i](h)
338
+
339
+ # β˜… CONSTELLATION BOTTLENECK β˜…
340
+ B, C, H, W = h.shape
341
+ h_flat = h.reshape(B, -1) # (B, C*H*W)
342
+ h_flat = self.bottleneck(h_flat, cond) # through S^15
343
+ h = h_flat.reshape(B, C, H, W)
344
+
345
+ # Decoder
346
+ for i in range(len(self.channel_mults)):
347
+ skip = skips.pop()
348
+ if i > 0:
349
+ h = self.dec_up[i - 1](h)
350
+ h = torch.cat([h, skip], dim=1)
351
+ h = self.dec_skip_proj[i](h)
352
+ for block in self.dec[i]:
353
+ h = block(h, cond)
354
+
355
+ h = self.out_norm(h)
356
+ h = F.silu(h)
357
+ return self.out_conv(h)
358
+
359
+
360
+ # ══════════════════════════════════════════════════════════════════
361
+ # SAMPLING
362
+ # ══════════════════════════════════════════════════════════════════
363
+
364
+ @torch.no_grad()
365
+ def sample(model, n_samples=64, n_steps=50, class_label=None, n_classes=10):
366
+ model.eval()
367
+ x = torch.randn(n_samples, 3, 32, 32, device=DEVICE)
368
+ if class_label is not None:
369
+ labels = torch.full((n_samples,), class_label, dtype=torch.long, device=DEVICE)
370
+ else:
371
+ labels = torch.randint(0, n_classes, (n_samples,), device=DEVICE)
372
+
373
+ dt = 1.0 / n_steps
374
+ for step in range(n_steps):
375
+ t_val = 1.0 - step * dt
376
+ t = torch.full((n_samples,), t_val, device=DEVICE)
377
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
378
+ v = model(x, t, labels)
379
+ x = x - v.float() * dt
380
+ return x.clamp(-1, 1), labels
381
+
382
+
383
+ # ══════════════════════════════════════════════════════════════════
384
+ # TRAINING
385
+ # ══════════════════════════════════════════════════════════════════
386
+
387
+ BATCH = 128
388
+ EPOCHS = 50
389
+ LR = 3e-4
390
+ N_CLASSES = 10
391
+ SAMPLE_EVERY = 5
392
+
393
+ print("=" * 70)
394
+ print("FLOW MATCHING β€” CONSTELLATION BOTTLENECK")
395
+ print(f" No attention. The constellation IS the bottleneck.")
396
+ print(f" Device: {DEVICE}")
397
+ print("=" * 70)
398
+
399
+ transform = transforms.Compose([
400
+ transforms.RandomHorizontalFlip(),
401
+ transforms.ToTensor(),
402
+ transforms.Normalize((0.5,)*3, (0.5,)*3),
403
+ ])
404
+ train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
405
+ train_loader = torch.utils.data.DataLoader(
406
+ train_ds, batch_size=BATCH, shuffle=True,
407
+ num_workers=4, pin_memory=True, drop_last=True)
408
+
409
+ model = FlowMatchConstellationUNet(
410
+ in_channels=3, base_ch=64, channel_mults=(1, 2, 4),
411
+ n_classes=N_CLASSES, cond_dim=256, embed_dim=256,
412
+ n_anchors=16, n_phases=3, pw_hidden=512,
413
+ ).to(DEVICE)
414
+
415
+ n_params = sum(p.numel() for p in model.parameters())
416
+ n_bottleneck = sum(p.numel() for p in model.bottleneck.parameters())
417
+ print(f" Total params: {n_params:,}")
418
+ print(f" Bottleneck params: {n_bottleneck:,} ({100*n_bottleneck/n_params:.1f}%)")
419
+ print(f" Train: {len(train_ds):,} images")
420
+
421
+ # Verify shapes
422
+ with torch.no_grad():
423
+ dummy = torch.randn(2, 3, 32, 32, device=DEVICE)
424
+ t_dummy = torch.rand(2, device=DEVICE)
425
+ c_dummy = torch.randint(0, 10, (2,), device=DEVICE)
426
+ out = model(dummy, t_dummy, c_dummy)
427
+ print(f" Shape check: {dummy.shape} β†’ {out.shape} βœ“")
428
+
429
+ # Show bottleneck info
430
+ bn = model.bottleneck
431
+ drift = bn.drift()
432
+ print(f" Bottleneck: {bn.spatial_dim}d β†’ {bn.embed_dim}d sphere "
433
+ f"β†’ {bn.n_patches}p Γ— {bn.patch_dim}d Γ— {bn.n_anchors}A Γ— {bn.n_phases}ph "
434
+ f"= {bn.n_patches * bn.n_anchors * bn.n_phases} tri dims")
435
+ print(f" Skip gate init: {bn.skip_gate.sigmoid().item():.4f}")
436
+
437
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
438
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
439
+ optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6)
440
+ scaler = torch.amp.GradScaler("cuda")
441
+
442
+ os.makedirs("samples_bn", exist_ok=True)
443
+ os.makedirs("checkpoints", exist_ok=True)
444
+
445
+ print(f"\n{'='*70}")
446
+ print(f"TRAINING β€” {EPOCHS} epochs")
447
+ print(f"{'='*70}")
448
+
449
+ best_loss = float('inf')
450
+
451
+ for epoch in range(EPOCHS):
452
+ model.train()
453
+ t0 = time.time()
454
+ total_loss = 0
455
+ n = 0
456
+
457
+ pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
458
+ for images, labels in pbar:
459
+ images = images.to(DEVICE, non_blocking=True)
460
+ labels = labels.to(DEVICE, non_blocking=True)
461
+ B = images.shape[0]
462
+
463
+ t = torch.rand(B, device=DEVICE)
464
+ eps = torch.randn_like(images)
465
+ t_b = t.view(B, 1, 1, 1)
466
+ x_t = (1 - t_b) * images + t_b * eps
467
+ v_target = eps - images
468
+
469
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
470
+ v_pred = model(x_t, t, labels)
471
+ loss = F.mse_loss(v_pred, v_target)
472
+
473
+ optimizer.zero_grad(set_to_none=True)
474
+ scaler.scale(loss).backward()
475
+ scaler.unscale_(optimizer)
476
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
477
+ scaler.step(optimizer)
478
+ scaler.update()
479
+ scheduler.step()
480
+
481
+ total_loss += loss.item()
482
+ n += 1
483
+ if n % 20 == 0:
484
+ pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")
485
+
486
+ elapsed = time.time() - t0
487
+ avg_loss = total_loss / n
488
+
489
+ mk = ""
490
+ if avg_loss < best_loss:
491
+ best_loss = avg_loss
492
+ torch.save({
493
+ 'state_dict': model.state_dict(),
494
+ 'epoch': epoch + 1,
495
+ 'loss': avg_loss,
496
+ }, 'checkpoints/constellation_bn_best.pt')
497
+ mk = " β˜…"
498
+
499
+ print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} "
500
+ f"({elapsed:.0f}s){mk}")
501
+
502
+ # Diagnostics
503
+ if (epoch + 1) % 10 == 0:
504
+ bn = model.bottleneck
505
+ drift = bn.drift()
506
+ gate = bn.skip_gate.sigmoid().item()
507
+ print(f" Bottleneck: drift={drift.mean():.4f}rad ({math.degrees(drift.mean()):.1f}Β°) "
508
+ f"max={drift.max():.4f}rad gate={gate:.4f}")
509
+
510
+ # Sample
511
+ if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:
512
+ imgs, _ = sample(model, 64, 50)
513
+ imgs = (imgs + 1) / 2
514
+ save_image(make_grid(imgs, nrow=8), f'samples_bn/epoch_{epoch+1:03d}.png')
515
+ print(f" β†’ samples_bn/epoch_{epoch+1:03d}.png")
516
+
517
+ if (epoch + 1) % (SAMPLE_EVERY * 2) == 0:
518
+ class_names = ['plane','auto','bird','cat','deer',
519
+ 'dog','frog','horse','ship','truck']
520
+ for c in range(N_CLASSES):
521
+ cs, _ = sample(model, 8, 50, class_label=c)
522
+ save_image(make_grid((cs+1)/2, nrow=8),
523
+ f'samples_bn/epoch_{epoch+1:03d}_{class_names[c]}.png')
524
+
525
+ print(f"\n{'='*70}")
526
+ print(f"DONE β€” Best loss: {best_loss:.4f}")
527
+ print(f" Params: {n_params:,} (bottleneck: {n_bottleneck:,})")
528
+ print(f"{'='*70}")