Upload 6 files
Browse files
Geolip_Procrustes_Bert_Model_Step_Model_Scaling.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiment_bulk (1).ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiment_bulk.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiment_bulk_claude_generated.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
qwen35_embedding_explorer.ipynb
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"accelerator": "GPU"
|
| 14 |
+
},
|
| 15 |
+
"cells": [
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"# Qwen3.5-0.8B Embedding Explorer\n",
|
| 21 |
+
"Extract all-layer embeddings, compare prompt similarity, and evaluate potential for diffusion conditioning.\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"**Runtime: GPU (T4 is fine for 0.8B)**"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"source": [
|
| 30 |
+
"# Qwen3.5 requires transformers from git main (not yet in PyPI release)\n",
|
| 31 |
+
"!pip install -q \"transformers @ git+https://github.com/huggingface/transformers.git@main\"\n",
|
| 32 |
+
"!pip install -q accelerate torch matplotlib seaborn numpy scipy"
|
| 33 |
+
],
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"outputs": []
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"import torch\n",
|
| 42 |
+
"import numpy as np\n",
|
| 43 |
+
"import matplotlib.pyplot as plt\n",
|
| 44 |
+
"import seaborn as sns\n",
|
| 45 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 46 |
+
"from scipy.spatial.distance import cosine\n",
|
| 47 |
+
"from typing import Optional\n",
|
| 48 |
+
"import gc\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 51 |
+
"print(f'Device: {device}')\n",
|
| 52 |
+
"if device.type == 'cuda':\n",
|
| 53 |
+
" print(f'GPU: {torch.cuda.get_device_name()}')\n",
|
| 54 |
+
" print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
|
| 55 |
+
],
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"outputs": []
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "markdown",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"source": [
|
| 63 |
+
"## Load Model"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"cell_type": "code",
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"source": [
|
| 70 |
+
"MODEL_ID = 'Qwen/Qwen3.5-0.8B'\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
|
| 73 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 74 |
+
" MODEL_ID,\n",
|
| 75 |
+
" torch_dtype=torch.bfloat16,\n",
|
| 76 |
+
" device_map='auto',\n",
|
| 77 |
+
" trust_remote_code=True,\n",
|
| 78 |
+
" output_hidden_states=True, # Critical: get all layer outputs\n",
|
| 79 |
+
")\n",
|
| 80 |
+
"model.eval()\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"num_layers = model.config.num_hidden_layers\n",
|
| 83 |
+
"hidden_dim = model.config.hidden_size\n",
|
| 84 |
+
"print(f'Layers: {num_layers}, Hidden dim: {hidden_dim}')\n",
|
| 85 |
+
"print(f'Total hidden states returned: {num_layers + 1} (embedding layer + {num_layers} transformer layers)')"
|
| 86 |
+
],
|
| 87 |
+
"execution_count": null,
|
| 88 |
+
"outputs": []
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "markdown",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"source": [
|
| 94 |
+
"## Embedding Extraction Engine\n",
|
| 95 |
+
"Extracts hidden states from all layers with multiple pooling strategies."
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"class QwenEmbeddingExtractor:\n",
|
| 103 |
+
" \"\"\"Extract and pool hidden states from all layers of Qwen3.5-0.8B.\"\"\"\n",
|
| 104 |
+
"\n",
|
| 105 |
+
" def __init__(self, model, tokenizer, device):\n",
|
| 106 |
+
" self.model = model\n",
|
| 107 |
+
" self.tokenizer = tokenizer\n",
|
| 108 |
+
" self.device = device\n",
|
| 109 |
+
" self.num_layers = model.config.num_hidden_layers + 1 # +1 for embedding layer\n",
|
| 110 |
+
" self.hidden_dim = model.config.hidden_size\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" @torch.no_grad()\n",
|
| 113 |
+
" def extract_hidden_states(self, text: str) -> dict:\n",
|
| 114 |
+
" \"\"\"\n",
|
| 115 |
+
" Run forward pass and return all hidden states + metadata.\n",
|
| 116 |
+
"\n",
|
| 117 |
+
" Returns dict with:\n",
|
| 118 |
+
" - hidden_states: tuple of (num_layers+1) tensors, each [1, seq_len, hidden_dim]\n",
|
| 119 |
+
" - input_ids: token IDs\n",
|
| 120 |
+
" - tokens: decoded token strings\n",
|
| 121 |
+
" - seq_len: number of tokens\n",
|
| 122 |
+
" \"\"\"\n",
|
| 123 |
+
" inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n",
|
| 124 |
+
" outputs = self.model(**inputs)\n",
|
| 125 |
+
"\n",
|
| 126 |
+
" hidden_states = outputs.hidden_states # tuple of (num_layers+1) tensors\n",
|
| 127 |
+
" input_ids = inputs['input_ids'][0]\n",
|
| 128 |
+
" tokens = [self.tokenizer.decode(tid) for tid in input_ids]\n",
|
| 129 |
+
"\n",
|
| 130 |
+
" return {\n",
|
| 131 |
+
" 'hidden_states': hidden_states,\n",
|
| 132 |
+
" 'input_ids': input_ids,\n",
|
| 133 |
+
" 'tokens': tokens,\n",
|
| 134 |
+
" 'seq_len': len(tokens),\n",
|
| 135 |
+
" }\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" def pool_hidden_states(\n",
|
| 138 |
+
" self,\n",
|
| 139 |
+
" hidden_states: tuple,\n",
|
| 140 |
+
" method: str = 'mean',\n",
|
| 141 |
+
" layer_indices: Optional[list] = None,\n",
|
| 142 |
+
" ) -> torch.Tensor:\n",
|
| 143 |
+
" \"\"\"\n",
|
| 144 |
+
" Pool hidden states across tokens for specified layers.\n",
|
| 145 |
+
"\n",
|
| 146 |
+
" Args:\n",
|
| 147 |
+
" hidden_states: tuple from extract_hidden_states\n",
|
| 148 |
+
" method: 'mean', 'last_token', 'max', or 'all_tokens'\n",
|
| 149 |
+
" layer_indices: which layers to return (None = all)\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" Returns:\n",
|
| 152 |
+
" For 'all_tokens': [num_layers, seq_len, hidden_dim]\n",
|
| 153 |
+
" Otherwise: [num_layers, hidden_dim]\n",
|
| 154 |
+
" \"\"\"\n",
|
| 155 |
+
" if layer_indices is None:\n",
|
| 156 |
+
" layer_indices = list(range(len(hidden_states)))\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" pooled = []\n",
|
| 159 |
+
" for idx in layer_indices:\n",
|
| 160 |
+
" hs = hidden_states[idx].squeeze(0) # [seq_len, hidden_dim]\n",
|
| 161 |
+
"\n",
|
| 162 |
+
" if method == 'mean':\n",
|
| 163 |
+
" pooled.append(hs.mean(dim=0)) # [hidden_dim]\n",
|
| 164 |
+
" elif method == 'last_token':\n",
|
| 165 |
+
" pooled.append(hs[-1]) # [hidden_dim]\n",
|
| 166 |
+
" elif method == 'max':\n",
|
| 167 |
+
" pooled.append(hs.max(dim=0).values) # [hidden_dim]\n",
|
| 168 |
+
" elif method == 'all_tokens':\n",
|
| 169 |
+
" pooled.append(hs) # [seq_len, hidden_dim]\n",
|
| 170 |
+
" else:\n",
|
| 171 |
+
" raise ValueError(f'Unknown pooling method: {method}')\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" return torch.stack(pooled)\n",
|
| 174 |
+
"\n",
|
| 175 |
+
" def extract_and_pool(self, text: str, method: str = 'mean') -> dict:\n",
|
| 176 |
+
" \"\"\"\n",
|
| 177 |
+
" Convenience: extract + pool in one call.\n",
|
| 178 |
+
"\n",
|
| 179 |
+
" Returns dict with:\n",
|
| 180 |
+
" - embeddings: [num_layers, hidden_dim] (or [num_layers, seq_len, hidden_dim] for all_tokens)\n",
|
| 181 |
+
" - tokens: list of token strings\n",
|
| 182 |
+
" - seq_len: int\n",
|
| 183 |
+
" \"\"\"\n",
|
| 184 |
+
" data = self.extract_hidden_states(text)\n",
|
| 185 |
+
" embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n",
|
| 186 |
+
" return {\n",
|
| 187 |
+
" 'embeddings': embeddings,\n",
|
| 188 |
+
" 'tokens': data['tokens'],\n",
|
| 189 |
+
" 'seq_len': data['seq_len'],\n",
|
| 190 |
+
" }\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"extractor = QwenEmbeddingExtractor(model, tokenizer, device)\n",
|
| 193 |
+
"print(f'Extractor ready. Will return {extractor.num_layers} layer embeddings per prompt.')"
|
| 194 |
+
],
|
| 195 |
+
"execution_count": null,
|
| 196 |
+
"outputs": []
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"cell_type": "markdown",
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"source": [
|
| 202 |
+
"## Define Test Prompts\n",
|
| 203 |
+
"Edit these to whatever you want to compare. Grouped by semantic category to see clustering behavior."
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"source": [
|
| 210 |
+
"# ---- EDIT THESE ----\n",
|
| 211 |
+
"# Groups help visualize clustering. Flat list is also fine.\n",
|
| 212 |
+
"PROMPT_GROUPS = {\n",
|
| 213 |
+
" 'photorealistic': [\n",
|
| 214 |
+
" 'a photograph of a cat sitting on a windowsill in golden hour light',\n",
|
| 215 |
+
" 'professional photo of a mountain landscape at sunset with dramatic clouds',\n",
|
| 216 |
+
" 'close-up portrait of an elderly man with weathered skin and blue eyes',\n",
|
| 217 |
+
" ],\n",
|
| 218 |
+
" 'artistic': [\n",
|
| 219 |
+
" 'an oil painting of a stormy sea in the style of Turner',\n",
|
| 220 |
+
" 'watercolor illustration of a quiet Japanese garden with cherry blossoms',\n",
|
| 221 |
+
" 'abstract geometric composition with overlapping translucent shapes',\n",
|
| 222 |
+
" ],\n",
|
| 223 |
+
" 'semantic_shift': [\n",
|
| 224 |
+
" 'a red cube on a blue floor',\n",
|
| 225 |
+
" 'a blue cube on a red floor',\n",
|
| 226 |
+
" 'a green sphere floating above a white plane',\n",
|
| 227 |
+
" ],\n",
|
| 228 |
+
" 'edge_cases': [\n",
|
| 229 |
+
" 'darkness',\n",
|
| 230 |
+
" '', # empty string baseline\n",
|
| 231 |
+
" 'asdfghjkl random noise tokens xyzzy',\n",
|
| 232 |
+
" ],\n",
|
| 233 |
+
"}\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"# Flatten for processing\n",
|
| 236 |
+
"prompts = []\n",
|
| 237 |
+
"prompt_labels = []\n",
|
| 238 |
+
"prompt_groups = []\n",
|
| 239 |
+
"for group_name, group_prompts in PROMPT_GROUPS.items():\n",
|
| 240 |
+
" for p in group_prompts:\n",
|
| 241 |
+
" prompts.append(p)\n",
|
| 242 |
+
" label = p[:50] + '...' if len(p) > 50 else p\n",
|
| 243 |
+
" label = label if label else '<empty>'\n",
|
| 244 |
+
" prompt_labels.append(label)\n",
|
| 245 |
+
" prompt_groups.append(group_name)\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"print(f'{len(prompts)} prompts across {len(PROMPT_GROUPS)} groups')"
|
| 248 |
+
],
|
| 249 |
+
"execution_count": null,
|
| 250 |
+
"outputs": []
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "markdown",
|
| 254 |
+
"metadata": {},
|
| 255 |
+
"source": [
|
| 256 |
+
"## Extract All Embeddings"
|
| 257 |
+
]
|
| 258 |
+
},
|
| 259 |
+
{
|
| 260 |
+
"cell_type": "code",
|
| 261 |
+
"metadata": {},
|
| 262 |
+
"source": [
|
| 263 |
+
"POOL_METHODS = ['mean', 'last_token']\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"# Store results: {method: {prompt_idx: [num_layers, hidden_dim]}}\n",
|
| 266 |
+
"all_embeddings = {method: {} for method in POOL_METHODS}\n",
|
| 267 |
+
"token_counts = {}\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"for i, prompt in enumerate(prompts):\n",
|
| 270 |
+
" print(f'[{i+1}/{len(prompts)}] ({len(prompt)} chars) \"{prompt_labels[i]}\"')\n",
|
| 271 |
+
" for method in POOL_METHODS:\n",
|
| 272 |
+
" result = extractor.extract_and_pool(prompt, method=method)\n",
|
| 273 |
+
" all_embeddings[method][i] = result['embeddings'].float().cpu() # [num_layers, hidden_dim]\n",
|
| 274 |
+
" if method == POOL_METHODS[0]:\n",
|
| 275 |
+
" token_counts[i] = result['seq_len']\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"print(f'\\nDone. Shape per prompt per method: {all_embeddings[\"mean\"][0].shape}')\n",
|
| 278 |
+
"print(f'Token counts: {list(token_counts.values())}')"
|
| 279 |
+
],
|
| 280 |
+
"execution_count": null,
|
| 281 |
+
"outputs": []
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "markdown",
|
| 285 |
+
"metadata": {},
|
| 286 |
+
"source": [
|
| 287 |
+
"## Cosine Similarity Analysis\n",
|
| 288 |
+
"Compute pairwise similarity at every layer, for each pooling method."
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"metadata": {},
|
| 294 |
+
"source": [
|
| 295 |
+
"def compute_pairwise_cosine(embeddings_dict, num_prompts, num_layers):\n",
|
| 296 |
+
" \"\"\"\n",
|
| 297 |
+
" Compute cosine similarity between all prompt pairs at each layer.\n",
|
| 298 |
+
"\n",
|
| 299 |
+
" Returns: [num_layers, num_prompts, num_prompts] numpy array\n",
|
| 300 |
+
" \"\"\"\n",
|
| 301 |
+
" sim_matrix = np.zeros((num_layers, num_prompts, num_prompts))\n",
|
| 302 |
+
"\n",
|
| 303 |
+
" for layer_idx in range(num_layers):\n",
|
| 304 |
+
" for i in range(num_prompts):\n",
|
| 305 |
+
" for j in range(num_prompts):\n",
|
| 306 |
+
" if i == j:\n",
|
| 307 |
+
" sim_matrix[layer_idx, i, j] = 1.0\n",
|
| 308 |
+
" elif j > i:\n",
|
| 309 |
+
" vec_i = embeddings_dict[i][layer_idx].numpy()\n",
|
| 310 |
+
" vec_j = embeddings_dict[j][layer_idx].numpy()\n",
|
| 311 |
+
" sim = 1.0 - cosine(vec_i, vec_j)\n",
|
| 312 |
+
" sim_matrix[layer_idx, i, j] = sim\n",
|
| 313 |
+
" sim_matrix[layer_idx, j, i] = sim\n",
|
| 314 |
+
"\n",
|
| 315 |
+
" return sim_matrix\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"n_prompts = len(prompts)\n",
|
| 318 |
+
"n_layers = extractor.num_layers\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"sim_matrices = {}\n",
|
| 321 |
+
"for method in POOL_METHODS:\n",
|
| 322 |
+
" sim_matrices[method] = compute_pairwise_cosine(\n",
|
| 323 |
+
" all_embeddings[method], n_prompts, n_layers\n",
|
| 324 |
+
" )\n",
|
| 325 |
+
" print(f'{method}: similarity matrix shape = {sim_matrices[method].shape}')"
|
| 326 |
+
],
|
| 327 |
+
"execution_count": null,
|
| 328 |
+
"outputs": []
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"cell_type": "markdown",
|
| 332 |
+
"metadata": {},
|
| 333 |
+
"source": [
|
| 334 |
+
"## Heatmaps: Per-Layer Similarity\n",
|
| 335 |
+
"Shows how prompt-pair similarity evolves across layers."
|
| 336 |
+
]
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"cell_type": "code",
|
| 340 |
+
"metadata": {},
|
| 341 |
+
"source": [
|
| 342 |
+
"def plot_similarity_heatmaps(sim_matrix, labels, method_name, layers_to_show=None):\n",
|
| 343 |
+
" \"\"\"\n",
|
| 344 |
+
" Plot similarity heatmaps for selected layers.\n",
|
| 345 |
+
" If layers_to_show is None, picks: first, 25%, 50%, 75%, last.\n",
|
| 346 |
+
" \"\"\"\n",
|
| 347 |
+
" n_layers = sim_matrix.shape[0]\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" if layers_to_show is None:\n",
|
| 350 |
+
" layers_to_show = sorted(set([\n",
|
| 351 |
+
" 0,\n",
|
| 352 |
+
" n_layers // 4,\n",
|
| 353 |
+
" n_layers // 2,\n",
|
| 354 |
+
" 3 * n_layers // 4,\n",
|
| 355 |
+
" n_layers - 2, # penultimate\n",
|
| 356 |
+
" n_layers - 1, # final\n",
|
| 357 |
+
" ]))\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" n_plots = len(layers_to_show)\n",
|
| 360 |
+
" fig, axes = plt.subplots(1, n_plots, figsize=(6 * n_plots, 5))\n",
|
| 361 |
+
" if n_plots == 1:\n",
|
| 362 |
+
" axes = [axes]\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" for ax, layer_idx in zip(axes, layers_to_show):\n",
|
| 365 |
+
" layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n",
|
| 366 |
+
" sns.heatmap(\n",
|
| 367 |
+
" sim_matrix[layer_idx],\n",
|
| 368 |
+
" xticklabels=labels, yticklabels=labels,\n",
|
| 369 |
+
" vmin=-0.2, vmax=1.0,\n",
|
| 370 |
+
" cmap='RdYlBu_r', annot=True, fmt='.2f',\n",
|
| 371 |
+
" ax=ax, square=True,\n",
|
| 372 |
+
" cbar_kws={'shrink': 0.6},\n",
|
| 373 |
+
" )\n",
|
| 374 |
+
" ax.set_title(f'{method_name} | {layer_name}', fontsize=11)\n",
|
| 375 |
+
" ax.tick_params(axis='x', rotation=45)\n",
|
| 376 |
+
" ax.tick_params(axis='y', rotation=0)\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" plt.tight_layout()\n",
|
| 379 |
+
" plt.show()\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"# Short labels for readability\n",
|
| 382 |
+
"short_labels = [l[:30] for l in prompt_labels]\n",
|
| 383 |
+
"\n",
|
| 384 |
+
"for method in POOL_METHODS:\n",
|
| 385 |
+
" print(f'\\n=== {method.upper()} POOLING ===')\n",
|
| 386 |
+
" plot_similarity_heatmaps(sim_matrices[method], short_labels, method)"
|
| 387 |
+
],
|
| 388 |
+
"execution_count": null,
|
| 389 |
+
"outputs": []
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
"cell_type": "markdown",
|
| 393 |
+
"metadata": {},
|
| 394 |
+
"source": [
|
| 395 |
+
"## Layer-wise Discriminability\n",
|
| 396 |
+
"For each layer, compute average within-group similarity vs. between-group similarity.\n",
|
| 397 |
+
"Higher gap = better semantic clustering = more useful for conditioning."
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"cell_type": "code",
|
| 402 |
+
"metadata": {},
|
| 403 |
+
"source": [
|
| 404 |
+
"def compute_discriminability(sim_matrix, group_labels):\n",
|
| 405 |
+
" \"\"\"\n",
|
| 406 |
+
" Per-layer: avg within-group sim, avg between-group sim, and gap.\n",
|
| 407 |
+
" Returns arrays of shape [num_layers].\n",
|
| 408 |
+
" \"\"\"\n",
|
| 409 |
+
" n_layers = sim_matrix.shape[0]\n",
|
| 410 |
+
" n = sim_matrix.shape[1]\n",
|
| 411 |
+
" unique_groups = list(set(group_labels))\n",
|
| 412 |
+
"\n",
|
| 413 |
+
" within_sim = np.zeros(n_layers)\n",
|
| 414 |
+
" between_sim = np.zeros(n_layers)\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" for layer in range(n_layers):\n",
|
| 417 |
+
" w_vals, b_vals = [], []\n",
|
| 418 |
+
" for i in range(n):\n",
|
| 419 |
+
" for j in range(i + 1, n):\n",
|
| 420 |
+
" val = sim_matrix[layer, i, j]\n",
|
| 421 |
+
" if group_labels[i] == group_labels[j]:\n",
|
| 422 |
+
" w_vals.append(val)\n",
|
| 423 |
+
" else:\n",
|
| 424 |
+
" b_vals.append(val)\n",
|
| 425 |
+
" within_sim[layer] = np.mean(w_vals) if w_vals else 0\n",
|
| 426 |
+
" between_sim[layer] = np.mean(b_vals) if b_vals else 0\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" return within_sim, between_sim, within_sim - between_sim\n",
|
| 429 |
+
"\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"fig, axes = plt.subplots(1, len(POOL_METHODS), figsize=(8 * len(POOL_METHODS), 5))\n",
|
| 432 |
+
"if len(POOL_METHODS) == 1:\n",
|
| 433 |
+
" axes = [axes]\n",
|
| 434 |
+
"\n",
|
| 435 |
+
"best_layers = {}\n",
|
| 436 |
+
"for ax, method in zip(axes, POOL_METHODS):\n",
|
| 437 |
+
" within, between, gap = compute_discriminability(sim_matrices[method], prompt_groups)\n",
|
| 438 |
+
"\n",
|
| 439 |
+
" layer_x = np.arange(n_layers)\n",
|
| 440 |
+
" ax.plot(layer_x, within, label='Within-group sim', color='#2196F3', linewidth=2)\n",
|
| 441 |
+
" ax.plot(layer_x, between, label='Between-group sim', color='#FF5722', linewidth=2)\n",
|
| 442 |
+
" ax.fill_between(layer_x, between, within, alpha=0.15, color='green')\n",
|
| 443 |
+
" ax.plot(layer_x, gap, label='Gap (discriminability)', color='green', linewidth=2, linestyle='--')\n",
|
| 444 |
+
"\n",
|
| 445 |
+
" best_layer = np.argmax(gap)\n",
|
| 446 |
+
" best_layers[method] = best_layer\n",
|
| 447 |
+
" ax.axvline(best_layer, color='green', linestyle=':', alpha=0.5)\n",
|
| 448 |
+
" ax.annotate(f'Best: L{best_layer}', xy=(best_layer, gap[best_layer]),\n",
|
| 449 |
+
" xytext=(best_layer + 1, gap[best_layer] + 0.02),\n",
|
| 450 |
+
" arrowprops=dict(arrowstyle='->', color='green'),\n",
|
| 451 |
+
" fontsize=10, color='green')\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" ax.set_xlabel('Layer Index')\n",
|
| 454 |
+
" ax.set_ylabel('Cosine Similarity')\n",
|
| 455 |
+
" ax.set_title(f'{method} pooling — Semantic Discriminability')\n",
|
| 456 |
+
" ax.legend()\n",
|
| 457 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"plt.tight_layout()\n",
|
| 460 |
+
"plt.show()\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"print('\\nBest discriminability layers:')\n",
|
| 463 |
+
"for method, layer in best_layers.items():\n",
|
| 464 |
+
" print(f' {method}: layer {layer}')"
|
| 465 |
+
],
|
| 466 |
+
"execution_count": null,
|
| 467 |
+
"outputs": []
|
| 468 |
+
},
|
| 469 |
+
{
|
| 470 |
+
"cell_type": "markdown",
|
| 471 |
+
"metadata": {},
|
| 472 |
+
"source": [
|
| 473 |
+
"## Embedding Norm & Variance Across Layers\n",
|
| 474 |
+
"Checks for collapse (all norms converging) or explosion — both bad for conditioning."
|
| 475 |
+
]
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"cell_type": "code",
|
| 479 |
+
"metadata": {},
|
| 480 |
+
"source": [
|
| 481 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"# Norms per layer per prompt\n",
|
| 484 |
+
"for method, ax in zip(POOL_METHODS, axes):\n",
|
| 485 |
+
" for i in range(n_prompts):\n",
|
| 486 |
+
" norms = all_embeddings[method][i].norm(dim=-1).numpy() # [num_layers]\n",
|
| 487 |
+
" ax.plot(range(n_layers), norms, alpha=0.6, label=short_labels[i][:20])\n",
|
| 488 |
+
"\n",
|
| 489 |
+
" ax.set_xlabel('Layer')\n",
|
| 490 |
+
" ax.set_ylabel('L2 Norm')\n",
|
| 491 |
+
" ax.set_title(f'{method} pooling — Embedding Norms')\n",
|
| 492 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 493 |
+
" ax.legend(fontsize=7, loc='upper left')\n",
|
| 494 |
+
"\n",
|
| 495 |
+
"plt.tight_layout()\n",
|
| 496 |
+
"plt.show()"
|
| 497 |
+
],
|
| 498 |
+
"execution_count": null,
|
| 499 |
+
"outputs": []
|
| 500 |
+
},
|
| 501 |
+
{
|
| 502 |
+
"cell_type": "markdown",
|
| 503 |
+
"metadata": {},
|
| 504 |
+
"source": [
|
| 505 |
+
"## Effective Dimensionality per Layer\n",
|
| 506 |
+
"How many dimensions are actually being used? Low rank = bad for diffusion conditioning diversity."
|
| 507 |
+
]
|
| 508 |
+
},
|
| 509 |
+
{
|
| 510 |
+
"cell_type": "code",
|
| 511 |
+
"metadata": {},
|
| 512 |
+
"source": [
|
| 513 |
+
"def effective_dimensionality(embeddings_list):\n",
|
| 514 |
+
" \"\"\"\n",
|
| 515 |
+
" Compute effective dimensionality via participation ratio of singular values.\n",
|
| 516 |
+
" embeddings_list: list of [hidden_dim] vectors\n",
|
| 517 |
+
" Returns: float (effective rank)\n",
|
| 518 |
+
" \"\"\"\n",
|
| 519 |
+
" mat = torch.stack(embeddings_list) # [n_prompts, hidden_dim]\n",
|
| 520 |
+
" mat = mat - mat.mean(dim=0) # center\n",
|
| 521 |
+
" _, S, _ = torch.svd(mat)\n",
|
| 522 |
+
" S = S / S.sum()\n",
|
| 523 |
+
" participation_ratio = 1.0 / (S ** 2).sum().item()\n",
|
| 524 |
+
" return participation_ratio\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"\n",
|
| 527 |
+
"for method in POOL_METHODS:\n",
|
| 528 |
+
" eff_dims = []\n",
|
| 529 |
+
" for layer_idx in range(n_layers):\n",
|
| 530 |
+
" layer_vecs = [all_embeddings[method][i][layer_idx] for i in range(n_prompts)]\n",
|
| 531 |
+
" ed = effective_dimensionality(layer_vecs)\n",
|
| 532 |
+
" eff_dims.append(ed)\n",
|
| 533 |
+
"\n",
|
| 534 |
+
" plt.plot(range(n_layers), eff_dims, label=method, linewidth=2)\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"plt.xlabel('Layer')\n",
|
| 537 |
+
"plt.ylabel('Effective Dimensionality (participation ratio)')\n",
|
| 538 |
+
"plt.title('Effective Rank of Embedding Space per Layer')\n",
|
| 539 |
+
"plt.legend()\n",
|
| 540 |
+
"plt.grid(True, alpha=0.3)\n",
|
| 541 |
+
"plt.tight_layout()\n",
|
| 542 |
+
"plt.show()"
|
| 543 |
+
],
|
| 544 |
+
"execution_count": null,
|
| 545 |
+
"outputs": []
|
| 546 |
+
},
|
| 547 |
+
{
|
| 548 |
+
"cell_type": "markdown",
|
| 549 |
+
"metadata": {},
|
| 550 |
+
"source": [
|
| 551 |
+
"## Quick Diffusion Conditioning Assessment\n",
|
| 552 |
+
"Summary: which layers look most promising as conditioning vectors?"
|
| 553 |
+
]
|
| 554 |
+
},
|
| 555 |
+
{
|
| 556 |
+
"cell_type": "code",
|
| 557 |
+
"metadata": {},
|
| 558 |
+
"source": [
|
| 559 |
+
"print('=' * 70)\n",
|
| 560 |
+
"print('DIFFUSION CONDITIONING VIABILITY SUMMARY')\n",
|
| 561 |
+
"print('=' * 70)\n",
|
| 562 |
+
"print(f'\\nModel: {MODEL_ID}')\n",
|
| 563 |
+
"print(f'Layers: {n_layers} (0=input embeddings, rest=transformer layers)')\n",
|
| 564 |
+
"print(f'Hidden dim: {hidden_dim}')\n",
|
| 565 |
+
"print(f'Prompts tested: {n_prompts}')\n",
|
| 566 |
+
"print()\n",
|
| 567 |
+
"\n",
|
| 568 |
+
"for method in POOL_METHODS:\n",
|
| 569 |
+
" within, between, gap = compute_discriminability(sim_matrices[method], prompt_groups)\n",
|
| 570 |
+
" best_l = np.argmax(gap)\n",
|
| 571 |
+
" worst_l = np.argmin(gap)\n",
|
| 572 |
+
"\n",
|
| 573 |
+
" # Check for near-collapse: if all pairwise sims > 0.95 at any layer\n",
|
| 574 |
+
" collapse_layers = []\n",
|
| 575 |
+
" for l in range(n_layers):\n",
|
| 576 |
+
" off_diag = sim_matrices[method][l][np.triu_indices(n_prompts, k=1)]\n",
|
| 577 |
+
" if off_diag.min() > 0.95:\n",
|
| 578 |
+
" collapse_layers.append(l)\n",
|
| 579 |
+
"\n",
|
| 580 |
+
" print(f'--- {method.upper()} POOLING ---')\n",
|
| 581 |
+
" print(f' Best discriminability: Layer {best_l} (gap = {gap[best_l]:.4f})')\n",
|
| 582 |
+
" print(f' Worst discriminability: Layer {worst_l} (gap = {gap[worst_l]:.4f})')\n",
|
| 583 |
+
" print(f' Penultimate layer gap: {gap[-2]:.4f}')\n",
|
| 584 |
+
" print(f' Final layer gap: {gap[-1]:.4f}')\n",
|
| 585 |
+
" if collapse_layers:\n",
|
| 586 |
+
" print(f' WARNING: Near-collapse at layers: {collapse_layers}')\n",
|
| 587 |
+
" else:\n",
|
| 588 |
+
" print(f' No collapse detected (all layers have some discrimination)')\n",
|
| 589 |
+
" print()\n",
|
| 590 |
+
"\n",
|
| 591 |
+
"print('RECOMMENDATIONS:')\n",
|
| 592 |
+
"print(' For POOLED conditioning (global vector): Use the best discriminability layer.')\n",
|
| 593 |
+
"print(' For TOKEN-LEVEL conditioning (cross-attention): Re-run with method=\"all_tokens\"')\n",
|
| 594 |
+
"print(' and compare token-level structure against T5/CLIP token outputs.')\n",
|
| 595 |
+
"print(' Watch for: norm explosion in later layers (may need LayerNorm before conditioning).')\n",
|
| 596 |
+
"print(' The penultimate layer often outperforms the final layer (CLIP effect).')"
|
| 597 |
+
],
|
| 598 |
+
"execution_count": null,
|
| 599 |
+
"outputs": []
|
| 600 |
+
},
|
| 601 |
+
{
|
| 602 |
+
"cell_type": "markdown",
|
| 603 |
+
"metadata": {},
|
| 604 |
+
"source": [
|
| 605 |
+
"## (Optional) Export Embeddings for Further Analysis\n",
|
| 606 |
+
"Save to disk for loading into your geometric pipeline."
|
| 607 |
+
]
|
| 608 |
+
},
|
| 609 |
+
{
|
| 610 |
+
"cell_type": "code",
|
| 611 |
+
"metadata": {},
|
| 612 |
+
"source": [
|
| 613 |
+
"# Uncomment to save\n",
|
| 614 |
+
"# export = {\n",
|
| 615 |
+
"# 'model_id': MODEL_ID,\n",
|
| 616 |
+
"# 'prompts': prompts,\n",
|
| 617 |
+
"# 'prompt_groups': prompt_groups,\n",
|
| 618 |
+
"# 'pool_methods': POOL_METHODS,\n",
|
| 619 |
+
"# 'embeddings': {m: {i: all_embeddings[m][i] for i in range(n_prompts)} for m in POOL_METHODS},\n",
|
| 620 |
+
"# 'sim_matrices': sim_matrices,\n",
|
| 621 |
+
"# 'num_layers': n_layers,\n",
|
| 622 |
+
"# 'hidden_dim': hidden_dim,\n",
|
| 623 |
+
"# }\n",
|
| 624 |
+
"# torch.save(export, 'qwen35_0.8b_embeddings.pt')\n",
|
| 625 |
+
"# print('Saved to qwen35_0.8b_embeddings.pt')"
|
| 626 |
+
],
|
| 627 |
+
"execution_count": null,
|
| 628 |
+
"outputs": []
|
| 629 |
+
}
|
| 630 |
+
]
|
| 631 |
+
}
|
qwen35_twoshot_embedding_explorer.ipynb
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"accelerator": "GPU"
|
| 14 |
+
},
|
| 15 |
+
"cells": [
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"# Qwen3.5-0.8B Two-Shot Embedding Explorer\n",
|
| 21 |
+
"Generate descriptions via two-shot prompting, then re-encode the output to extract embeddings with actual semantic diversity.\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"**Runtime: GPU (T4)**"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"source": [
|
| 30 |
+
"# Qwen3.5 requires transformers from git main\n",
|
| 31 |
+
"!pip install -q \"transformers @ git+https://github.com/huggingface/transformers.git@main\"\n",
|
| 32 |
+
"!pip install -q accelerate torch matplotlib seaborn numpy scipy"
|
| 33 |
+
],
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"outputs": []
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"import torch\n",
|
| 42 |
+
"import torch.nn.functional as F\n",
|
| 43 |
+
"import numpy as np\n",
|
| 44 |
+
"import matplotlib.pyplot as plt\n",
|
| 45 |
+
"import seaborn as sns\n",
|
| 46 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 47 |
+
"from scipy.spatial.distance import cosine\n",
|
| 48 |
+
"from typing import Optional\n",
|
| 49 |
+
"import gc\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 52 |
+
"print(f'Device: {device}')\n",
|
| 53 |
+
"if device.type == 'cuda':\n",
|
| 54 |
+
" print(f'GPU: {torch.cuda.get_device_name()}')\n",
|
| 55 |
+
" print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
|
| 56 |
+
],
|
| 57 |
+
"execution_count": null,
|
| 58 |
+
"outputs": []
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "markdown",
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"source": [
|
| 64 |
+
"## Load Model"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"MODEL_ID = 'Qwen/Qwen3.5-0.8B'\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
|
| 74 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 75 |
+
" MODEL_ID,\n",
|
| 76 |
+
" torch_dtype=torch.bfloat16,\n",
|
| 77 |
+
" device_map='auto',\n",
|
| 78 |
+
" trust_remote_code=True,\n",
|
| 79 |
+
")\n",
|
| 80 |
+
"model.eval()\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"num_layers = model.config.num_hidden_layers\n",
|
| 83 |
+
"hidden_dim = model.config.hidden_size\n",
|
| 84 |
+
"print(f'Layers: {num_layers}, Hidden dim: {hidden_dim}')\n",
|
| 85 |
+
"print(f'Total hidden states: {num_layers + 1}')"
|
| 86 |
+
],
|
| 87 |
+
"execution_count": null,
|
| 88 |
+
"outputs": []
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "markdown",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"source": [
|
| 94 |
+
"## Two-Shot Generation + Re-Encode Pipeline\n",
|
| 95 |
+
"1. Build a two-shot chat prompt with examples\n",
|
| 96 |
+
"2. Generate a description\n",
|
| 97 |
+
"3. Re-encode the generated text (not the prompt) and extract all hidden states"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"source": [
|
| 104 |
+
"class TwoShotEmbeddingExtractor:\n",
|
| 105 |
+
" \"\"\"\n",
|
| 106 |
+
" Two-shot generation -> re-encode pipeline.\n",
|
| 107 |
+
" \n",
|
| 108 |
+
" Step 1: Chat-template two-shot prompt -> generate description\n",
|
| 109 |
+
" Step 2: Encode the GENERATED text alone -> extract hidden states\n",
|
| 110 |
+
" \n",
|
| 111 |
+
" This produces embeddings of the model's own description,\n",
|
| 112 |
+
" which has far more semantic diversity than raw prompt encoding.\n",
|
| 113 |
+
" \"\"\"\n",
|
| 114 |
+
"\n",
|
| 115 |
+
" def __init__(self, model, tokenizer, device, min_tokens=2):\n",
|
| 116 |
+
" self.model = model\n",
|
| 117 |
+
" self.tokenizer = tokenizer\n",
|
| 118 |
+
" self.device = device\n",
|
| 119 |
+
" self.min_tokens = min_tokens\n",
|
| 120 |
+
" self.num_layers = model.config.num_hidden_layers + 1\n",
|
| 121 |
+
" self.hidden_dim = model.config.hidden_size\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" def build_twoshot_prompt(self, subject: str) -> str:\n",
|
| 124 |
+
" \"\"\"Build two-shot chat prompt with visual description examples.\"\"\"\n",
|
| 125 |
+
" messages = [\n",
|
| 126 |
+
" {\n",
|
| 127 |
+
" 'role': 'system',\n",
|
| 128 |
+
" 'content': 'You describe scenes and subjects in exactly one sentence. '\n",
|
| 129 |
+
" 'Be specific about visual features, lighting, colors, and composition.'\n",
|
| 130 |
+
" },\n",
|
| 131 |
+
" {\n",
|
| 132 |
+
" 'role': 'user',\n",
|
| 133 |
+
" 'content': 'Describe: a car on a highway'\n",
|
| 134 |
+
" },\n",
|
| 135 |
+
" {\n",
|
| 136 |
+
" 'role': 'assistant',\n",
|
| 137 |
+
" 'content': 'A silver sedan cruises along a sunlit four-lane highway '\n",
|
| 138 |
+
" 'cutting through rolling green hills under a pale blue sky with wispy cirrus clouds.'\n",
|
| 139 |
+
" },\n",
|
| 140 |
+
" {\n",
|
| 141 |
+
" 'role': 'user',\n",
|
| 142 |
+
" 'content': 'Describe: a sunflower field'\n",
|
| 143 |
+
" },\n",
|
| 144 |
+
" {\n",
|
| 145 |
+
" 'role': 'assistant',\n",
|
| 146 |
+
" 'content': 'Thousands of tall sunflowers with bright yellow petals and dark brown centers '\n",
|
| 147 |
+
" 'stand in dense rows across a flat field stretching to the horizon at golden hour.'\n",
|
| 148 |
+
" },\n",
|
| 149 |
+
" {\n",
|
| 150 |
+
" 'role': 'user',\n",
|
| 151 |
+
" 'content': f'Describe: {subject}'\n",
|
| 152 |
+
" },\n",
|
| 153 |
+
" ]\n",
|
| 154 |
+
" return self.tokenizer.apply_chat_template(\n",
|
| 155 |
+
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 156 |
+
" )\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" @torch.no_grad()\n",
|
| 159 |
+
" def generate_description(self, subject: str, max_new_tokens=80) -> str:\n",
|
| 160 |
+
" \"\"\"Generate a one-sentence visual description via two-shot.\"\"\"\n",
|
| 161 |
+
" prompt = self.build_twoshot_prompt(subject)\n",
|
| 162 |
+
" inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" output_ids = self.model.generate(\n",
|
| 165 |
+
" **inputs,\n",
|
| 166 |
+
" max_new_tokens=max_new_tokens,\n",
|
| 167 |
+
" do_sample=True,\n",
|
| 168 |
+
" temperature=0.7,\n",
|
| 169 |
+
" top_p=0.9,\n",
|
| 170 |
+
" pad_token_id=self.tokenizer.eos_token_id,\n",
|
| 171 |
+
" )\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" # Decode only the new tokens\n",
|
| 174 |
+
" new_tokens = output_ids[0][inputs['input_ids'].shape[1]:]\n",
|
| 175 |
+
" description = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()\n",
|
| 176 |
+
" return description\n",
|
| 177 |
+
"\n",
|
| 178 |
+
" @torch.no_grad()\n",
|
| 179 |
+
" def encode_text(self, text: str) -> dict:\n",
|
| 180 |
+
" \"\"\"\n",
|
| 181 |
+
" Encode text and return all hidden states.\n",
|
| 182 |
+
" Pads ultra-short inputs to avoid conv1d crash in DeltaNet layers.\n",
|
| 183 |
+
" \"\"\"\n",
|
| 184 |
+
" inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n",
|
| 185 |
+
" seq_len = inputs['input_ids'].shape[1]\n",
|
| 186 |
+
"\n",
|
| 187 |
+
" if seq_len < self.min_tokens:\n",
|
| 188 |
+
" text = text + ' . .'\n",
|
| 189 |
+
" inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n",
|
| 190 |
+
" seq_len = inputs['input_ids'].shape[1]\n",
|
| 191 |
+
"\n",
|
| 192 |
+
" outputs = self.model(**inputs, output_hidden_states=True)\n",
|
| 193 |
+
"\n",
|
| 194 |
+
" hidden_states = outputs.hidden_states\n",
|
| 195 |
+
" if hidden_states is None:\n",
|
| 196 |
+
" raise RuntimeError('Model returned None for hidden_states.')\n",
|
| 197 |
+
"\n",
|
| 198 |
+
" input_ids = inputs['input_ids'][0]\n",
|
| 199 |
+
" tokens = [self.tokenizer.decode(tid) for tid in input_ids]\n",
|
| 200 |
+
"\n",
|
| 201 |
+
" return {\n",
|
| 202 |
+
" 'hidden_states': hidden_states,\n",
|
| 203 |
+
" 'input_ids': input_ids,\n",
|
| 204 |
+
" 'tokens': tokens,\n",
|
| 205 |
+
" 'seq_len': len(tokens),\n",
|
| 206 |
+
" }\n",
|
| 207 |
+
"\n",
|
| 208 |
+
" def pool_hidden_states(self, hidden_states, method='mean'):\n",
|
| 209 |
+
" \"\"\"Pool across tokens for all layers. Returns [num_layers, hidden_dim].\"\"\"\n",
|
| 210 |
+
" pooled = []\n",
|
| 211 |
+
" for hs in hidden_states:\n",
|
| 212 |
+
" hs = hs.squeeze(0) # [seq_len, hidden_dim]\n",
|
| 213 |
+
" if method == 'mean':\n",
|
| 214 |
+
" pooled.append(hs.mean(dim=0))\n",
|
| 215 |
+
" elif method == 'last_token':\n",
|
| 216 |
+
" pooled.append(hs[-1])\n",
|
| 217 |
+
" elif method == 'max':\n",
|
| 218 |
+
" pooled.append(hs.max(dim=0).values)\n",
|
| 219 |
+
" else:\n",
|
| 220 |
+
" raise ValueError(f'Unknown method: {method}')\n",
|
| 221 |
+
" return torch.stack(pooled)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
" def generate_and_encode(self, subject: str, method='mean') -> dict:\n",
|
| 224 |
+
" \"\"\"\n",
|
| 225 |
+
" Full pipeline: generate description, then re-encode it.\n",
|
| 226 |
+
" Returns embeddings of the GENERATED text, not the prompt.\n",
|
| 227 |
+
" \"\"\"\n",
|
| 228 |
+
" description = self.generate_description(subject)\n",
|
| 229 |
+
" data = self.encode_text(description)\n",
|
| 230 |
+
" embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n",
|
| 231 |
+
" return {\n",
|
| 232 |
+
" 'embeddings': embeddings,\n",
|
| 233 |
+
" 'description': description,\n",
|
| 234 |
+
" 'tokens': data['tokens'],\n",
|
| 235 |
+
" 'seq_len': data['seq_len'],\n",
|
| 236 |
+
" }\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" def encode_raw(self, text: str, method='mean') -> dict:\n",
|
| 239 |
+
" \"\"\"\n",
|
| 240 |
+
" Direct encode (no generation). For comparison baseline.\n",
|
| 241 |
+
" \"\"\"\n",
|
| 242 |
+
" data = self.encode_text(text)\n",
|
| 243 |
+
" embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n",
|
| 244 |
+
" return {\n",
|
| 245 |
+
" 'embeddings': embeddings,\n",
|
| 246 |
+
" 'description': text,\n",
|
| 247 |
+
" 'tokens': data['tokens'],\n",
|
| 248 |
+
" 'seq_len': data['seq_len'],\n",
|
| 249 |
+
" }\n",
|
| 250 |
+
"\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"extractor = TwoShotEmbeddingExtractor(model, tokenizer, device)\n",
|
| 253 |
+
"print(f'Extractor ready. {extractor.num_layers} layers, {extractor.hidden_dim}d')"
|
| 254 |
+
],
|
| 255 |
+
"execution_count": null,
|
| 256 |
+
"outputs": []
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "markdown",
|
| 260 |
+
"metadata": {},
|
| 261 |
+
"source": [
|
| 262 |
+
"## Test Generation\n",
|
| 263 |
+
"Quick sanity check that the two-shot pipeline produces good descriptions."
|
| 264 |
+
]
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"cell_type": "code",
|
| 268 |
+
"metadata": {},
|
| 269 |
+
"source": [
|
| 270 |
+
"test_subjects = [\n",
|
| 271 |
+
" 'a cat on a windowsill',\n",
|
| 272 |
+
" 'a red cube on a blue floor',\n",
|
| 273 |
+
" 'an oil painting of a stormy sea',\n",
|
| 274 |
+
" 'darkness',\n",
|
| 275 |
+
" 'cheese',\n",
|
| 276 |
+
"]\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"print('Two-shot generation test:')\n",
|
| 279 |
+
"print('=' * 70)\n",
|
| 280 |
+
"for subject in test_subjects:\n",
|
| 281 |
+
" desc = extractor.generate_description(subject)\n",
|
| 282 |
+
" print(f'\\nSubject: {subject}')\n",
|
| 283 |
+
" print(f'Generated: {desc}')"
|
| 284 |
+
],
|
| 285 |
+
"execution_count": null,
|
| 286 |
+
"outputs": []
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
"cell_type": "markdown",
|
| 290 |
+
"metadata": {},
|
| 291 |
+
"source": [
|
| 292 |
+
"## Define Test Subjects\n",
|
| 293 |
+
"Same groups as before. Each gets a two-shot generated description + raw encode for comparison."
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"cell_type": "code",
|
| 298 |
+
"metadata": {},
|
| 299 |
+
"source": [
|
| 300 |
+
"SUBJECT_GROUPS = {\n",
|
| 301 |
+
" 'photorealistic': [\n",
|
| 302 |
+
" 'a cat sitting on a windowsill in golden hour light',\n",
|
| 303 |
+
" 'a mountain landscape at sunset with dramatic clouds',\n",
|
| 304 |
+
" 'an elderly man with weathered skin and blue eyes',\n",
|
| 305 |
+
" ],\n",
|
| 306 |
+
" 'artistic': [\n",
|
| 307 |
+
" 'an oil painting of a stormy sea',\n",
|
| 308 |
+
" 'a quiet Japanese garden with cherry blossoms',\n",
|
| 309 |
+
" 'abstract geometric shapes overlapping',\n",
|
| 310 |
+
" ],\n",
|
| 311 |
+
" 'semantic_shift': [\n",
|
| 312 |
+
" 'a red cube on a blue floor',\n",
|
| 313 |
+
" 'a blue cube on a red floor',\n",
|
| 314 |
+
" 'a green sphere floating above a white plane',\n",
|
| 315 |
+
" ],\n",
|
| 316 |
+
" 'gibberish': [\n",
|
| 317 |
+
" 'mxkrl vvtonp qazhif bwsdee lpoqnr yttmz',\n",
|
| 318 |
+
" 'florpnax grindleby shovantic wumblecrax tazzifer',\n",
|
| 319 |
+
" 'aaaa bbbb cccc dddd eeee ffff gggg hhhh',\n",
|
| 320 |
+
" ],\n",
|
| 321 |
+
" 'short': [\n",
|
| 322 |
+
" 'taco',\n",
|
| 323 |
+
" '1girl',\n",
|
| 324 |
+
" 'cheese',\n",
|
| 325 |
+
" 'cheddar bacon sub',\n",
|
| 326 |
+
" ],\n",
|
| 327 |
+
"}\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"subjects = []\n",
|
| 330 |
+
"subject_labels = []\n",
|
| 331 |
+
"subject_groups = []\n",
|
| 332 |
+
"for group_name, group_items in SUBJECT_GROUPS.items():\n",
|
| 333 |
+
" for s in group_items:\n",
|
| 334 |
+
" subjects.append(s)\n",
|
| 335 |
+
" label = s[:40] + '...' if len(s) > 40 else s\n",
|
| 336 |
+
" subject_labels.append(label)\n",
|
| 337 |
+
" subject_groups.append(group_name)\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"print(f'{len(subjects)} subjects across {len(SUBJECT_GROUPS)} groups')"
|
| 340 |
+
],
|
| 341 |
+
"execution_count": null,
|
| 342 |
+
"outputs": []
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"cell_type": "markdown",
|
| 346 |
+
"metadata": {},
|
| 347 |
+
"source": [
|
| 348 |
+
"## Generate Descriptions + Extract Embeddings\n",
|
| 349 |
+
"Both two-shot (generated) and raw (direct encode) for comparison."
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"cell_type": "code",
|
| 354 |
+
"metadata": {},
|
| 355 |
+
"source": [
|
| 356 |
+
"POOL_METHODS = ['mean', 'last_token']\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"# Two-shot generated embeddings\n",
|
| 359 |
+
"twoshot_embeddings = {method: {} for method in POOL_METHODS}\n",
|
| 360 |
+
"twoshot_descriptions = {}\n",
|
| 361 |
+
"twoshot_token_counts = {}\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"# Raw direct-encode embeddings (baseline)\n",
|
| 364 |
+
"raw_embeddings = {method: {} for method in POOL_METHODS}\n",
|
| 365 |
+
"raw_token_counts = {}\n",
|
| 366 |
+
"\n",
|
| 367 |
+
"print('=== TWO-SHOT GENERATION + ENCODE ===')\n",
|
| 368 |
+
"print('=' * 70)\n",
|
| 369 |
+
"for i, subject in enumerate(subjects):\n",
|
| 370 |
+
" print(f'\\n[{i+1}/{len(subjects)}] \"{subject_labels[i]}\"')\n",
|
| 371 |
+
" for method in POOL_METHODS:\n",
|
| 372 |
+
" result = extractor.generate_and_encode(subject, method=method)\n",
|
| 373 |
+
" twoshot_embeddings[method][i] = result['embeddings'].float().cpu()\n",
|
| 374 |
+
" if method == POOL_METHODS[0]:\n",
|
| 375 |
+
" twoshot_descriptions[i] = result['description']\n",
|
| 376 |
+
" twoshot_token_counts[i] = result['seq_len']\n",
|
| 377 |
+
" print(f' -> \"{result[\"description\"][:80]}...\"' if len(result['description']) > 80 else f' -> \"{result[\"description\"]}\"')\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"print('\\n\\n=== RAW DIRECT ENCODE (BASELINE) ===')\n",
|
| 380 |
+
"print('=' * 70)\n",
|
| 381 |
+
"for i, subject in enumerate(subjects):\n",
|
| 382 |
+
" print(f'[{i+1}/{len(subjects)}] \"{subject_labels[i]}\"')\n",
|
| 383 |
+
" for method in POOL_METHODS:\n",
|
| 384 |
+
" result = extractor.encode_raw(subject, method=method)\n",
|
| 385 |
+
" raw_embeddings[method][i] = result['embeddings'].float().cpu()\n",
|
| 386 |
+
" if method == POOL_METHODS[0]:\n",
|
| 387 |
+
" raw_token_counts[i] = result['seq_len']\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"n_subjects = len(subjects)\n",
|
| 390 |
+
"n_layers = extractor.num_layers\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"print(f'\\nDone. {n_subjects} subjects, {n_layers} layers, {extractor.hidden_dim}d')\n",
|
| 393 |
+
"print(f'Two-shot token counts: {list(twoshot_token_counts.values())}')\n",
|
| 394 |
+
"print(f'Raw token counts: {list(raw_token_counts.values())}')"
|
| 395 |
+
],
|
| 396 |
+
"execution_count": null,
|
| 397 |
+
"outputs": []
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
"cell_type": "markdown",
|
| 401 |
+
"metadata": {},
|
| 402 |
+
"source": [
|
| 403 |
+
"## Cosine Similarity Matrices"
|
| 404 |
+
]
|
| 405 |
+
},
|
| 406 |
+
{
|
| 407 |
+
"cell_type": "code",
|
| 408 |
+
"metadata": {},
|
| 409 |
+
"source": [
|
| 410 |
+
"def compute_pairwise_cosine(embeddings_dict, num_prompts, num_layers):\n",
|
| 411 |
+
" sim_matrix = np.zeros((num_layers, num_prompts, num_prompts))\n",
|
| 412 |
+
" for layer_idx in range(num_layers):\n",
|
| 413 |
+
" for i in range(num_prompts):\n",
|
| 414 |
+
" for j in range(num_prompts):\n",
|
| 415 |
+
" if i == j:\n",
|
| 416 |
+
" sim_matrix[layer_idx, i, j] = 1.0\n",
|
| 417 |
+
" elif j > i:\n",
|
| 418 |
+
" vec_i = embeddings_dict[i][layer_idx].numpy()\n",
|
| 419 |
+
" vec_j = embeddings_dict[j][layer_idx].numpy()\n",
|
| 420 |
+
" sim = 1.0 - cosine(vec_i, vec_j)\n",
|
| 421 |
+
" sim_matrix[layer_idx, i, j] = sim\n",
|
| 422 |
+
" sim_matrix[layer_idx, j, i] = sim\n",
|
| 423 |
+
" return sim_matrix\n",
|
| 424 |
+
"\n",
|
| 425 |
+
"# Compute for both pipelines, both pooling methods\n",
|
| 426 |
+
"twoshot_sim = {}\n",
|
| 427 |
+
"raw_sim = {}\n",
|
| 428 |
+
"for method in POOL_METHODS:\n",
|
| 429 |
+
" twoshot_sim[method] = compute_pairwise_cosine(twoshot_embeddings[method], n_subjects, n_layers)\n",
|
| 430 |
+
" raw_sim[method] = compute_pairwise_cosine(raw_embeddings[method], n_subjects, n_layers)\n",
|
| 431 |
+
" print(f'{method}: twoshot {twoshot_sim[method].shape}, raw {raw_sim[method].shape}')"
|
| 432 |
+
],
|
| 433 |
+
"execution_count": null,
|
| 434 |
+
"outputs": []
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"cell_type": "markdown",
|
| 438 |
+
"metadata": {},
|
| 439 |
+
"source": [
|
| 440 |
+
"## Head-to-Head: Two-Shot vs Raw at Best Layer\n",
|
| 441 |
+
"Side-by-side heatmaps showing how two-shot generation changes the similarity landscape."
|
| 442 |
+
]
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"cell_type": "code",
|
| 446 |
+
"metadata": {},
|
| 447 |
+
"source": [
|
| 448 |
+
"def plot_comparison_heatmaps(twoshot_sim, raw_sim, labels, method, layer_idx):\n",
|
| 449 |
+
" \"\"\"Side-by-side: raw vs two-shot at a specific layer.\"\"\"\n",
|
| 450 |
+
" layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n",
|
| 451 |
+
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))\n",
|
| 452 |
+
"\n",
|
| 453 |
+
" sns.heatmap(\n",
|
| 454 |
+
" raw_sim[layer_idx], xticklabels=labels, yticklabels=labels,\n",
|
| 455 |
+
" vmin=-0.2, vmax=1.0, cmap='RdYlBu_r', annot=True, fmt='.2f',\n",
|
| 456 |
+
" ax=ax1, square=True, annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n",
|
| 457 |
+
" )\n",
|
| 458 |
+
" ax1.set_title(f'RAW encode | {method} | {layer_name}', fontsize=14)\n",
|
| 459 |
+
" ax1.tick_params(axis='x', rotation=90, labelsize=7)\n",
|
| 460 |
+
" ax1.tick_params(axis='y', rotation=0, labelsize=7)\n",
|
| 461 |
+
"\n",
|
| 462 |
+
" sns.heatmap(\n",
|
| 463 |
+
" twoshot_sim[layer_idx], xticklabels=labels, yticklabels=labels,\n",
|
| 464 |
+
" vmin=-0.2, vmax=1.0, cmap='RdYlBu_r', annot=True, fmt='.2f',\n",
|
| 465 |
+
" ax=ax2, square=True, annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n",
|
| 466 |
+
" )\n",
|
| 467 |
+
" ax2.set_title(f'TWO-SHOT encode | {method} | {layer_name}', fontsize=14)\n",
|
| 468 |
+
" ax2.tick_params(axis='x', rotation=90, labelsize=7)\n",
|
| 469 |
+
" ax2.tick_params(axis='y', rotation=0, labelsize=7)\n",
|
| 470 |
+
"\n",
|
| 471 |
+
" plt.tight_layout()\n",
|
| 472 |
+
" plt.show()\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"short_labels = [l[:30] for l in subject_labels]\n",
|
| 475 |
+
"\n",
|
| 476 |
+
"# Show at penultimate and final layer for last_token\n",
|
| 477 |
+
"for layer_idx in [n_layers - 2, n_layers - 1]:\n",
|
| 478 |
+
" plot_comparison_heatmaps(\n",
|
| 479 |
+
" twoshot_sim['last_token'], raw_sim['last_token'],\n",
|
| 480 |
+
" short_labels, 'last_token', layer_idx\n",
|
| 481 |
+
" )"
|
| 482 |
+
],
|
| 483 |
+
"execution_count": null,
|
| 484 |
+
"outputs": []
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"cell_type": "markdown",
|
| 488 |
+
"metadata": {},
|
| 489 |
+
"source": [
|
| 490 |
+
"## Two-Shot Heatmap Grid (All Sampled Layers)"
|
| 491 |
+
]
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"cell_type": "code",
|
| 495 |
+
"metadata": {},
|
| 496 |
+
"source": [
|
| 497 |
+
"def plot_heatmap_grid(sim_matrix, labels, method_name, title_prefix=''):\n",
|
| 498 |
+
" n_layers = sim_matrix.shape[0]\n",
|
| 499 |
+
" layers_to_show = sorted(set([\n",
|
| 500 |
+
" 0, n_layers // 4, n_layers // 2,\n",
|
| 501 |
+
" 3 * n_layers // 4, n_layers - 2, n_layers - 1,\n",
|
| 502 |
+
" ]))\n",
|
| 503 |
+
"\n",
|
| 504 |
+
" fig, axes = plt.subplots(2, 3, figsize=(24, 20))\n",
|
| 505 |
+
" axes = axes.flatten()\n",
|
| 506 |
+
"\n",
|
| 507 |
+
" for idx, (ax, layer_idx) in enumerate(zip(axes, layers_to_show)):\n",
|
| 508 |
+
" layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n",
|
| 509 |
+
" sns.heatmap(\n",
|
| 510 |
+
" sim_matrix[layer_idx],\n",
|
| 511 |
+
" xticklabels=labels, yticklabels=labels,\n",
|
| 512 |
+
" vmin=-0.2, vmax=1.0, cmap='RdYlBu_r',\n",
|
| 513 |
+
" annot=True, fmt='.2f', ax=ax, square=True,\n",
|
| 514 |
+
" annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n",
|
| 515 |
+
" )\n",
|
| 516 |
+
" ax.set_title(f'{title_prefix}{method_name} | {layer_name}', fontsize=13)\n",
|
| 517 |
+
" ax.tick_params(axis='x', rotation=90, labelsize=7)\n",
|
| 518 |
+
" ax.tick_params(axis='y', rotation=0, labelsize=7)\n",
|
| 519 |
+
"\n",
|
| 520 |
+
" for idx in range(len(layers_to_show), len(axes)):\n",
|
| 521 |
+
" axes[idx].set_visible(False)\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" plt.tight_layout()\n",
|
| 524 |
+
" plt.show()\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"for method in POOL_METHODS:\n",
|
| 527 |
+
" print(f'\\n=== TWO-SHOT | {method.upper()} ===')\n",
|
| 528 |
+
" plot_heatmap_grid(twoshot_sim[method], short_labels, method, 'twoshot | ')"
|
| 529 |
+
],
|
| 530 |
+
"execution_count": null,
|
| 531 |
+
"outputs": []
|
| 532 |
+
},
|
| 533 |
+
{
|
| 534 |
+
"cell_type": "markdown",
|
| 535 |
+
"metadata": {},
|
| 536 |
+
"source": [
|
| 537 |
+
"## Discriminability: Two-Shot vs Raw"
|
| 538 |
+
]
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"cell_type": "code",
|
| 542 |
+
"metadata": {},
|
| 543 |
+
"source": [
|
| 544 |
+
"def compute_discriminability(sim_matrix, group_labels):\n",
|
| 545 |
+
" n_layers = sim_matrix.shape[0]\n",
|
| 546 |
+
" n = sim_matrix.shape[1]\n",
|
| 547 |
+
" within_sim = np.zeros(n_layers)\n",
|
| 548 |
+
" between_sim = np.zeros(n_layers)\n",
|
| 549 |
+
"\n",
|
| 550 |
+
" for layer in range(n_layers):\n",
|
| 551 |
+
" w_vals, b_vals = [], []\n",
|
| 552 |
+
" for i in range(n):\n",
|
| 553 |
+
" for j in range(i + 1, n):\n",
|
| 554 |
+
" val = sim_matrix[layer, i, j]\n",
|
| 555 |
+
" if group_labels[i] == group_labels[j]:\n",
|
| 556 |
+
" w_vals.append(val)\n",
|
| 557 |
+
" else:\n",
|
| 558 |
+
" b_vals.append(val)\n",
|
| 559 |
+
" within_sim[layer] = np.mean(w_vals) if w_vals else 0\n",
|
| 560 |
+
" between_sim[layer] = np.mean(b_vals) if b_vals else 0\n",
|
| 561 |
+
"\n",
|
| 562 |
+
" return within_sim, between_sim, within_sim - between_sim\n",
|
| 563 |
+
"\n",
|
| 564 |
+
"\n",
|
| 565 |
+
"fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
|
| 566 |
+
"\n",
|
| 567 |
+
"configs = [\n",
|
| 568 |
+
" ('mean', raw_sim, 'RAW | mean'),\n",
|
| 569 |
+
" ('mean', twoshot_sim, 'TWO-SHOT | mean'),\n",
|
| 570 |
+
" ('last_token', raw_sim, 'RAW | last_token'),\n",
|
| 571 |
+
" ('last_token', twoshot_sim, 'TWO-SHOT | last_token'),\n",
|
| 572 |
+
"]\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"for ax, (method, sim_dict, title) in zip(axes.flatten(), configs):\n",
|
| 575 |
+
" within, between, gap = compute_discriminability(sim_dict[method], subject_groups)\n",
|
| 576 |
+
" layer_x = np.arange(n_layers)\n",
|
| 577 |
+
"\n",
|
| 578 |
+
" ax.plot(layer_x, within, label='Within-group', color='#2196F3', linewidth=2)\n",
|
| 579 |
+
" ax.plot(layer_x, between, label='Between-group', color='#FF5722', linewidth=2)\n",
|
| 580 |
+
" ax.fill_between(layer_x, between, within, alpha=0.15, color='green')\n",
|
| 581 |
+
" ax.plot(layer_x, gap, label='Gap', color='green', linewidth=2, linestyle='--')\n",
|
| 582 |
+
"\n",
|
| 583 |
+
" best = np.argmax(gap)\n",
|
| 584 |
+
" ax.axvline(best, color='green', linestyle=':', alpha=0.5)\n",
|
| 585 |
+
" ax.annotate(f'Best: L{best} ({gap[best]:.3f})', xy=(best, gap[best]),\n",
|
| 586 |
+
" xytext=(best + 1, gap[best] + 0.02),\n",
|
| 587 |
+
" arrowprops=dict(arrowstyle='->', color='green'),\n",
|
| 588 |
+
" fontsize=9, color='green')\n",
|
| 589 |
+
"\n",
|
| 590 |
+
" ax.set_xlabel('Layer')\n",
|
| 591 |
+
" ax.set_ylabel('Cosine Similarity')\n",
|
| 592 |
+
" ax.set_title(title, fontsize=13)\n",
|
| 593 |
+
" ax.legend(fontsize=8)\n",
|
| 594 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"plt.tight_layout()\n",
|
| 597 |
+
"plt.show()"
|
| 598 |
+
],
|
| 599 |
+
"execution_count": null,
|
| 600 |
+
"outputs": []
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"cell_type": "markdown",
|
| 604 |
+
"metadata": {},
|
| 605 |
+
"source": [
|
| 606 |
+
"## Similarity Statistics: Two-Shot vs Raw"
|
| 607 |
+
]
|
| 608 |
+
},
|
| 609 |
+
{
|
| 610 |
+
"cell_type": "code",
|
| 611 |
+
"metadata": {},
|
| 612 |
+
"source": [
|
| 613 |
+
"print('=' * 70)\n",
|
| 614 |
+
"print('SIMILARITY STATISTICS COMPARISON')\n",
|
| 615 |
+
"print('=' * 70)\n",
|
| 616 |
+
"\n",
|
| 617 |
+
"for method in POOL_METHODS:\n",
|
| 618 |
+
" print(f'\\n--- {method.upper()} ---')\n",
|
| 619 |
+
" for label, sim_dict in [('RAW', raw_sim), ('TWO-SHOT', twoshot_sim)]:\n",
|
| 620 |
+
" # Use penultimate layer\n",
|
| 621 |
+
" layer = n_layers - 2\n",
|
| 622 |
+
" mat = sim_dict[method][layer]\n",
|
| 623 |
+
" off_diag = mat[np.triu_indices(n_subjects, k=1)]\n",
|
| 624 |
+
"\n",
|
| 625 |
+
" print(f' {label} (L{layer}):')\n",
|
| 626 |
+
" print(f' Mean sim: {off_diag.mean():.4f}')\n",
|
| 627 |
+
" print(f' Std sim: {off_diag.std():.4f}')\n",
|
| 628 |
+
" print(f' Min sim: {off_diag.min():.4f}')\n",
|
| 629 |
+
" print(f' Max sim: {off_diag.max():.4f}')\n",
|
| 630 |
+
" print(f' Near-zero (<0.05): {(off_diag < 0.05).sum()}')\n",
|
| 631 |
+
" print(f' High (>0.9): {(off_diag > 0.9).sum()}')"
|
| 632 |
+
],
|
| 633 |
+
"execution_count": null,
|
| 634 |
+
"outputs": []
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"cell_type": "markdown",
|
| 638 |
+
"metadata": {},
|
| 639 |
+
"source": [
|
| 640 |
+
"## Norms & Effective Dimensionality (Two-Shot)"
|
| 641 |
+
]
|
| 642 |
+
},
|
| 643 |
+
{
|
| 644 |
+
"cell_type": "code",
|
| 645 |
+
"metadata": {},
|
| 646 |
+
"source": [
|
| 647 |
+
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 648 |
+
"\n",
|
| 649 |
+
"for method, ax in zip(POOL_METHODS, axes):\n",
|
| 650 |
+
" for i in range(n_subjects):\n",
|
| 651 |
+
" norms = twoshot_embeddings[method][i].norm(dim=-1).numpy()\n",
|
| 652 |
+
" ax.plot(range(n_layers), norms, alpha=0.6, label=short_labels[i][:20])\n",
|
| 653 |
+
" ax.set_xlabel('Layer')\n",
|
| 654 |
+
" ax.set_ylabel('L2 Norm')\n",
|
| 655 |
+
" ax.set_title(f'TWO-SHOT | {method} | Embedding Norms')\n",
|
| 656 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 657 |
+
" ax.legend(fontsize=6, loc='upper left')\n",
|
| 658 |
+
"\n",
|
| 659 |
+
"plt.tight_layout()\n",
|
| 660 |
+
"plt.show()"
|
| 661 |
+
],
|
| 662 |
+
"execution_count": null,
|
| 663 |
+
"outputs": []
|
| 664 |
+
},
|
| 665 |
+
{
|
| 666 |
+
"cell_type": "code",
|
| 667 |
+
"metadata": {},
|
| 668 |
+
"source": [
|
| 669 |
+
"def effective_dimensionality(embeddings_list):\n",
|
| 670 |
+
" mat = torch.stack(embeddings_list)\n",
|
| 671 |
+
" mat = mat - mat.mean(dim=0)\n",
|
| 672 |
+
" _, S, _ = torch.svd(mat)\n",
|
| 673 |
+
" S = S / S.sum()\n",
|
| 674 |
+
" return 1.0 / (S ** 2).sum().item()\n",
|
| 675 |
+
"\n",
|
| 676 |
+
"fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 677 |
+
"\n",
|
| 678 |
+
"for label, emb_dict, ls in [('raw', raw_embeddings, '--'), ('twoshot', twoshot_embeddings, '-')]:\n",
|
| 679 |
+
" for method, color in zip(POOL_METHODS, ['#2196F3', '#FF5722']):\n",
|
| 680 |
+
" eff_dims = []\n",
|
| 681 |
+
" for layer_idx in range(n_layers):\n",
|
| 682 |
+
" vecs = [emb_dict[method][i][layer_idx] for i in range(n_subjects)]\n",
|
| 683 |
+
" eff_dims.append(effective_dimensionality(vecs))\n",
|
| 684 |
+
" ax.plot(range(n_layers), eff_dims, label=f'{label} | {method}',\n",
|
| 685 |
+
" linewidth=2, linestyle=ls, color=color)\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"ax.set_xlabel('Layer')\n",
|
| 688 |
+
"ax.set_ylabel('Effective Dimensionality')\n",
|
| 689 |
+
"ax.set_title('Effective Rank: Raw (dashed) vs Two-Shot (solid)')\n",
|
| 690 |
+
"ax.legend()\n",
|
| 691 |
+
"ax.grid(True, alpha=0.3)\n",
|
| 692 |
+
"plt.tight_layout()\n",
|
| 693 |
+
"plt.show()"
|
| 694 |
+
],
|
| 695 |
+
"execution_count": null,
|
| 696 |
+
"outputs": []
|
| 697 |
+
},
|
| 698 |
+
{
|
| 699 |
+
"cell_type": "markdown",
|
| 700 |
+
"metadata": {},
|
| 701 |
+
"source": [
|
| 702 |
+
"## Summary"
|
| 703 |
+
]
|
| 704 |
+
},
|
| 705 |
+
{
|
| 706 |
+
"cell_type": "code",
|
| 707 |
+
"metadata": {},
|
| 708 |
+
"source": [
|
| 709 |
+
"print('=' * 70)\n",
|
| 710 |
+
"print('TWO-SHOT vs RAW EMBEDDING SUMMARY')\n",
|
| 711 |
+
"print('=' * 70)\n",
|
| 712 |
+
"print(f'Model: {MODEL_ID}')\n",
|
| 713 |
+
"print(f'Layers: {n_layers}, Hidden dim: {extractor.hidden_dim}')\n",
|
| 714 |
+
"print(f'Subjects: {n_subjects}')\n",
|
| 715 |
+
"print()\n",
|
| 716 |
+
"\n",
|
| 717 |
+
"for method in POOL_METHODS:\n",
|
| 718 |
+
" print(f'--- {method.upper()} ---')\n",
|
| 719 |
+
" for label, sim_dict in [('RAW', raw_sim), ('TWO-SHOT', twoshot_sim)]:\n",
|
| 720 |
+
" within, between, gap = compute_discriminability(sim_dict[method], subject_groups)\n",
|
| 721 |
+
" best_l = np.argmax(gap)\n",
|
| 722 |
+
" print(f' {label}:')\n",
|
| 723 |
+
" print(f' Best layer: L{best_l} (gap = {gap[best_l]:.4f})')\n",
|
| 724 |
+
" print(f' Final layer gap: {gap[-1]:.4f}')\n",
|
| 725 |
+
" print()\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"print('\\nGENERATED DESCRIPTIONS:')\n",
|
| 728 |
+
"for i, subject in enumerate(subjects):\n",
|
| 729 |
+
" print(f' [{subject_labels[i]}]')\n",
|
| 730 |
+
" print(f' -> {twoshot_descriptions[i]}')"
|
| 731 |
+
],
|
| 732 |
+
"execution_count": null,
|
| 733 |
+
"outputs": []
|
| 734 |
+
},
|
| 735 |
+
{
|
| 736 |
+
"cell_type": "markdown",
|
| 737 |
+
"metadata": {},
|
| 738 |
+
"source": [
|
| 739 |
+
"## (Optional) Export"
|
| 740 |
+
]
|
| 741 |
+
},
|
| 742 |
+
{
|
| 743 |
+
"cell_type": "code",
|
| 744 |
+
"metadata": {},
|
| 745 |
+
"source": [
|
| 746 |
+
"# Uncomment to save\n",
|
| 747 |
+
"# export = {\n",
|
| 748 |
+
"# 'model_id': MODEL_ID,\n",
|
| 749 |
+
"# 'subjects': subjects,\n",
|
| 750 |
+
"# 'subject_groups': subject_groups,\n",
|
| 751 |
+
"# 'twoshot_descriptions': twoshot_descriptions,\n",
|
| 752 |
+
"# 'twoshot_embeddings': twoshot_embeddings,\n",
|
| 753 |
+
"# 'raw_embeddings': raw_embeddings,\n",
|
| 754 |
+
"# 'twoshot_sim': twoshot_sim,\n",
|
| 755 |
+
"# 'raw_sim': raw_sim,\n",
|
| 756 |
+
"# 'num_layers': n_layers,\n",
|
| 757 |
+
"# 'hidden_dim': extractor.hidden_dim,\n",
|
| 758 |
+
"# }\n",
|
| 759 |
+
"# torch.save(export, 'qwen35_twoshot_embeddings.pt')\n",
|
| 760 |
+
"# print('Saved.')"
|
| 761 |
+
],
|
| 762 |
+
"execution_count": null,
|
| 763 |
+
"outputs": []
|
| 764 |
+
}
|
| 765 |
+
]
|
| 766 |
+
}
|