Update svd_triton_gram_newton.py
Browse files- svd_triton_gram_newton.py +263 -23
svd_triton_gram_newton.py
CHANGED
|
@@ -250,27 +250,15 @@ def gram_eigh_svd(A):
|
|
| 250 |
Mathematically exact. The Eckart-Young (1936) shortcut.
|
| 251 |
"""
|
| 252 |
B, M, N = A.shape
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
eigenvalues = eigenvalues.flip(-1)
|
| 263 |
-
V = V.flip(-1)
|
| 264 |
-
|
| 265 |
-
# Singular values = sqrt of eigenvalues
|
| 266 |
-
S = torch.sqrt(eigenvalues.clamp(min=1e-12)) # (B, N)
|
| 267 |
-
|
| 268 |
-
# Phase 3: U = A @ V @ diag(1/S) = (A @ V) / S
|
| 269 |
-
U = torch.bmm(A_f, V) / S.unsqueeze(1) # (B, M, N)
|
| 270 |
-
|
| 271 |
-
# Vh = V^T
|
| 272 |
-
Vh = V.transpose(-2, -1).contiguous() # (B, N, N)
|
| 273 |
-
|
| 274 |
return U, S, Vh
|
| 275 |
|
| 276 |
|
|
@@ -359,10 +347,219 @@ def newton_schulz_invsqrt(G, iters=10):
|
|
| 359 |
|
| 360 |
|
| 361 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
-
# β
|
| 363 |
-
# β
|
| 364 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
def projected_svd(A, target_rank=24, oversampling=8):
|
| 367 |
"""Rank-projected thin SVD for (B, M, N) with large N.
|
| 368 |
|
|
@@ -838,6 +1035,49 @@ def run_validation(B=64, M=1024):
|
|
| 838 |
all_pass = all_pass and p3
|
| 839 |
|
| 840 |
print(f"\n {'ALL PASSED' if all_pass else 'SOME FAILURES'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
return all_pass
|
| 842 |
|
| 843 |
|
|
|
|
| 250 |
Mathematically exact. The Eckart-Young (1936) shortcut.
|
| 251 |
"""
|
| 252 |
B, M, N = A.shape
|
| 253 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 254 |
+
A_f = A.float()
|
| 255 |
+
G = torch.bmm(A_f.transpose(1, 2), A_f) # (B, N, N)
|
| 256 |
+
eigenvalues, V = torch.linalg.eigh(G) # (B, N), (B, N, N)
|
| 257 |
+
eigenvalues = eigenvalues.flip(-1)
|
| 258 |
+
V = V.flip(-1)
|
| 259 |
+
S = torch.sqrt(eigenvalues.clamp(min=1e-12)) # (B, N)
|
| 260 |
+
U = torch.bmm(A_f, V) / S.unsqueeze(1) # (B, M, N)
|
| 261 |
+
Vh = V.transpose(-2, -1).contiguous() # (B, N, N)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
return U, S, Vh
|
| 263 |
|
| 264 |
|
|
|
|
| 347 |
|
| 348 |
|
| 349 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 350 |
+
# β BATCHED PROCRUSTES ALIGNMENT β
|
| 351 |
+
# β Subspace-preserving: rotate in k-d, leave orthogonal complement alone β
|
| 352 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 353 |
|
| 354 |
+
def batched_procrustes(source, target, rank=24, whiten=True, schulz_iters=10):
|
| 355 |
+
"""Batched Procrustes alignment with rank-k subspace-preserving rotation.
|
| 356 |
+
|
| 357 |
+
For N β€ 32: runs full N-d Procrustes (sub-ms via gram_eigh).
|
| 358 |
+
For N > 32: projects to rank-d, aligns there, lifts back preserving
|
| 359 |
+
the orthogonal complement exactly.
|
| 360 |
+
|
| 361 |
+
Empirically validated: 1.000 NN agreement with full Procrustes across
|
| 362 |
+
all tested configurations (N=32-128, k=8-64).
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
source: (B, n_samples, N) or (n_samples, N) β source embeddings
|
| 366 |
+
target: (B, n_samples, N) or (n_samples, N) β target embeddings
|
| 367 |
+
rank: Projection rank for large N. Ignored if N β€ 32.
|
| 368 |
+
whiten: If True, apply Newton-Schulz whitening before rotation.
|
| 369 |
+
schulz_iters: Iterations for whitening (if enabled).
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
aligned: same shape as source β source aligned to target
|
| 373 |
+
info: dict with rotation matrix, diagnostics
|
| 374 |
+
"""
|
| 375 |
+
unbatched = source.ndim == 2
|
| 376 |
+
if unbatched:
|
| 377 |
+
source = source.unsqueeze(0)
|
| 378 |
+
target = target.unsqueeze(0)
|
| 379 |
+
|
| 380 |
+
B, n_samples, N = source.shape
|
| 381 |
+
device = source.device
|
| 382 |
+
source_f = source.float()
|
| 383 |
+
target_f = target.float()
|
| 384 |
+
|
| 385 |
+
# Center
|
| 386 |
+
src_mean = source_f.mean(1, keepdim=True)
|
| 387 |
+
tgt_mean = target_f.mean(1, keepdim=True)
|
| 388 |
+
src_c = source_f - src_mean
|
| 389 |
+
tgt_c = target_f - tgt_mean
|
| 390 |
+
|
| 391 |
+
# Whiten if requested (Newton-Schulz, pure bmm)
|
| 392 |
+
if whiten:
|
| 393 |
+
src_cov = torch.bmm(src_c.transpose(1, 2), src_c) / max(n_samples - 1, 1)
|
| 394 |
+
tgt_cov = torch.bmm(tgt_c.transpose(1, 2), tgt_c) / max(n_samples - 1, 1)
|
| 395 |
+
src_W = newton_schulz_invsqrt(src_cov, iters=schulz_iters) # (B, N, N)
|
| 396 |
+
tgt_W = newton_schulz_invsqrt(tgt_cov, iters=schulz_iters)
|
| 397 |
+
src_w = torch.bmm(src_c, src_W)
|
| 398 |
+
tgt_w = torch.bmm(tgt_c, tgt_W)
|
| 399 |
+
# Normalize rows
|
| 400 |
+
src_w = F.normalize(src_w, dim=-1)
|
| 401 |
+
tgt_w = F.normalize(tgt_w, dim=-1)
|
| 402 |
+
else:
|
| 403 |
+
src_w = src_c
|
| 404 |
+
tgt_w = tgt_c
|
| 405 |
+
|
| 406 |
+
use_projection = N > 32 and rank < N
|
| 407 |
+
|
| 408 |
+
if not use_projection:
|
| 409 |
+
# βββ Full N-d Procrustes βββ
|
| 410 |
+
C = torch.bmm(src_w.transpose(1, 2), tgt_w) # (B, N, N)
|
| 411 |
+
U, _, Vh = torch.linalg.svd(C)
|
| 412 |
+
R = torch.bmm(U, Vh) # (B, N, N)
|
| 413 |
+
|
| 414 |
+
aligned_w = torch.bmm(src_w, R)
|
| 415 |
+
|
| 416 |
+
# Unwhiten back to target space
|
| 417 |
+
if whiten:
|
| 418 |
+
tgt_unW = torch.linalg.pinv(tgt_W) # (B, N, N)
|
| 419 |
+
aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean
|
| 420 |
+
else:
|
| 421 |
+
aligned = aligned_w + tgt_mean
|
| 422 |
+
|
| 423 |
+
cos_after = F.cosine_similarity(
|
| 424 |
+
aligned_w[:, :min(1000, n_samples)],
|
| 425 |
+
tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item()
|
| 426 |
+
|
| 427 |
+
info = {
|
| 428 |
+
'method': 'full',
|
| 429 |
+
'N': N, 'rank': N,
|
| 430 |
+
'rotation': R,
|
| 431 |
+
'cos_after': cos_after,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
else:
|
| 435 |
+
# βββ Subspace-preserving rank-k Procrustes βββ
|
| 436 |
+
k = min(rank, N - 1)
|
| 437 |
+
|
| 438 |
+
# Orthonormal projection basis via QR
|
| 439 |
+
P_raw = torch.randn(B, N, k, device=device, dtype=torch.float32)
|
| 440 |
+
P = torch.linalg.qr(P_raw).Q # (B, N, k) orthonormal columns
|
| 441 |
+
|
| 442 |
+
# Project to k-d
|
| 443 |
+
src_proj = torch.bmm(src_w, P) # (B, n_samples, k)
|
| 444 |
+
tgt_proj = torch.bmm(tgt_w, P) # (B, n_samples, k)
|
| 445 |
+
|
| 446 |
+
# Procrustes in k-d (cheap β kΓk SVD)
|
| 447 |
+
C_k = torch.bmm(src_proj.transpose(1, 2), tgt_proj) # (B, k, k)
|
| 448 |
+
U_k, _, Vh_k = torch.linalg.svd(C_k)
|
| 449 |
+
R_k = torch.bmm(U_k, Vh_k) # (B, k, k)
|
| 450 |
+
|
| 451 |
+
# Subspace-preserving lift:
|
| 452 |
+
# 1. Decompose source into in-subspace and perpendicular components
|
| 453 |
+
# 2. Rotate only the in-subspace component
|
| 454 |
+
# 3. Add back the perpendicular component untouched
|
| 455 |
+
src_in = torch.bmm(src_w, P) # (B, n_samples, k) β coefficients in subspace
|
| 456 |
+
P_T = P.transpose(1, 2) # (B, k, N)
|
| 457 |
+
src_in_fullspace = torch.bmm(src_in, P_T) # (B, n_samples, N) β back in N-d
|
| 458 |
+
src_perp = src_w - src_in_fullspace # (B, n_samples, N) β orthogonal complement
|
| 459 |
+
|
| 460 |
+
# Rotate in-subspace component
|
| 461 |
+
src_rotated_k = torch.bmm(src_in, R_k) # (B, n_samples, k)
|
| 462 |
+
src_rotated_fullspace = torch.bmm(src_rotated_k, P_T) # (B, n_samples, N)
|
| 463 |
+
|
| 464 |
+
# Recombine
|
| 465 |
+
aligned_w = src_rotated_fullspace + src_perp
|
| 466 |
+
|
| 467 |
+
# Unwhiten
|
| 468 |
+
if whiten:
|
| 469 |
+
tgt_unW = torch.linalg.pinv(tgt_W)
|
| 470 |
+
aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean
|
| 471 |
+
else:
|
| 472 |
+
aligned = aligned_w + tgt_mean
|
| 473 |
+
|
| 474 |
+
# Diagnostics
|
| 475 |
+
cos_after_full = F.cosine_similarity(
|
| 476 |
+
aligned_w[:, :min(1000, n_samples)],
|
| 477 |
+
tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item()
|
| 478 |
+
cos_after_k = F.cosine_similarity(
|
| 479 |
+
src_rotated_k[:, :min(1000, n_samples)],
|
| 480 |
+
tgt_proj[:, :min(1000, n_samples)], dim=-1).mean().item()
|
| 481 |
+
|
| 482 |
+
info = {
|
| 483 |
+
'method': 'subspace',
|
| 484 |
+
'N': N, 'rank': k,
|
| 485 |
+
'rotation_k': R_k,
|
| 486 |
+
'projection': P,
|
| 487 |
+
'cos_after': cos_after_full,
|
| 488 |
+
'cos_after_k': cos_after_k,
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
if unbatched:
|
| 492 |
+
aligned = aligned.squeeze(0)
|
| 493 |
+
|
| 494 |
+
return aligned, info
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def batched_procrustes_align_pair(source, target, rank=24, whiten=True,
|
| 498 |
+
schulz_iters=10, n_align=10000):
|
| 499 |
+
"""Convenience wrapper: align source to target using a subset, apply to all.
|
| 500 |
+
|
| 501 |
+
Computes alignment on first n_align samples, applies to full source.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
source: (n_samples, N) source embeddings
|
| 505 |
+
target: (n_samples, N) target embeddings
|
| 506 |
+
rank: Projection rank for N > 32
|
| 507 |
+
whiten: Apply Newton-Schulz whitening
|
| 508 |
+
n_align: Number of samples to compute alignment from
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
aligned: (n_samples, N) aligned source
|
| 512 |
+
info: alignment diagnostics
|
| 513 |
+
"""
|
| 514 |
+
N = source.shape[-1]
|
| 515 |
+
n = min(n_align, source.shape[0], target.shape[0])
|
| 516 |
+
|
| 517 |
+
# Compute alignment on subset
|
| 518 |
+
_, info = batched_procrustes(
|
| 519 |
+
source[:n].unsqueeze(0), target[:n].unsqueeze(0),
|
| 520 |
+
rank=rank, whiten=whiten, schulz_iters=schulz_iters)
|
| 521 |
+
|
| 522 |
+
# Apply to full source
|
| 523 |
+
src_f = source.float()
|
| 524 |
+
src_mean = source[:n].float().mean(0, keepdim=True)
|
| 525 |
+
tgt_mean = target[:n].float().mean(0, keepdim=True)
|
| 526 |
+
src_c = src_f - src_mean
|
| 527 |
+
|
| 528 |
+
if info['method'] == 'full':
|
| 529 |
+
R = info['rotation'].squeeze(0) # (N, N)
|
| 530 |
+
if whiten:
|
| 531 |
+
src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1)
|
| 532 |
+
tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1)
|
| 533 |
+
src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0)
|
| 534 |
+
tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0)
|
| 535 |
+
tgt_unW = torch.linalg.pinv(tgt_W)
|
| 536 |
+
aligned = F.normalize(src_c @ src_W, dim=-1) @ R @ tgt_unW + tgt_mean
|
| 537 |
+
else:
|
| 538 |
+
aligned = src_c @ R + tgt_mean
|
| 539 |
+
else:
|
| 540 |
+
P = info['projection'].squeeze(0) # (N, k)
|
| 541 |
+
R_k = info['rotation_k'].squeeze(0) # (k, k)
|
| 542 |
+
if whiten:
|
| 543 |
+
src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1)
|
| 544 |
+
tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1)
|
| 545 |
+
src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0)
|
| 546 |
+
tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0)
|
| 547 |
+
tgt_unW = torch.linalg.pinv(tgt_W)
|
| 548 |
+
src_w = F.normalize(src_c @ src_W, dim=-1)
|
| 549 |
+
else:
|
| 550 |
+
src_w = src_c
|
| 551 |
+
|
| 552 |
+
src_in = src_w @ P # (n_all, k)
|
| 553 |
+
src_perp = src_w - src_in @ P.T
|
| 554 |
+
src_rotated = src_in @ R_k @ P.T + src_perp
|
| 555 |
+
|
| 556 |
+
if whiten:
|
| 557 |
+
aligned = src_rotated @ tgt_unW + tgt_mean
|
| 558 |
+
else:
|
| 559 |
+
aligned = src_rotated + tgt_mean
|
| 560 |
+
|
| 561 |
+
return aligned, info
|
| 562 |
+
|
| 563 |
def projected_svd(A, target_rank=24, oversampling=8):
|
| 564 |
"""Rank-projected thin SVD for (B, M, N) with large N.
|
| 565 |
|
|
|
|
| 1035 |
all_pass = all_pass and p3
|
| 1036 |
|
| 1037 |
print(f"\n {'ALL PASSED' if all_pass else 'SOME FAILURES'}")
|
| 1038 |
+
|
| 1039 |
+
# ββ Procrustes alignment validation ββ
|
| 1040 |
+
print(f"\n{'='*70}")
|
| 1041 |
+
print(f" PROCRUSTES ALIGNMENT VALIDATION")
|
| 1042 |
+
print(f"{'='*70}")
|
| 1043 |
+
|
| 1044 |
+
for N in [16, 32, 48, 64, 128]:
|
| 1045 |
+
n_samp = 2000
|
| 1046 |
+
# Create correlated source/target
|
| 1047 |
+
shared = torch.randn(n_samp, N, device='cuda')
|
| 1048 |
+
source = shared + 0.3 * torch.randn(n_samp, N, device='cuda')
|
| 1049 |
+
target = shared + 0.3 * torch.randn(n_samp, N, device='cuda')
|
| 1050 |
+
|
| 1051 |
+
rank = min(24, N - 1)
|
| 1052 |
+
aligned, info = batched_procrustes(
|
| 1053 |
+
source.unsqueeze(0), target.unsqueeze(0),
|
| 1054 |
+
rank=rank, whiten=True)
|
| 1055 |
+
aligned = aligned.squeeze(0)
|
| 1056 |
+
|
| 1057 |
+
cos_before = F.cosine_similarity(source, target, dim=-1).mean().item()
|
| 1058 |
+
cos_after = F.cosine_similarity(aligned, target, dim=-1).mean().item()
|
| 1059 |
+
improved = cos_after > cos_before
|
| 1060 |
+
|
| 1061 |
+
print(f" N={N:>3} rank={rank:>3} method={info['method']:>8}:"
|
| 1062 |
+
f" cos {cos_before:.4f} β {cos_after:.4f}"
|
| 1063 |
+
f" {'IMPROVED' if improved else 'WORSE'}")
|
| 1064 |
+
|
| 1065 |
+
# Test unbatched interface
|
| 1066 |
+
source_ub = torch.randn(1000, 48, device='cuda')
|
| 1067 |
+
target_ub = torch.randn(1000, 48, device='cuda') * 0.5 + source_ub * 0.5
|
| 1068 |
+
aligned_ub, info_ub = batched_procrustes(source_ub, target_ub, rank=24)
|
| 1069 |
+
assert aligned_ub.shape == source_ub.shape, f"Shape mismatch: {aligned_ub.shape} vs {source_ub.shape}"
|
| 1070 |
+
print(f" Unbatched API: shape {aligned_ub.shape} β method={info_ub['method']}")
|
| 1071 |
+
|
| 1072 |
+
# Test batched_procrustes_align_pair
|
| 1073 |
+
aligned_pair, info_pair = batched_procrustes_align_pair(
|
| 1074 |
+
source_ub, target_ub, rank=24, n_align=500)
|
| 1075 |
+
assert aligned_pair.shape == source_ub.shape
|
| 1076 |
+
cos_pair = F.cosine_similarity(aligned_pair, target_ub, dim=-1).mean().item()
|
| 1077 |
+
print(f" Align-pair API: cos={cos_pair:.4f} method={info_pair['method']}")
|
| 1078 |
+
|
| 1079 |
+
print(f" PROCRUSTES VALIDATION COMPLETE")
|
| 1080 |
+
|
| 1081 |
return all_pass
|
| 1082 |
|
| 1083 |
|