asdf98 commited on
Commit
0dbbcfa
·
verified ·
1 Parent(s): 1392e15

Add v3 training script (self-contained, works with HF Jobs)

Browse files
Files changed (1) hide show
  1. train_v3.py +552 -0
train_v3.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BokehFlow v3 Training Script
3
+ Trains on RealBokeh_3MP dataset (timseizinger/RealBokeh_3MP)
4
+
5
+ Self-contained — all model code is inline so this works as a standalone
6
+ script in HF Jobs or any GPU environment.
7
+
8
+ Usage:
9
+ # Quick test (200 scenes, 3 epochs)
10
+ VARIANT=small MAX_SCENES=200 EPOCHS=3 BATCH_SIZE=4 python train_v3.py
11
+
12
+ # Full training (all 3960 scenes, 10 epochs)
13
+ VARIANT=small EPOCHS=10 BATCH_SIZE=8 python train_v3.py
14
+
15
+ Environment variables:
16
+ VARIANT: nano/small/base (default: small)
17
+ MAX_SCENES: limit scenes for testing (default: 0 = all)
18
+ EPOCHS: number of epochs (default: 10)
19
+ BATCH_SIZE: batch size (default: 4)
20
+ CROP_SIZE: random crop size (default: 256)
21
+ LR: learning rate (default: 2e-4)
22
+ HUB_MODEL_ID: HF model repo to push to (default: asdf98/BokehFlow)
23
+
24
+ Requirements:
25
+ pip install torch torchvision Pillow huggingface_hub trackio aiohttp
26
+ """
27
+
28
+ import os, sys, time, json, math, random, glob
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.utils.data import Dataset, DataLoader
33
+ from pathlib import Path
34
+ from dataclasses import dataclass
35
+
36
+
37
+ # ===================================================================
38
+ # Model (inline — identical to bokehflow_v3.py)
39
+ # ===================================================================
40
+
41
+ @dataclass
42
+ class BokehFlowConfig:
43
+ variant: str = "small"
44
+ embed_dim: int = 96
45
+ depth_blocks: int = 6
46
+ bokeh_blocks: int = 6
47
+ fusion_every: int = 2
48
+ stem_channels: int = 48
49
+ patch_stride: int = 4
50
+ max_coc_radius: int = 31
51
+ num_depth_layers: int = 8
52
+ aperture_embed_dim: int = 64
53
+ dropout: float = 0.0
54
+ sensor_width_mm: float = 36.0
55
+ default_focal_mm: float = 50.0
56
+ default_fnumber: float = 2.0
57
+ default_focus_m: float = 2.0
58
+ ffn_expansion: int = 2
59
+ large_kernel: int = 7
60
+
61
+ def __post_init__(self):
62
+ if self.variant == "nano":
63
+ self.embed_dim = 48
64
+ self.depth_blocks = 4
65
+ self.bokeh_blocks = 4
66
+ elif self.variant == "small":
67
+ self.embed_dim = 96
68
+ self.depth_blocks = 6
69
+ self.bokeh_blocks = 6
70
+ elif self.variant == "base":
71
+ self.embed_dim = 192
72
+ self.depth_blocks = 8
73
+ self.bokeh_blocks = 8
74
+
75
+
76
+ class GatedConvRecurrence(nn.Module):
77
+ def __init__(self, dim, kernel_size=7, ffn_expansion=2):
78
+ super().__init__()
79
+ k = kernel_size; p = k // 2
80
+ self.norm1 = nn.GroupNorm(8, dim)
81
+ self.dw1 = nn.Conv2d(dim, dim, k, padding=p, groups=dim, bias=False)
82
+ self.dw2 = nn.Conv2d(dim, dim, k, padding=p, groups=dim, bias=False)
83
+ self.pw = nn.Conv2d(dim, dim, 1, bias=False)
84
+ self.gate_proj = nn.Conv2d(dim, dim, 1, bias=True)
85
+ self.norm2 = nn.GroupNorm(8, dim)
86
+ h = dim * ffn_expansion
87
+ self.ffn = nn.Sequential(nn.Conv2d(dim, h, 1, bias=False), nn.GELU(), nn.Conv2d(h, dim, 1, bias=False))
88
+ nn.init.zeros_(self.pw.weight)
89
+ nn.init.zeros_(self.ffn[-1].weight)
90
+
91
+ def forward(self, x):
92
+ h = self.norm1(x)
93
+ spatial = self.dw2(F.silu(self.dw1(h)))
94
+ spatial = self.pw(spatial)
95
+ gate = torch.sigmoid(self.gate_proj(h))
96
+ x = x + spatial * gate
97
+ x = x + self.ffn(self.norm2(x))
98
+ return x
99
+
100
+
101
+ class GatedConvRecurrenceWithACFM(GatedConvRecurrence):
102
+ def __init__(self, dim, kernel_size=7, ffn_expansion=2, aperture_embed_dim=64):
103
+ super().__init__(dim, kernel_size, ffn_expansion)
104
+ self.acfm = nn.Linear(aperture_embed_dim, dim * 2)
105
+ nn.init.zeros_(self.acfm.weight)
106
+ self.acfm.bias.data[:dim] = 1.0
107
+ self.acfm.bias.data[dim:] = 0.0
108
+
109
+ def forward(self, x, aperture_embed=None):
110
+ x = super().forward(x)
111
+ if aperture_embed is not None:
112
+ B, C, H, W = x.shape
113
+ ss = self.acfm(aperture_embed)
114
+ scale = ss[:, :C].view(B, C, 1, 1)
115
+ shift = ss[:, C:].view(B, C, 1, 1)
116
+ x = x * scale + shift
117
+ return x
118
+
119
+
120
+ class ConvStem(nn.Module):
121
+ def __init__(self, in_ch=3, stem_ch=48, embed_dim=96):
122
+ super().__init__()
123
+ self.net = nn.Sequential(
124
+ nn.Conv2d(in_ch, stem_ch, 7, stride=2, padding=3, bias=False),
125
+ nn.GroupNorm(8, stem_ch), nn.GELU(),
126
+ nn.Conv2d(stem_ch, stem_ch, 3, stride=2, padding=1, groups=stem_ch, bias=False),
127
+ nn.Conv2d(stem_ch, embed_dim, 1, bias=False),
128
+ nn.GroupNorm(8, embed_dim), nn.GELU())
129
+ def forward(self, x): return self.net(x)
130
+
131
+
132
+ class ApertureEncoder(nn.Module):
133
+ def __init__(self, embed_dim=64):
134
+ super().__init__()
135
+ self.mlp = nn.Sequential(nn.Linear(3, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.GELU())
136
+ self.register_buffer('p_min', torch.tensor([1., 10., 0.1]))
137
+ self.register_buffer('p_max', torch.tensor([22., 200., 100.]))
138
+ def forward(self, f_number, focal_mm, focus_m):
139
+ p = torch.stack([f_number, focal_mm, focus_m], -1)
140
+ return self.mlp(((p - self.p_min) / (self.p_max - self.p_min + 1e-6)).clamp(0,1))
141
+
142
+
143
+ class CrossFusion(nn.Module):
144
+ def __init__(self, d):
145
+ super().__init__()
146
+ self.gate_d = nn.Sequential(nn.Conv2d(d, d, 1), nn.Sigmoid())
147
+ self.gate_b = nn.Sequential(nn.Conv2d(d, d, 1), nn.Sigmoid())
148
+ self.proj_d = nn.Conv2d(d, d, 1, bias=False)
149
+ self.proj_b = nn.Conv2d(d, d, 1, bias=False)
150
+ nn.init.zeros_(self.proj_d.weight); nn.init.zeros_(self.proj_b.weight)
151
+ def forward(self, d_feat, b_feat):
152
+ return (d_feat + self.gate_d(b_feat) * self.proj_d(b_feat),
153
+ b_feat + self.gate_b(d_feat) * self.proj_b(d_feat))
154
+
155
+
156
+ class DepthHead(nn.Module):
157
+ def __init__(self, dim=96):
158
+ super().__init__()
159
+ self.net = nn.Sequential(
160
+ nn.Conv2d(dim, dim//2, 3, padding=1), nn.GELU(),
161
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
162
+ nn.Conv2d(dim//2, dim//4, 3, padding=1), nn.GELU(),
163
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
164
+ nn.Conv2d(dim//4, 1, 3, padding=1), nn.Softplus())
165
+ def forward(self, x): return self.net(x).clamp(max=100.0)
166
+
167
+
168
+ class BokehHead(nn.Module):
169
+ def __init__(self, dim=96):
170
+ super().__init__()
171
+ self.net = nn.Sequential(
172
+ nn.Conv2d(dim, dim, 3, padding=1), nn.GELU(),
173
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
174
+ nn.Conv2d(dim, dim//2, 3, padding=1), nn.GELU(),
175
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
176
+ nn.Conv2d(dim//2, 3, 3, padding=1))
177
+ def forward(self, x): return self.net(x)
178
+
179
+
180
+ class PGCoC(nn.Module):
181
+ def __init__(self, sensor_width=36.0, max_radius=31, n_levels=5):
182
+ super().__init__()
183
+ self.sensor_width = sensor_width
184
+ self.max_radius = max_radius
185
+ self.n_levels = n_levels
186
+ self.kernels = nn.ParameterList()
187
+ for i in range(n_levels):
188
+ sigma = (i + 1) * max_radius / n_levels / 3.0
189
+ ks = int(sigma * 6) | 1; ks = max(ks, 3); ks = min(ks, 31)
190
+ k1d = torch.exp(-torch.arange(-(ks//2), ks//2+1).float()**2 / (2*sigma**2+1e-6))
191
+ k1d = k1d / k1d.sum()
192
+ k2d = k1d.unsqueeze(1) @ k1d.unsqueeze(0)
193
+ self.kernels.append(nn.Parameter(k2d.unsqueeze(0).unsqueeze(0), requires_grad=False))
194
+ self.refine = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.GELU(), nn.Conv2d(16, 3, 3, padding=1))
195
+
196
+ def _blur_at_level(self, image, kernel):
197
+ B, C, H, W = image.shape
198
+ k = kernel.expand(C, -1, -1, -1)
199
+ p = kernel.shape[-1] // 2
200
+ return F.conv2d(F.pad(image, [p]*4, mode='reflect'), k, groups=C)
201
+
202
+ def forward(self, image, depth, f_number, focal_mm, focus_m):
203
+ B, C, H, W = image.shape
204
+ f = focal_mm.view(-1,1,1,1); N = f_number.view(-1,1,1,1)
205
+ S1 = (focus_m.view(-1,1,1,1) * 1000).clamp(min=51)
206
+ D = (depth * 1000).clamp(min=100)
207
+ coc = (f**2 / (N * (S1 - f).clamp(min=1))) * (D - S1).abs() / D
208
+ coc_px = (coc * W / self.sensor_width / 2).clamp(0, self.max_radius)
209
+ coc_norm = coc_px / self.max_radius
210
+ blurred_levels = [self._blur_at_level(image, kernel) for kernel in self.kernels]
211
+ level_float = coc_norm * (self.n_levels - 1)
212
+ level_low = level_float.long().clamp(0, self.n_levels - 2)
213
+ level_frac = (level_float - level_low.float()).clamp(0, 1)
214
+ rendered = image.clone()
215
+ for lv in range(self.n_levels - 1):
216
+ mask = (level_low == lv).float()
217
+ if mask.sum() > 0:
218
+ interp = blurred_levels[lv] * (1 - level_frac) + blurred_levels[lv + 1] * level_frac
219
+ rendered = rendered * (1 - mask) + interp * mask
220
+ mask_top = (level_low >= self.n_levels - 2).float() * (level_frac > 0.99).float()
221
+ rendered = rendered * (1 - mask_top) + blurred_levels[-1] * mask_top
222
+ rendered = rendered + self.refine(rendered) * 0.1
223
+ return rendered, coc_px
224
+
225
+
226
+ class BokehFlow(nn.Module):
227
+ def __init__(self, config=None):
228
+ super().__init__()
229
+ if config is None: config = BokehFlowConfig()
230
+ self.config = config; c = config
231
+ self.stem = ConvStem(3, c.stem_channels, c.embed_dim)
232
+ self.aperture_enc = ApertureEncoder(c.aperture_embed_dim)
233
+ self.depth_blocks = nn.ModuleList([
234
+ GatedConvRecurrence(c.embed_dim, c.large_kernel, c.ffn_expansion)
235
+ for _ in range(c.depth_blocks)])
236
+ self.bokeh_blocks = nn.ModuleList([
237
+ GatedConvRecurrenceWithACFM(c.embed_dim, c.large_kernel, c.ffn_expansion, c.aperture_embed_dim)
238
+ for _ in range(c.bokeh_blocks)])
239
+ n_fusions = max(c.depth_blocks, c.bokeh_blocks) // c.fusion_every
240
+ self.fusions = nn.ModuleList([CrossFusion(c.embed_dim) for _ in range(n_fusions)])
241
+ self.depth_head = DepthHead(c.embed_dim)
242
+ self.bokeh_head = BokehHead(c.embed_dim)
243
+ self.pgcoc = PGCoC(c.sensor_width_mm, c.max_coc_radius)
244
+ self.blend_w = nn.Parameter(torch.tensor(0.5))
245
+
246
+ def forward(self, image, f_number=None, focal_length_mm=None, focus_distance_m=None, **kw):
247
+ B = image.shape[0]; dev = image.device; c = self.config
248
+ if f_number is None: f_number = torch.full((B,), c.default_fnumber, device=dev)
249
+ if focal_length_mm is None: focal_length_mm = torch.full((B,), c.default_focal_mm, device=dev)
250
+ if focus_distance_m is None: focus_distance_m = torch.full((B,), c.default_focus_m, device=dev)
251
+ ae = self.aperture_enc(f_number, focal_length_mm, focus_distance_m)
252
+ feat = self.stem(image)
253
+ d_feat = feat; b_feat = feat; fi = 0
254
+ n_blk = max(c.depth_blocks, c.bokeh_blocks)
255
+ for i in range(n_blk):
256
+ if i < c.depth_blocks: d_feat = self.depth_blocks[i](d_feat)
257
+ if i < c.bokeh_blocks: b_feat = self.bokeh_blocks[i](b_feat, ae)
258
+ if (i+1) % c.fusion_every == 0 and fi < len(self.fusions):
259
+ d_feat, b_feat = self.fusions[fi](d_feat, b_feat); fi += 1
260
+ depth = self.depth_head(d_feat)
261
+ if depth.shape[2:] != image.shape[2:]:
262
+ depth = F.interpolate(depth, image.shape[2:], mode='bilinear', align_corners=False)
263
+ physics_bokeh, coc_map = self.pgcoc(image, depth, f_number, focal_length_mm, focus_distance_m)
264
+ learned_bokeh = self.bokeh_head(b_feat)
265
+ if learned_bokeh.shape[2:] != image.shape[2:]:
266
+ learned_bokeh = F.interpolate(learned_bokeh, image.shape[2:], mode='bilinear', align_corners=False)
267
+ w = torch.sigmoid(self.blend_w)
268
+ bokeh = (w * physics_bokeh + (1-w) * (image + learned_bokeh)).clamp(0, 1)
269
+ return {'bokeh': bokeh, 'depth': depth, 'coc_map': coc_map}
270
+
271
+
272
+ class BokehFlowLoss(nn.Module):
273
+ def forward(self, pred, targets):
274
+ bp, bg = pred['bokeh'], targets['bokeh_gt']
275
+ l1 = F.l1_loss(bp, bg)
276
+ C1, C2 = 0.01**2, 0.03**2
277
+ mu_p = F.avg_pool2d(bp, 11, 1, 5); mu_g = F.avg_pool2d(bg, 11, 1, 5)
278
+ mu_pp = mu_p*mu_p; mu_gg = mu_g*mu_g; mu_pg = mu_p*mu_g
279
+ sig_pp = F.avg_pool2d(bp*bp, 11, 1, 5) - mu_pp
280
+ sig_gg = F.avg_pool2d(bg*bg, 11, 1, 5) - mu_gg
281
+ sig_pg = F.avg_pool2d(bp*bg, 11, 1, 5) - mu_pg
282
+ ssim_map = ((2*mu_pg+C1)*(2*sig_pg+C2)) / ((mu_pp+mu_gg+C1)*(sig_pp+sig_gg+C2))
283
+ ssim_loss = 1.0 - ssim_map.mean()
284
+ return {'total': l1 + ssim_loss, 'l1': l1.detach(), 'ssim': ssim_loss.detach()}
285
+
286
+
287
+ # ===================================================================
288
+ # Dataset
289
+ # ===================================================================
290
+
291
+ class RealBokehDataset(Dataset):
292
+ """Loads from local disk after snapshot_download."""
293
+ def __init__(self, root, crop_size=256, split='train', target_fstop='2.0'):
294
+ self.crop = crop_size
295
+ self.pairs = []
296
+ in_dir = Path(root) / split / 'in'
297
+ gt_dir = Path(root) / split / 'gt'
298
+ meta_dir = Path(root) / split / 'metadata'
299
+
300
+ for in_path in sorted(in_dir.glob('*_f22.JPG')):
301
+ sid = in_path.stem.split('_')[0]
302
+ gt_path = gt_dir / sid / f'{sid}_f{target_fstop}.JPG'
303
+ meta_path = meta_dir / f'{sid}.json'
304
+ if gt_path.exists():
305
+ meta = {}
306
+ if meta_path.exists():
307
+ with open(meta_path) as f:
308
+ meta = json.load(f)
309
+ self.pairs.append((str(in_path), str(gt_path), meta))
310
+
311
+ print(f"RealBokehDataset: {len(self.pairs)} pairs found (target f/{target_fstop})")
312
+
313
+ def __len__(self):
314
+ return len(self.pairs)
315
+
316
+ def __getitem__(self, idx):
317
+ from PIL import Image
318
+ import torchvision.transforms.functional as TF
319
+
320
+ in_path, gt_path, meta = self.pairs[idx]
321
+ inp = Image.open(in_path).convert('RGB')
322
+ gt = Image.open(gt_path).convert('RGB')
323
+
324
+ # Resize to manageable size first, then crop
325
+ short = min(inp.size)
326
+ if short > 512:
327
+ scale = 512.0 / short
328
+ new_w = int(inp.size[0] * scale)
329
+ new_h = int(inp.size[1] * scale)
330
+ inp = inp.resize((new_w, new_h), Image.LANCZOS)
331
+ gt = gt.resize((new_w, new_h), Image.LANCZOS)
332
+
333
+ inp = TF.to_tensor(inp)
334
+ gt = TF.to_tensor(gt)
335
+
336
+ # Random crop
337
+ _, h, w = inp.shape
338
+ cs = self.crop
339
+ if h >= cs and w >= cs:
340
+ i = random.randint(0, h - cs)
341
+ j = random.randint(0, w - cs)
342
+ inp = inp[:, i:i+cs, j:j+cs]
343
+ gt = gt[:, i:i+cs, j:j+cs]
344
+ else:
345
+ inp = F.interpolate(inp.unsqueeze(0), (cs, cs), mode='bilinear', align_corners=False).squeeze(0)
346
+ gt = F.interpolate(gt.unsqueeze(0), (cs, cs), mode='bilinear', align_corners=False).squeeze(0)
347
+
348
+ # Random horizontal flip
349
+ if random.random() > 0.5:
350
+ inp = inp.flip(-1)
351
+ gt = gt.flip(-1)
352
+
353
+ focal = meta.get('focal_length', 50.0)
354
+ focus = meta.get('focus_plane_distance', 2.0)
355
+ fnum = 2.0
356
+
357
+ return {
358
+ 'image': inp,
359
+ 'bokeh_gt': gt,
360
+ 'f_number': torch.tensor(fnum, dtype=torch.float32),
361
+ 'focal_length_mm': torch.tensor(float(focal), dtype=torch.float32),
362
+ 'focus_distance_m': torch.tensor(float(focus), dtype=torch.float32),
363
+ }
364
+
365
+
366
+ # ===================================================================
367
+ # Data download
368
+ # ===================================================================
369
+
370
+ def download_realbokeh(max_scenes=None):
371
+ """Download RealBokeh_3MP using snapshot_download with exact patterns."""
372
+ from huggingface_hub import snapshot_download
373
+
374
+ data_dir = '/tmp/realbokeh'
375
+ check_file = Path(data_dir) / 'train' / 'in' / '1_f22.JPG'
376
+ if check_file.exists():
377
+ n = len(list(Path(data_dir).glob('train/in/*_f22.JPG')))
378
+ print(f"Data already cached: {n} scenes")
379
+ return data_dir
380
+
381
+ print("Fetching metadata to build download list...")
382
+ snapshot_download(
383
+ 'timseizinger/RealBokeh_3MP',
384
+ repo_type='dataset',
385
+ local_dir=data_dir,
386
+ allow_patterns=['train/metadata/*.json'],
387
+ )
388
+
389
+ meta_dir = Path(data_dir) / 'train' / 'metadata'
390
+ scene_ids = sorted([p.stem for p in meta_dir.glob('*.json')], key=lambda x: int(x))
391
+
392
+ if max_scenes:
393
+ scene_ids = scene_ids[:max_scenes]
394
+
395
+ print(f"Found {len(scene_ids)} scenes. Downloading input + f/2.0 GT images...")
396
+
397
+ patterns = []
398
+ for sid in scene_ids:
399
+ patterns.append(f'train/in/{sid}_f22.JPG')
400
+ patterns.append(f'train/gt/{sid}/{sid}_f2.0.JPG')
401
+
402
+ t0 = time.time()
403
+ snapshot_download(
404
+ 'timseizinger/RealBokeh_3MP',
405
+ repo_type='dataset',
406
+ local_dir=data_dir,
407
+ allow_patterns=patterns,
408
+ )
409
+ elapsed = time.time() - t0
410
+ n_found = len(list(Path(data_dir).glob('train/in/*_f22.JPG')))
411
+ print(f"Downloaded {n_found} scenes in {elapsed:.0f}s")
412
+ return data_dir
413
+
414
+
415
+ # ===================================================================
416
+ # Training loop
417
+ # ===================================================================
418
+
419
+ def train():
420
+ import trackio
421
+
422
+ VARIANT = os.environ.get('VARIANT', 'small')
423
+ MAX_SCENES = int(os.environ.get('MAX_SCENES', '0')) or None
424
+ EPOCHS = int(os.environ.get('EPOCHS', '10'))
425
+ BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
426
+ CROP_SIZE = int(os.environ.get('CROP_SIZE', '256'))
427
+ LR = float(os.environ.get('LR', '2e-4'))
428
+ HUB_MODEL_ID = os.environ.get('HUB_MODEL_ID', 'asdf98/BokehFlow')
429
+
430
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
431
+ print(f"Device: {device}")
432
+ if device == 'cuda':
433
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
434
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
435
+
436
+ trackio.init(project="bokehflow", name=f"v3-{VARIANT}-e{EPOCHS}-lr{LR}")
437
+
438
+ data_dir = download_realbokeh(max_scenes=MAX_SCENES)
439
+
440
+ ds = RealBokehDataset(data_dir, crop_size=CROP_SIZE)
441
+ dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
442
+ pin_memory=True, drop_last=True, persistent_workers=True)
443
+ print(f"Batches per epoch: {len(dl)}")
444
+
445
+ config = BokehFlowConfig(variant=VARIANT)
446
+ model = BokehFlow(config).to(device)
447
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
448
+ print(f"Model: BokehFlow-{VARIANT}, {n_params:,} params")
449
+
450
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
451
+ total_steps = EPOCHS * len(dl)
452
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=LR/20)
453
+ loss_fn = BokehFlowLoss()
454
+
455
+ scaler = torch.amp.GradScaler('cuda', enabled=(device == 'cuda'))
456
+
457
+ global_step = 0
458
+ best_loss = float('inf')
459
+
460
+ for epoch in range(EPOCHS):
461
+ model.train()
462
+ epoch_loss = 0.0
463
+ t_epoch = time.time()
464
+
465
+ for batch_idx, batch in enumerate(dl):
466
+ t_step = time.time()
467
+ image = batch['image'].to(device)
468
+ bokeh_gt = batch['bokeh_gt'].to(device)
469
+ f_number = batch['f_number'].to(device)
470
+ focal_mm = batch['focal_length_mm'].to(device)
471
+ focus_m = batch['focus_distance_m'].to(device)
472
+
473
+ optimizer.zero_grad(set_to_none=True)
474
+
475
+ with torch.amp.autocast('cuda', enabled=(device == 'cuda')):
476
+ out = model(image, f_number, focal_mm, focus_m)
477
+ losses = loss_fn(out, {'bokeh_gt': bokeh_gt})
478
+ loss = losses['total']
479
+
480
+ scaler.scale(loss).backward()
481
+ scaler.unscale_(optimizer)
482
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
483
+ scaler.step(optimizer)
484
+ scaler.update()
485
+ scheduler.step()
486
+
487
+ epoch_loss += loss.item()
488
+ global_step += 1
489
+ step_time = time.time() - t_step
490
+
491
+ if global_step % 10 == 0 or batch_idx == 0:
492
+ lr_now = scheduler.get_last_lr()[0]
493
+ print(f"Ep {epoch+1}/{EPOCHS} [{batch_idx+1}/{len(dl)}] "
494
+ f"loss={loss.item():.4f} l1={losses['l1'].item():.4f} "
495
+ f"ssim={losses['ssim'].item():.4f} lr={lr_now:.2e} "
496
+ f"step={step_time*1000:.0f}ms")
497
+ trackio.log({
498
+ "loss": loss.item(),
499
+ "l1": losses['l1'].item(),
500
+ "ssim_loss": losses['ssim'].item(),
501
+ "lr": lr_now,
502
+ "step_ms": step_time * 1000,
503
+ "epoch": epoch + 1,
504
+ })
505
+
506
+ if device == 'cuda' and global_step == 1:
507
+ vram = torch.cuda.max_memory_allocated() / 1e9
508
+ print(f"Peak VRAM after 1st step: {vram:.2f} GB")
509
+ trackio.log({"peak_vram_gb": vram})
510
+
511
+ epoch_time = time.time() - t_epoch
512
+ avg_loss = epoch_loss / len(dl)
513
+ print(f"Epoch {epoch+1}/{EPOCHS} done in {epoch_time:.0f}s, avg_loss={avg_loss:.4f}")
514
+ trackio.log({"epoch_avg_loss": avg_loss, "epoch_time_s": epoch_time})
515
+
516
+ if avg_loss < best_loss:
517
+ best_loss = avg_loss
518
+ torch.save({
519
+ 'model_state_dict': model.state_dict(),
520
+ 'config': config.__dict__,
521
+ 'epoch': epoch + 1,
522
+ 'loss': avg_loss,
523
+ }, '/tmp/bokehflow_best.pt')
524
+ print(f" Saved best model (loss={avg_loss:.4f})")
525
+
526
+ # Push to hub
527
+ print("\nPushing model to Hub...")
528
+ from huggingface_hub import HfApi
529
+ api = HfApi()
530
+
531
+ torch.save({
532
+ 'model_state_dict': model.state_dict(),
533
+ 'config': config.__dict__,
534
+ 'epoch': EPOCHS,
535
+ 'loss': avg_loss,
536
+ }, '/tmp/bokehflow_final.pt')
537
+
538
+ for fname in ['bokehflow_best.pt', 'bokehflow_final.pt']:
539
+ fpath = f'/tmp/{fname}'
540
+ if os.path.exists(fpath):
541
+ api.upload_file(
542
+ path_or_fileobj=fpath,
543
+ path_in_repo=f'checkpoints/{fname}',
544
+ repo_id=HUB_MODEL_ID,
545
+ )
546
+ print(f" Uploaded {fname}")
547
+
548
+ print(f"\nTraining complete! Model: https://huggingface.co/{HUB_MODEL_ID}")
549
+
550
+
551
+ if __name__ == '__main__':
552
+ train()