gezi2333 commited on
Commit
3589275
·
verified ·
1 Parent(s): e40b7dd

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +207 -0
  3. assets/WVO-WD-TFS.png +3 -0
  4. assets/orthoreg_loss.png +3 -0
  5. environment.yml +140 -0
  6. src/__init__.py +0 -0
  7. src/__pycache__/__init__.cpython-310.pyc +0 -0
  8. src/__pycache__/args.cpython-310.pyc +0 -0
  9. src/__pycache__/attention_only_finetune.cpython-310.pyc +0 -0
  10. src/__pycache__/distributed.cpython-310.pyc +0 -0
  11. src/__pycache__/eval.cpython-310.pyc +0 -0
  12. src/__pycache__/heads.cpython-310.pyc +0 -0
  13. src/__pycache__/linearize.cpython-310.pyc +0 -0
  14. src/__pycache__/modeling.cpython-310.pyc +0 -0
  15. src/__pycache__/task_vectors.cpython-310.pyc +0 -0
  16. src/__pycache__/utils.cpython-310.pyc +0 -0
  17. src/args.py +153 -0
  18. src/attention_only_finetune.py +116 -0
  19. src/datasets/__pycache__/cars.cpython-310.pyc +0 -0
  20. src/datasets/__pycache__/cifar10.cpython-310.pyc +0 -0
  21. src/datasets/__pycache__/cifar100.cpython-310.pyc +0 -0
  22. src/datasets/__pycache__/common.cpython-310.pyc +0 -0
  23. src/datasets/__pycache__/dtd.cpython-310.pyc +0 -0
  24. src/datasets/__pycache__/emnist.cpython-310.pyc +0 -0
  25. src/datasets/__pycache__/eurosat.cpython-310.pyc +0 -0
  26. src/datasets/__pycache__/gtsrb.cpython-310.pyc +0 -0
  27. src/datasets/__pycache__/imagenet.cpython-310.pyc +0 -0
  28. src/datasets/__pycache__/kmnist.cpython-310.pyc +0 -0
  29. src/datasets/__pycache__/mnist.cpython-310.pyc +0 -0
  30. src/datasets/__pycache__/oxfordpets.cpython-310.pyc +0 -0
  31. src/datasets/__pycache__/registry.cpython-310.pyc +0 -0
  32. src/datasets/__pycache__/resisc45.cpython-310.pyc +0 -0
  33. src/datasets/__pycache__/stl10.cpython-310.pyc +0 -0
  34. src/datasets/__pycache__/sun397.cpython-310.pyc +0 -0
  35. src/datasets/__pycache__/svhn.cpython-310.pyc +0 -0
  36. src/datasets/__pycache__/templates.cpython-310.pyc +0 -0
  37. src/datasets/cars.py +155 -0
  38. src/datasets/cifar10.py +56 -0
  39. src/datasets/cifar100.py +30 -0
  40. src/datasets/common.py +139 -0
  41. src/datasets/dtd.py +34 -0
  42. src/datasets/emnist.py +74 -0
  43. src/datasets/eurosat.py +75 -0
  44. src/datasets/gtsrb.py +205 -0
  45. src/datasets/imagenet.py +253 -0
  46. src/datasets/kmnist.py +39 -0
  47. src/datasets/mnist.py +41 -0
  48. src/datasets/oxfordpets.py +38 -0
  49. src/datasets/registry.py +103 -0
  50. src/datasets/resisc45.py +304 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/WVO-WD-TFS.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/orthoreg_loss.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Understanding and Enforcing Weight Disentanglement in Task Arithmetic
2
+
3
+ [CVPR 2026] Official code of the paper **"Understanding and Enforcing Weight Disentanglement in Task Arithmetic"**.
4
+
5
+ [[Paper](https://arxiv.org/abs/2604.17078)]   [[Checkpoints](#-checkpoints)]   [[Datasets](#-datasets)]
6
+
7
+ ---
8
+
9
+ ## 🎯 Abstract
10
+
11
+ Task arithmetic provides an efficient, training-free way to edit pre-trained models, yet lacks a fundamental theoretical explanation for its success. The existing concept of "weight disentanglement" describes the ideal outcome of non-interfering task composition but does not reveal its underlying cause. Crucially, what intrinsic properties of the pre-trained model ($\theta_0$) or the task vectors ($\tau_t$) enable this disentanglement remains underexplored. In this paper, we introduce Task-Feature Specialization (TFS), a model's ability to allocate distinct internal features to different tasks, as the fundamental principle. We first prove that TFS is a sufficient condition for weight disentanglement. More importantly, we find that TFS also gives rise to an observable geometric consequence: weight vector orthogonality. This positions TFS as the common cause for both the desired functional outcome (disentanglement) and a measurable geometric property (orthogonality). This relationship provides the key insight for our method: since the abstract TFS property is intractable to enforce directly, we can instead promote weight disentanglement by shaping its concrete geometric consequence, orthogonality. Therefore, we propose OrthoReg, a simple and effective regularization method that actively enforces an internal orthogonal structure on weight updates ($\Delta W$) that constitute $\tau_t$ during fine-tuning. And we theoretically prove that OrthoReg promotes disentanglement. Extensive experiments demonstrate that OrthoReg consistently and significantly enhances the performance of various task arithmetic methods.
12
+
13
+ <p align="center">
14
+ <img src="assets/WVO-WD-TFS.png" width="500"/>
15
+ <br>
16
+ <em>TFS is the common cause connecting Weight Vector Orthogonality (WVO) with Weight Disentanglement (WD).</em>
17
+ </p>
18
+
19
+ ### ✨ Key Contributions
20
+
21
+ - 📐 **Theory**: We identify TFS as a sufficient condition for weight disentanglement, and WVO as its geometric consequence, providing the first principled explanation for task arithmetic.
22
+ - 🔧 **Method (OrthoReg)**: A simple regularization term added to the fine-tuning loss that enforces column-wise orthogonality on ΔW, for which we prove theoretical efficacy.
23
+ - 🔗 **Connection to TTA**: We show that OrthoReg and Tangent Task Arithmetic (TTA) share the same underlying mechanism (i.e. inter-task vector orthogonality), but OrthoReg achieves this more efficiently.
24
+ - 📊 **Experiments**: Consistent and significant improvements over Non-linear FT, TTA, ATT-FT, LoRA-ATT across ViT-B-32, ViT-B-16, and ViT-L-14.
25
+
26
+ ---
27
+
28
+ ### The OrthoReg Loss
29
+
30
+ <p align="center">
31
+ <img src="assets/orthoreg_loss.png" width="560"/>
32
+ </p>
33
+
34
+ The total loss adds a regularization term to the standard task objective:
35
+
36
+ $$\mathcal{L} = \mathcal{L}_{\text{task}}(\theta_0 + \Delta\theta) + \lambda \cdot \mathcal{L}_{\text{ortho}}(\Delta\theta)$$
37
+
38
+ $$\mathcal{L}_{\text{ortho}}(\Delta\theta) = \sum_l \left\|(\Delta W^{(l)})^\top \Delta W^{(l)} - I\right\|_F^2$$
39
+
40
+ ---
41
+
42
+ ## 🛠️ Installation
43
+
44
+ This codebase is built on top of [Tangent Task Arithmetic (TTA)](https://github.com/gortizji/tangent_task_arithmetic). Environment setup follows theirs exactly.
45
+
46
+
47
+ To run the code, please install all its dependencies:
48
+ ```sh
49
+ conda env create
50
+ conda activate tangent-arithmetic
51
+ ```
52
+ and add the `src` directory to the `PYTHONPATH`:
53
+ ```sh
54
+ cd OrthoReg
55
+ export PYTHONPATH="$PYTHONPATH:$PWD"
56
+ ```
57
+
58
+ ---
59
+
60
+ ## 📦 Datasets
61
+
62
+ We evaluate on 8 image classification benchmarks following [Task Arithmetic](https://github.com/mlfoundations/task_vectors) and [TTA](https://github.com/gortizji/tangent_task_arithmetic):
63
+
64
+ **Cars · DTD · EuroSAT · GTSRB · MNIST · RESISC45 · SUN397 · SVHN**
65
+
66
+ For dataset download and preparation, please follow the instructions in the [TTA repository](https://github.com/gortizji/tangent_task_arithmetic#datasets).
67
+
68
+ We also provide a pre-packaged dataset archive for convenience:
69
+
70
+ > 📥 **Dataset Download:** `https://pan.baidu.com/s/1PgLyjUrAhsmgSAz4ms5mcQ?pwd=fwf5`
71
+
72
+ Set the root path via `--data-location /path/to/datasets/`.
73
+
74
+ ---
75
+
76
+ ## 🚀 Quick Start
77
+
78
+ All scripts are run from the `OrthoReg/` directory. This repository implements **6 finetuning modes**:
79
+
80
+ | `--finetuning-mode` | Description |
81
+ |---|---|
82
+ | `standard` | Non-linear full fine-tuning (baseline) |
83
+ | `standard_ortho` | Non-linear FT + OrthoReg |
84
+ | `linear` | TTA — tangent space fine-tuning (baseline) |
85
+ | `linear_ortho` | TTA + OrthoReg |
86
+ | `linear-2` | ATT-FT — attention-only fine-tuning (baseline) |
87
+ | `linear-2_ortho` | ATT-FT + OrthoReg |
88
+
89
+ > **Note on LoRA-ATT:** The LoRA-ATT and LoRA-ATT+OrthoReg results from the paper are implemented in a separate repository due to the complexity of patching OpenCLIP's fused QKV projection. Code will be released at: `https://github.com/lshangge/OrthoReg_lora`
90
+
91
+ ### Step 1 — Fine-tune
92
+
93
+ ```bash
94
+ python src/finetune.py \
95
+ --model ViT-B-32 \
96
+ --finetuning-mode standard_ortho \
97
+ --ortho-lambda 10 \
98
+ --lr 1e-5 \
99
+ --data-location /path/to/datasets/ \
100
+ ```
101
+
102
+ Switch between all six modes by changing `--finetuning-mode` and `--ortho-lambda`:
103
+
104
+ ```bash
105
+ --finetuning-mode standard --ortho-lambda 0 # Non-linear FT
106
+ --finetuning-mode standard_ortho --ortho-lambda xx # Non-linear FT + OrthoReg
107
+ --finetuning-mode linear --ortho-lambda 0 # TTA
108
+ --finetuning-mode linear_ortho --ortho-lambda xx # TTA + OrthoReg
109
+ --finetuning-mode linear-2 --ortho-lambda 0 # ATT-FT
110
+ --finetuning-mode linear-2_ortho --ortho-lambda xx # ATT-FT + OrthoReg
111
+ ```
112
+
113
+ Checkpoints are saved to:
114
+ - `checkpoints_{seed}/{mode}_{lr}_{model}/` — for baselines
115
+ - `checkpoints_{seed}/{mode}_{lr}_lambda{lambda}_{model}/` — for OrthoReg variants
116
+
117
+ ### Step 2 — Evaluate Single-Task Accuracy
118
+
119
+ ```bash
120
+ python src/eval_single_task.py \
121
+ --model ViT-B-32 \
122
+ --finetuning-mode standard_ortho \
123
+ --ortho-lambda 10 \
124
+ --lr 1e-5 \
125
+ --data-location /path/to/datasets/
126
+ ```
127
+
128
+ > Run `eval_single_task` with `--finetuning-mode none --ortho-lambda 0` first to generate `zeroshot_accuracies.json`, which is required as the reference for normalized accuracy in Steps 3–4.
129
+
130
+ ### Step 3 — Evaluate Task Addition
131
+
132
+ ```bash
133
+ python src/eval_task_addition.py \
134
+ --model ViT-B-32 \
135
+ --finetuning-mode standard_ortho \
136
+ --ortho-lambda 10 \
137
+ --lr 1e-5 \
138
+ --data-location /path/to/datasets/
139
+ ```
140
+
141
+ ### Step 4 — Evaluate Task Negation
142
+
143
+ ```bash
144
+ python src/eval_task_negation.py \
145
+ --model ViT-B-32 \
146
+ --finetuning-mode standard_ortho \
147
+ --ortho-lambda 10 \
148
+ --lr 1e-5 \
149
+ --data-location /path/to/datasets/
150
+ ```
151
+
152
+ ---
153
+
154
+ ## 🔧 Key Arguments
155
+
156
+ | Argument | Default | Description |
157
+ |---|:---:|---|
158
+ | `--model` | `ViT-B-32` | CLIP model architecture |
159
+ | `--finetuning-mode` | — | One of the 6 modes above |
160
+ | `--ortho-lambda` | `0.0` | OrthoReg strength λ; set to `0` for baselines |
161
+ | `--lr` | `1e-5` | Learning rate |
162
+ | `--seed` | `1993` | Random seed |
163
+ | `--world-size` | `1` | Number of GPUs (DDP) |
164
+ | `--data-location` | — | Dataset root directory |
165
+ | `--batch-size` | `128` | Batch size per GPU |
166
+
167
+ ---
168
+
169
+ ## 📁 Checkpoints
170
+
171
+ We release fine-tuned checkpoints for ViT-B-32, ViT-B-16, and ViT-L-14 on all 8 tasks, covering all 6 modes.
172
+
173
+ > 📥 **Checkpoint Download:** `https://huggingface.co/gezi2333/OrthoReg_checkpoints`
174
+
175
+ Unzip into `OrthoReg/checkpoints_{seed}/` and pass the corresponding `--seed`, `--lr`, and `--ortho-lambda` to the eval scripts to reproduce the paper's results directly.
176
+
177
+ ---
178
+
179
+ ## 📝 Citation
180
+
181
+ If you find this work useful, please cite:
182
+
183
+ ```bibtex
184
+ @inproceedings{liu2026orthoreg,
185
+ title = {Understanding and Enforcing Weight Disentanglement in Task Arithmetic},
186
+ author = {Liu, Shangge and Yin, Yuehan and Wang, Lei and Fan, Qi and
187
+ Shi, Yinghuan and Li, Wenbin and Gao, Yang and Tao, Dacheng},
188
+ booktitle = {CVPR},
189
+ year = {2026}
190
+ }
191
+ ```
192
+
193
+ ---
194
+
195
+ ## 📞 Contact
196
+
197
+ For questions or issues, please:
198
+
199
+ - Open an issue on GitHub
200
+ - Contact the authors at [lshangge@smail.nju.edu.cn]
201
+
202
+ ---
203
+
204
+ ## 📬 Acknowledgements
205
+
206
+ This codebase is built on top of [Task Arithmetic](https://github.com/mlfoundations/task_vectors), [Tangent Task Arithmetic](https://github.com/gortizji/tangent_task_arithmetic), and [Attention-Only Fine-tuning](https://github.com/kyrie-23/linear_task_arithmetic). We thank the authors for releasing their code.
207
+
assets/WVO-WD-TFS.png ADDED

Git LFS Details

  • SHA256: bc8a9efc76ecb495a5de03a98215606a8cbab5b38cdbb53ea5d2c2ed133e535a
  • Pointer size: 131 Bytes
  • Size of remote file: 150 kB
assets/orthoreg_loss.png ADDED

Git LFS Details

  • SHA256: d7974c4003a412cc9a5134f26b71cc58739f5cd1f48301ab0d3d428cb17ecb8e
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB
environment.yml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: tangent-arithmetic
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1
8
+ - _openmp_mutex=5.1
9
+ - blas=1.0
10
+ - brotlipy=0.7.0
11
+ - bzip2=1.0.8
12
+ - ca-certificates=2023.05.30
13
+ - certifi=2023.5.7
14
+ - cffi=1.15.1
15
+ - charset-normalizer=2.0.4
16
+ - cryptography=39.0.1
17
+ - cuda=11.6.1
18
+ - cuda-cccl=11.6.55
19
+ - cuda-command-line-tools=11.6.2
20
+ - cuda-compiler=11.6.2
21
+ - cuda-cudart=11.6.55
22
+ - cuda-cudart-dev=11.6.55
23
+ - cuda-cuobjdump=11.6.124
24
+ - cuda-cupti=11.6.124
25
+ - cuda-cuxxfilt=11.6.124
26
+ - cuda-driver-dev=11.6.55
27
+ - cuda-gdb=12.1.105
28
+ - cuda-libraries=11.6.1
29
+ - cuda-libraries-dev=11.6.1
30
+ - cuda-memcheck=11.8.86
31
+ - cuda-nsight=12.1.105
32
+ - cuda-nsight-compute=12.1.1
33
+ - cuda-nvcc=11.6.124
34
+ - cuda-nvdisasm=12.1.105
35
+ - cuda-nvml-dev=11.6.55
36
+ - cuda-nvprof=12.1.105
37
+ - cuda-nvprune=11.6.124
38
+ - cuda-nvrtc=11.6.124
39
+ - cuda-nvrtc-dev=11.6.124
40
+ - cuda-nvtx=11.6.124
41
+ - cuda-nvvp=12.1.105
42
+ - cuda-runtime=11.6.1
43
+ - cuda-samples=11.6.101
44
+ - cuda-sanitizer-api=12.1.105
45
+ - cuda-toolkit=11.6.1
46
+ - cuda-tools=11.6.1
47
+ - cuda-visual-tools=11.6.1
48
+ - ffmpeg=4.3
49
+ - freetype=2.12.1
50
+ - gds-tools=1.6.1.9
51
+ - giflib=5.2.1
52
+ - gmp=6.2.1
53
+ - gnutls=3.6.15
54
+ - idna=3.4
55
+ - intel-openmp=2023.1.0
56
+ - jpeg=9e
57
+ - lame=3.100
58
+ - lcms2=2.12
59
+ - ld_impl_linux-64=2.38
60
+ - lerc=3.0
61
+ - libcublas=11.9.2.110
62
+ - libcublas-dev=11.9.2.110
63
+ - libcufft=10.7.1.112
64
+ - libcufft-dev=10.7.1.112
65
+ - libcufile=1.6.1.9
66
+ - libcufile-dev=1.6.1.9
67
+ - libcurand=10.3.2.106
68
+ - libcurand-dev=10.3.2.106
69
+ - libcusolver=11.3.4.124
70
+ - libcusparse=11.7.2.124
71
+ - libcusparse-dev=11.7.2.124
72
+ - libdeflate=1.17
73
+ - libffi=3.4.4
74
+ - libgcc-ng=11.2.0
75
+ - libgomp=11.2.0
76
+ - libiconv=1.16
77
+ - libidn2=2.3.4
78
+ - libnpp=11.6.3.124
79
+ - libnpp-dev=11.6.3.124
80
+ - libnvjpeg=11.6.2.124
81
+ - libnvjpeg-dev=11.6.2.124
82
+ - libpng=1.6.39
83
+ - libstdcxx-ng=11.2.0
84
+ - libtasn1=4.19.0
85
+ - libtiff=4.5.0
86
+ - libunistring=0.9.10
87
+ - libuuid=1.41.5
88
+ - libwebp=1.2.4
89
+ - libwebp-base=1.2.4
90
+ - lz4-c=1.9.4
91
+ - mkl=2023.1.0
92
+ - mkl-service=2.4.0
93
+ - mkl_fft=1.3.6
94
+ - mkl_random=1.2.2
95
+ - ncurses=6.4
96
+ - nettle=3.7.3
97
+ - nsight-compute=2023.1.1.4
98
+ - numpy=1.24.3
99
+ - numpy-base=1.24.3
100
+ - openh264=2.1.1
101
+ - openssl=1.1.1t
102
+ - pillow=9.4.0
103
+ - pip=23.0.1
104
+ - pycparser=2.21
105
+ - pyopenssl=23.0.0
106
+ - pysocks=1.7.1
107
+ - python=3.10.11
108
+ - pytorch=1.13.1
109
+ - pytorch-cuda=11.6
110
+ - pytorch-mutex=1.0
111
+ - readline=8.2
112
+ - requests=2.29.0
113
+ - setuptools=67.8.0
114
+ - sqlite=3.41.2
115
+ - tbb=2021.8.0
116
+ - tk=8.6.12
117
+ - torchaudio=0.13.1
118
+ - torchvision=0.14.1
119
+ - typing_extensions=4.5.0
120
+ - tzdata=2023c
121
+ - urllib3=1.26.16
122
+ - wheel=0.38.4
123
+ - xz=5.4.2
124
+ - zlib=1.2.13
125
+ - zstd=1.5.5
126
+ - pip:
127
+ - filelock==3.12.0
128
+ - fsspec==2023.5.0
129
+ - ftfy==6.1.1
130
+ - huggingface-hub==0.15.1
131
+ - open-clip-torch==2.10.1
132
+ - packaging==23.1
133
+ - protobuf==3.20.3
134
+ - pyyaml==6.0
135
+ - regex==2023.6.3
136
+ - safetensors==0.3.1
137
+ - scipy==1.10.1
138
+ - sentencepiece==0.1.99
139
+ - timm==0.9.2
140
+ - wcwidth==0.2.6
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (139 Bytes). View file
 
src/__pycache__/args.cpython-310.pyc ADDED
Binary file (3.4 kB). View file
 
src/__pycache__/attention_only_finetune.cpython-310.pyc ADDED
Binary file (3.6 kB). View file
 
src/__pycache__/distributed.cpython-310.pyc ADDED
Binary file (1.19 kB). View file
 
src/__pycache__/eval.cpython-310.pyc ADDED
Binary file (3.38 kB). View file
 
src/__pycache__/heads.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
src/__pycache__/linearize.cpython-310.pyc ADDED
Binary file (6.29 kB). View file
 
src/__pycache__/modeling.cpython-310.pyc ADDED
Binary file (6.52 kB). View file
 
src/__pycache__/task_vectors.cpython-310.pyc ADDED
Binary file (7.75 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.95 kB). View file
 
src/args.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+
6
+
7
+ def parse_arguments():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument(
10
+ "--data-location",
11
+ type=str,
12
+ default=os.path.expanduser("/path/datasets/"),
13
+ help="The root directory for the datasets.",
14
+ )
15
+ parser.add_argument(
16
+ "--eval-datasets",
17
+ default=None,
18
+ type=lambda x: x.split(","),
19
+ help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ",
20
+ )
21
+ parser.add_argument(
22
+ "--train-dataset",
23
+ default=None,
24
+ type=lambda x: x.split(","),
25
+ help="Which dataset(s) to patch on.",
26
+ )
27
+ parser.add_argument(
28
+ "--exp_name",
29
+ type=str,
30
+ default=None,
31
+ help="Name of the experiment, for organization purposes only.",
32
+ )
33
+ parser.add_argument(
34
+ "--results-db",
35
+ type=str,
36
+ default=None,
37
+ help="Where to store the results, else does not store",
38
+ )
39
+ parser.add_argument(
40
+ "--model",
41
+ type=str,
42
+ default="ViT-B-32",
43
+ help="The type of model (e.g. RN50, ViT-B-32).",
44
+ )
45
+ parser.add_argument(
46
+ "--batch-size",
47
+ type=int,
48
+ default=128,
49
+ )
50
+ parser.add_argument(
51
+ "--num-grad-accumulation",
52
+ type=int,
53
+ default=1,
54
+ help="Number of gradient accumulation steps.",
55
+ )
56
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
57
+ parser.add_argument("--wd", type=float, default=0.1, help="Weight decay")
58
+ parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.")
59
+ parser.add_argument(
60
+ "--warmup_length",
61
+ type=int,
62
+ default=500,
63
+ )
64
+ parser.add_argument(
65
+ "--epochs",
66
+ type=int,
67
+ default=10,
68
+ )
69
+ parser.add_argument(
70
+ "--load",
71
+ type=lambda x: x.split(","),
72
+ default=None,
73
+ help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
74
+ )
75
+ parser.add_argument(
76
+ "--save",
77
+ type=str,
78
+ default=None,
79
+ help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
80
+ )
81
+ parser.add_argument(
82
+ "--cache-dir",
83
+ type=str,
84
+ default=None,
85
+ help="Directory for caching features and encoder",
86
+ )
87
+ parser.add_argument(
88
+ "--openclip-cachedir",
89
+ type=str,
90
+ default=os.path.expanduser("~/openclip-cachedir/open_clip"),
91
+ help="Directory for caching models from OpenCLIP",
92
+ )
93
+ parser.add_argument(
94
+ "--world-size",
95
+ type=int,
96
+ default=1,
97
+ help="Number of processes for distributed training.",
98
+ )
99
+ parser.add_argument(
100
+ "--checkpoint-every",
101
+ type=int,
102
+ default=-1,
103
+ help="How often to checkpoint the model.",
104
+ )
105
+ parser.add_argument(
106
+ "--port",
107
+ type=int,
108
+ default=12355,
109
+ help="Port for distributed training.",
110
+ )
111
+ parser.add_argument(
112
+ "--seed",
113
+ type=int,
114
+ default=1993,
115
+ help="Random seed.",
116
+ )
117
+ parser.add_argument(
118
+ "--finetuning-mode",
119
+ choices=["standard", "standard_ortho", "linear", "linear_ortho", "linear-2", "linear-2_ortho"],
120
+ help="Finetuning mode: standard/linear/linear-2 with optional ortho regularization.",
121
+ )
122
+ parser.add_argument(
123
+ "--n-eval-points",
124
+ type=int,
125
+ default=21,
126
+ help="Number of evaluation points used to find optimal coefficient in task arithmetic.",
127
+ )
128
+ parser.add_argument(
129
+ "--ortho-lambda",
130
+ type=float,
131
+ default=0.0,
132
+ help="Weight of the orthogonality regularization term. Default 0.0 means no regularization.",
133
+ )
134
+ parser.add_argument(
135
+ "--control_threshold",
136
+ type=float,
137
+ default=0.95,
138
+ help="Control dataset performance degradation threshold.",
139
+ )
140
+ parser.add_argument(
141
+ "--alpha",
142
+ type=float,
143
+ default=None,
144
+ help="Manually specify the scaling coefficient for task vectors. If None, it will be optimized on validation set.",
145
+ )
146
+
147
+ parsed_args = parser.parse_args()
148
+ parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
149
+
150
+ if parsed_args.load is not None and len(parsed_args.load) == 1:
151
+ parsed_args.load = parsed_args.load[0]
152
+
153
+ return parsed_args
src/attention_only_finetune.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from src.modeling import ImageEncoder
5
+ from src.utils import DotDict
6
+
7
+ class AttentionOnlyFinetuneEncoder(ImageEncoder):
8
+ """
9
+ A specialized ImageEncoder that fine-tunes only the attention module weights in the ViT.
10
+ Corresponds to the method described in Jin et al. (2025).
11
+ """
12
+ def __init__(self, args, keep_lang=False):
13
+ # 1. Call the parent constructor to build the full model as usual
14
+ super().__init__(args, keep_lang=keep_lang)
15
+
16
+ self.args = args
17
+
18
+ # 2. Freeze all model parameters
19
+ # print("Freezing all parameters of the model initially...")
20
+ for param in self.model.parameters():
21
+ param.requires_grad = False
22
+
23
+ # 3. Unfreeze only the Attention module weights (Wq, Wk, Wv, Wo)
24
+ # print("Unfreezing Attention module weights for fine-tuning...")
25
+ self._unfreeze_attention_weights(self.model.visual)
26
+
27
+ # 4. (Optional but recommended) Print trainable parameters for verification
28
+ # self._verify_trainable_params()
29
+
30
+ def _unfreeze_attention_weights(self, vit_model):
31
+ """
32
+ Iterate over all Transformer blocks and unfreeze the attention projection weights.
33
+ """
34
+ # Iterate over the model and unfreeze target parameters
35
+ for block in vit_model.transformer.resblocks:
36
+ # Unfreeze the combined input projection weight for Q, K, V
37
+ block.attn.in_proj_weight.requires_grad = True
38
+
39
+ # Unfreeze the output projection weight
40
+ block.attn.out_proj.weight.requires_grad = True
41
+
42
+ # Per the paper's ablation study, not fine-tuning biases yields better results; keep them frozen
43
+ # block.attn.in_proj_bias.requires_grad = True
44
+ # block.attn.out_proj.bias.requires_grad = True
45
+
46
+ def _verify_trainable_params(self):
47
+ """Print all trainable parameters for debugging and verification."""
48
+ print("="*80)
49
+ print("Trainable parameters in AttentionOnlyFinetuneEncoder:")
50
+ trainable_params_count = 0
51
+ for name, param in self.model.named_parameters():
52
+ if param.requires_grad:
53
+ print(f" - {name}")
54
+ trainable_params_count += param.numel()
55
+ print(f"Total trainable parameters: {trainable_params_count / 1e6:.2f}M")
56
+ print("="*80)
57
+
58
+ def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None):
59
+ """
60
+ Extended forward method to optionally compute and return the orthogonal loss.
61
+ Consistent with the logic implemented for standard_ortho.
62
+ """
63
+ # Original forward pass
64
+ features = self.model.encode_image(images)
65
+
66
+ # Return features directly if orthogonal loss is not needed
67
+ if not calculate_ortho_loss:
68
+ return features
69
+
70
+ # --- Compute orthogonal loss if requested ---
71
+ if pretrained_state_dict is None:
72
+ raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True")
73
+
74
+ ortho_loss = 0.0
75
+ # self.model is the open_clip model (e.g. ViT); iterate over its parameters
76
+ for name, p_finetuned in self.model.named_parameters():
77
+ # Only compute loss for trainable parameters with gradients
78
+ if p_finetuned.requires_grad and p_finetuned.dim() == 2:
79
+ if name in pretrained_state_dict:
80
+ p_pretrained = pretrained_state_dict[name].to(p_finetuned.device)
81
+
82
+ delta_W = p_finetuned - p_pretrained
83
+
84
+ # Compute orthogonal loss (W^T * W - I)
85
+ rows, cols = delta_W.shape
86
+ if rows < cols:
87
+ mat = delta_W @ delta_W.T
88
+ identity = torch.eye(rows, device=delta_W.device)
89
+ else:
90
+ mat = delta_W.T @ delta_W
91
+ identity = torch.eye(cols, device=delta_W.device)
92
+
93
+ ortho_loss += torch.norm(mat - identity, p='fro')
94
+
95
+ return features, ortho_loss
96
+
97
+ def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
98
+ # Ensure __call__ forwards all arguments
99
+ return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict)
100
+
101
+ def save(self, filename):
102
+ """Save model weights."""
103
+ # print(f"Saving AttentionOnlyFinetuneEncoder state_dict to {filename}")
104
+ if os.path.dirname(filename):
105
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
106
+ # Save only the state_dict; reconstruct the model on load
107
+ torch.save(self.state_dict(), filename)
108
+
109
+ @classmethod
110
+ def load(cls, filename, args):
111
+ """Load model from a state_dict."""
112
+ # print(f"Loading AttentionOnlyFinetuneEncoder from {filename}")
113
+ encoder = cls(args) # Create a new instance
114
+ state_dict = torch.load(filename, map_location='cpu')
115
+ encoder.load_state_dict(state_dict) # Load weights
116
+ return encoder
src/datasets/__pycache__/cars.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
src/datasets/__pycache__/cifar10.cpython-310.pyc ADDED
Binary file (2.15 kB). View file
 
src/datasets/__pycache__/cifar100.cpython-310.pyc ADDED
Binary file (924 Bytes). View file
 
src/datasets/__pycache__/common.cpython-310.pyc ADDED
Binary file (5.22 kB). View file
 
src/datasets/__pycache__/dtd.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
src/datasets/__pycache__/emnist.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
src/datasets/__pycache__/eurosat.cpython-310.pyc ADDED
Binary file (3.02 kB). View file
 
src/datasets/__pycache__/gtsrb.cpython-310.pyc ADDED
Binary file (7.68 kB). View file
 
src/datasets/__pycache__/imagenet.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
src/datasets/__pycache__/kmnist.cpython-310.pyc ADDED
Binary file (952 Bytes). View file
 
src/datasets/__pycache__/mnist.cpython-310.pyc ADDED
Binary file (947 Bytes). View file
 
src/datasets/__pycache__/oxfordpets.cpython-310.pyc ADDED
Binary file (986 Bytes). View file
 
src/datasets/__pycache__/registry.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
src/datasets/__pycache__/resisc45.cpython-310.pyc ADDED
Binary file (9.24 kB). View file
 
src/datasets/__pycache__/stl10.cpython-310.pyc ADDED
Binary file (976 Bytes). View file
 
src/datasets/__pycache__/sun397.cpython-310.pyc ADDED
Binary file (1.41 kB). View file
 
src/datasets/__pycache__/svhn.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
src/datasets/__pycache__/templates.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
src/datasets/cars.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.datasets as datasets
4
+
5
+
6
+ import pathlib
7
+ from typing import Callable, Optional, Any, Tuple
8
+
9
+ from PIL import Image
10
+
11
+ from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg
12
+ from torchvision.datasets.vision import VisionDataset
13
+
14
+
15
+ class PytorchStanfordCars(VisionDataset):
16
+ """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
17
+
18
+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
19
+ split into 8,144 training images and 8,041 testing images, where each class
20
+ has been split roughly in a 50-50 split
21
+
22
+ .. note::
23
+
24
+ This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
25
+
26
+ Args:
27
+ root (string): Root directory of dataset
28
+ split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
29
+ transform (callable, optional): A function/transform that takes in an PIL image
30
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
31
+ target_transform (callable, optional): A function/transform that takes in the
32
+ target and transforms it.
33
+ download (bool, optional): If True, downloads the dataset from the internet and
34
+ puts it in root directory. If dataset is already downloaded, it is not
35
+ downloaded again."""
36
+
37
+ def __init__(
38
+ self,
39
+ root: str,
40
+ split: str = "train",
41
+ transform: Optional[Callable] = None,
42
+ target_transform: Optional[Callable] = None,
43
+ download: bool = False,
44
+ ) -> None:
45
+
46
+ try:
47
+ import scipy.io as sio
48
+ except ImportError:
49
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
50
+
51
+ super().__init__(root, transform=transform, target_transform=target_transform)
52
+
53
+ self._split = verify_str_arg(split, "split", ("train", "test"))
54
+ self._base_folder = pathlib.Path(root) / "stanford_cars"
55
+ devkit = self._base_folder / "devkit"
56
+
57
+ if self._split == "train":
58
+ self._annotations_mat_path = devkit / "cars_train_annos.mat"
59
+ self._images_base_path = self._base_folder / "cars_train"
60
+ else:
61
+ self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
62
+ self._images_base_path = self._base_folder / "cars_test"
63
+
64
+ if download:
65
+ self.download()
66
+
67
+ if not self._check_exists():
68
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
69
+
70
+ self._samples = [
71
+ (
72
+ str(self._images_base_path / annotation["fname"]),
73
+ annotation["class"] - 1, # Original target mapping starts from 1, hence -1
74
+ )
75
+ for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
76
+ ]
77
+
78
+ self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
79
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
80
+
81
+ def __len__(self) -> int:
82
+ return len(self._samples)
83
+
84
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
85
+ """Returns pil_image and class_id for given index"""
86
+ image_path, target = self._samples[idx]
87
+ pil_image = Image.open(image_path).convert("RGB")
88
+
89
+ if self.transform is not None:
90
+ pil_image = self.transform(pil_image)
91
+ if self.target_transform is not None:
92
+ target = self.target_transform(target)
93
+ return pil_image, target
94
+
95
+
96
+ def download(self) -> None:
97
+ if self._check_exists():
98
+ return
99
+
100
+ download_and_extract_archive(
101
+ url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
102
+ download_root=str(self._base_folder),
103
+ md5="c3b158d763b6e2245038c8ad08e45376",
104
+ )
105
+ if self._split == "train":
106
+ download_and_extract_archive(
107
+ url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
108
+ download_root=str(self._base_folder),
109
+ md5="065e5b463ae28d29e77c1b4b166cfe61",
110
+ )
111
+ else:
112
+ download_and_extract_archive(
113
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
114
+ download_root=str(self._base_folder),
115
+ md5="4ce7ebf6a94d07f1952d94dd34c4d501",
116
+ )
117
+ download_url(
118
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
119
+ root=str(self._base_folder),
120
+ md5="b0a2b23655a3edd16d84508592a98d10",
121
+ )
122
+
123
+ def _check_exists(self) -> bool:
124
+ if not (self._base_folder / "devkit").is_dir():
125
+ return False
126
+
127
+ return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
128
+
129
+
130
+ class Cars:
131
+ def __init__(self,
132
+ preprocess,
133
+ location=os.path.expanduser('~/data'),
134
+ batch_size=32,
135
+ num_workers=16):
136
+ # Data loading code
137
+
138
+ self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=False)
139
+ self.train_loader = torch.utils.data.DataLoader(
140
+ self.train_dataset,
141
+ shuffle=True,
142
+ batch_size=batch_size,
143
+ num_workers=num_workers,
144
+ )
145
+
146
+ self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=False)
147
+ self.test_loader = torch.utils.data.DataLoader(
148
+ self.test_dataset,
149
+ batch_size=batch_size,
150
+ num_workers=num_workers
151
+ )
152
+ idx_to_class = dict((v, k)
153
+ for k, v in self.train_dataset.class_to_idx.items())
154
+ self.classnames = [idx_to_class[i].replace(
155
+ '_', ' ') for i in range(len(idx_to_class))]
src/datasets/cifar10.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PIL
3
+ import torch
4
+ import numpy as np
5
+ import torchvision
6
+ from torchvision import transforms
7
+ from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
8
+ from torchvision.datasets import VisionDataset
9
+
10
+ cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
11
+
12
+ class CIFAR10:
13
+ def __init__(self, preprocess,
14
+ location=os.path.expanduser('~/data'),
15
+ batch_size=128,
16
+ num_workers=16):
17
+
18
+
19
+ self.train_dataset = PyTorchCIFAR10(
20
+ root=location, download=True, train=True, transform=preprocess
21
+ )
22
+
23
+ self.train_loader = torch.utils.data.DataLoader(
24
+ self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
25
+ )
26
+
27
+ self.test_dataset = PyTorchCIFAR10(
28
+ root=location, download=True, train=False, transform=preprocess
29
+ )
30
+
31
+ self.test_loader = torch.utils.data.DataLoader(
32
+ self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
33
+ )
34
+
35
+ self.classnames = self.test_dataset.classes
36
+
37
+ def convert(x):
38
+ if isinstance(x, np.ndarray):
39
+ return torchvision.transforms.functional.to_pil_image(x)
40
+ return x
41
+
42
+ class BasicVisionDataset(VisionDataset):
43
+ def __init__(self, images, targets, transform=None, target_transform=None):
44
+ if transform is not None:
45
+ transform.transforms.insert(0, convert)
46
+ super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform)
47
+ assert len(images) == len(targets)
48
+
49
+ self.images = images
50
+ self.targets = targets
51
+
52
+ def __getitem__(self, index):
53
+ return self.transform(self.images[index]), self.targets[index]
54
+
55
+ def __len__(self):
56
+ return len(self.targets)
src/datasets/cifar100.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision.datasets import CIFAR100 as PyTorchCIFAR100
4
+
5
+ class CIFAR100:
6
+ def __init__(self,
7
+ preprocess,
8
+ location=os.path.expanduser('~/data'),
9
+ batch_size=128,
10
+ num_workers=16):
11
+
12
+ self.train_dataset = PyTorchCIFAR100(
13
+ root=location, download=True, train=True, transform=preprocess
14
+ )
15
+
16
+ self.train_loader = torch.utils.data.DataLoader(
17
+ self.train_dataset, batch_size=batch_size, num_workers=num_workers
18
+ )
19
+
20
+ self.test_dataset = PyTorchCIFAR100(
21
+ root=location, download=True, train=False, transform=preprocess
22
+ )
23
+
24
+ self.test_loader = torch.utils.data.DataLoader(
25
+ self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
26
+ )
27
+
28
+ self.classnames = self.test_dataset.classes
29
+
30
+
src/datasets/common.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import glob
5
+ import collections
6
+ import random
7
+
8
+ import numpy as np
9
+
10
+ from tqdm import tqdm
11
+
12
+ import torchvision.datasets as datasets
13
+ from torch.utils.data import Dataset, DataLoader, Sampler
14
+
15
+
16
+ class SubsetSampler(Sampler):
17
+ def __init__(self, indices):
18
+ self.indices = indices
19
+
20
+ def __iter__(self):
21
+ return (i for i in self.indices)
22
+
23
+ def __len__(self):
24
+ return len(self.indices)
25
+
26
+ class ImageFolderWithPaths(datasets.ImageFolder):
27
+ def __init__(self, path, transform, flip_label_prob=0.0):
28
+ super().__init__(path, transform)
29
+ self.flip_label_prob = flip_label_prob
30
+ if self.flip_label_prob > 0:
31
+ print(f'Flipping labels with probability {self.flip_label_prob}')
32
+ num_classes = len(self.classes)
33
+ for i in range(len(self.samples)):
34
+ if random.random() < self.flip_label_prob:
35
+ new_label = random.randint(0, num_classes-1)
36
+ self.samples[i] = (
37
+ self.samples[i][0],
38
+ new_label
39
+ )
40
+
41
+ def __getitem__(self, index):
42
+ image, label = super(ImageFolderWithPaths, self).__getitem__(index)
43
+ return {
44
+ 'images': image,
45
+ 'labels': label,
46
+ 'image_paths': self.samples[index][0]
47
+ }
48
+
49
+
50
+ def maybe_dictionarize(batch):
51
+ if isinstance(batch, dict):
52
+ return batch
53
+
54
+ if len(batch) == 2:
55
+ batch = {'images': batch[0], 'labels': batch[1]}
56
+ elif len(batch) == 3:
57
+ batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
58
+ else:
59
+ raise ValueError(f'Unexpected number of elements: {len(batch)}')
60
+
61
+ return batch
62
+
63
+
64
+ def get_features_helper(image_encoder, dataloader, device):
65
+ all_data = collections.defaultdict(list)
66
+
67
+ image_encoder = image_encoder.to(device)
68
+ image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
69
+ image_encoder.eval()
70
+
71
+ with torch.no_grad():
72
+ for batch in tqdm(dataloader):
73
+ batch = maybe_dictionarize(batch)
74
+ features = image_encoder(batch['images'].cuda())
75
+
76
+ all_data['features'].append(features.cpu())
77
+
78
+ for key, val in batch.items():
79
+ if key == 'images':
80
+ continue
81
+ if hasattr(val, 'cpu'):
82
+ val = val.cpu()
83
+ all_data[key].append(val)
84
+ else:
85
+ all_data[key].extend(val)
86
+
87
+ for key, val in all_data.items():
88
+ if torch.is_tensor(val[0]):
89
+ all_data[key] = torch.cat(val).numpy()
90
+
91
+ return all_data
92
+
93
+
94
+ def get_features(is_train, image_encoder, dataset, device):
95
+ split = 'train' if is_train else 'val'
96
+ dname = type(dataset).__name__
97
+ if image_encoder.cache_dir is not None:
98
+ cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
99
+ cached_files = glob.glob(f'{cache_dir}/*')
100
+ if image_encoder.cache_dir is not None and len(cached_files) > 0:
101
+ print(f'Getting features from {cache_dir}')
102
+ data = {}
103
+ for cached_file in cached_files:
104
+ name = os.path.splitext(os.path.basename(cached_file))[0]
105
+ data[name] = torch.load(cached_file)
106
+ else:
107
+ print(f'Did not find cached features at {cache_dir}. Building from scratch.')
108
+ loader = dataset.train_loader if is_train else dataset.test_loader
109
+ data = get_features_helper(image_encoder, loader, device)
110
+ if image_encoder.cache_dir is None:
111
+ print('Not caching because no cache directory was passed.')
112
+ else:
113
+ os.makedirs(cache_dir, exist_ok=True)
114
+ print(f'Caching data at {cache_dir}')
115
+ for name, val in data.items():
116
+ torch.save(val, f'{cache_dir}/{name}.pt')
117
+ return data
118
+
119
+
120
+ class FeatureDataset(Dataset):
121
+ def __init__(self, is_train, image_encoder, dataset, device):
122
+ self.data = get_features(is_train, image_encoder, dataset, device)
123
+
124
+ def __len__(self):
125
+ return len(self.data['features'])
126
+
127
+ def __getitem__(self, idx):
128
+ data = {k: v[idx] for k, v in self.data.items()}
129
+ data['features'] = torch.from_numpy(data['features']).float()
130
+ return data
131
+
132
+
133
+ def get_dataloader(dataset, is_train, args, image_encoder=None):
134
+ if image_encoder is not None:
135
+ feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device)
136
+ dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train)
137
+ else:
138
+ dataloader = dataset.train_loader if is_train else dataset.test_loader
139
+ return dataloader
src/datasets/dtd.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.datasets as datasets
4
+
5
+
6
+ class DTD:
7
+ def __init__(self,
8
+ preprocess,
9
+ location=os.path.expanduser('~/data'),
10
+ batch_size=32,
11
+ num_workers=16):
12
+ # Data loading code
13
+ traindir = os.path.join(location, 'dtd', 'train')
14
+ valdir = os.path.join(location, 'dtd', 'val')
15
+
16
+ self.train_dataset = datasets.ImageFolder(
17
+ traindir, transform=preprocess)
18
+ self.train_loader = torch.utils.data.DataLoader(
19
+ self.train_dataset,
20
+ shuffle=True,
21
+ batch_size=batch_size,
22
+ num_workers=num_workers,
23
+ )
24
+
25
+ self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess)
26
+ self.test_loader = torch.utils.data.DataLoader(
27
+ self.test_dataset,
28
+ batch_size=batch_size,
29
+ num_workers=num_workers
30
+ )
31
+ idx_to_class = dict((v, k)
32
+ for k, v in self.train_dataset.class_to_idx.items())
33
+ self.classnames = [idx_to_class[i].replace(
34
+ '_', ' ') for i in range(len(idx_to_class))]
src/datasets/emnist.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ import torchvision
6
+ import torchvision.datasets as datasets
7
+
8
+
9
+ def rotate_img(img):
10
+ return torchvision.transforms.functional.rotate(img, -90)
11
+
12
+
13
+ def flip_img(img):
14
+ return torchvision.transforms.functional.hflip(img)
15
+
16
+
17
+ def emnist_preprocess():
18
+ return torchvision.transforms.Compose(
19
+ [
20
+ rotate_img,
21
+ flip_img,
22
+ ]
23
+ )
24
+
25
+
26
+ class EMNIST:
27
+ def __init__(
28
+ self,
29
+ preprocess,
30
+ location,
31
+ batch_size=128,
32
+ num_workers=8,
33
+ ):
34
+ preprocess1 = emnist_preprocess()
35
+ preprocess = torchvision.transforms.Compose(
36
+ [
37
+ preprocess,
38
+ preprocess1,
39
+ ]
40
+ )
41
+ # if not os.path.exists(location):
42
+ # os.makedirs(location, exist_ok=True)
43
+
44
+ self.train_dataset = datasets.EMNIST(
45
+ root=location,
46
+ download=True,
47
+ split="digits",
48
+ transform=preprocess,
49
+ train=True,
50
+ )
51
+
52
+ self.train_loader = torch.utils.data.DataLoader(
53
+ self.train_dataset,
54
+ batch_size=batch_size,
55
+ shuffle=True,
56
+ num_workers=num_workers,
57
+ )
58
+
59
+ self.test_dataset = datasets.EMNIST(
60
+ root=location,
61
+ download=True,
62
+ split="digits",
63
+ transform=preprocess,
64
+ train=False,
65
+ )
66
+
67
+ self.test_loader = torch.utils.data.DataLoader(
68
+ self.test_dataset,
69
+ batch_size=32,
70
+ shuffle=False,
71
+ num_workers=num_workers,
72
+ )
73
+
74
+ self.classnames = self.train_dataset.classes
src/datasets/eurosat.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.datasets as datasets
4
+ import re
5
+
6
+ def pretify_classname(classname):
7
+ l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname)
8
+ l = [i.lower() for i in l]
9
+ out = ' '.join(l)
10
+ if out.endswith('al'):
11
+ return out + ' area'
12
+ return out
13
+
14
+ class EuroSATBase:
15
+ def __init__(self,
16
+ preprocess,
17
+ test_split,
18
+ location='~/datasets',
19
+ batch_size=32,
20
+ num_workers=16):
21
+ # Data loading code
22
+ traindir = os.path.join(location, 'EuroSAT_splits', 'train')
23
+ testdir = os.path.join(location, 'EuroSAT_splits', test_split)
24
+
25
+
26
+ self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess)
27
+ self.train_loader = torch.utils.data.DataLoader(
28
+ self.train_dataset,
29
+ shuffle=True,
30
+ batch_size=batch_size,
31
+ num_workers=num_workers,
32
+ )
33
+
34
+ self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess)
35
+ self.test_loader = torch.utils.data.DataLoader(
36
+ self.test_dataset,
37
+ batch_size=batch_size,
38
+ num_workers=num_workers
39
+ )
40
+ idx_to_class = dict((v, k)
41
+ for k, v in self.train_dataset.class_to_idx.items())
42
+ self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))]
43
+ self.classnames = [pretify_classname(c) for c in self.classnames]
44
+ ours_to_open_ai = {
45
+ 'annual crop': 'annual crop land',
46
+ 'forest': 'forest',
47
+ 'herbaceous vegetation': 'brushland or shrubland',
48
+ 'highway': 'highway or road',
49
+ 'industrial area': 'industrial buildings or commercial buildings',
50
+ 'pasture': 'pasture land',
51
+ 'permanent crop': 'permanent crop land',
52
+ 'residential area': 'residential buildings or homes or apartments',
53
+ 'river': 'river',
54
+ 'sea lake': 'lake or sea',
55
+ }
56
+ for i in range(len(self.classnames)):
57
+ self.classnames[i] = ours_to_open_ai[self.classnames[i]]
58
+
59
+
60
+ class EuroSAT(EuroSATBase):
61
+ def __init__(self,
62
+ preprocess,
63
+ location='~/datasets',
64
+ batch_size=32,
65
+ num_workers=16):
66
+ super().__init__(preprocess, 'test', location, batch_size, num_workers)
67
+
68
+
69
+ class EuroSATVal(EuroSATBase):
70
+ def __init__(self,
71
+ preprocess,
72
+ location='~/datasets',
73
+ batch_size=32,
74
+ num_workers=16):
75
+ super().__init__(preprocess, 'val', location, batch_size, num_workers)
src/datasets/gtsrb.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import pathlib
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ from torchvision.datasets.folder import make_dataset
10
+ from torchvision.datasets.utils import (download_and_extract_archive,
11
+ verify_str_arg)
12
+ from torchvision.datasets.vision import VisionDataset
13
+
14
+ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
15
+ """Finds the class folders in a dataset.
16
+
17
+ See :class:`DatasetFolder` for details.
18
+ """
19
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
20
+ if not classes:
21
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
22
+
23
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
24
+ return classes, class_to_idx
25
+
26
+ class PyTorchGTSRB(VisionDataset):
27
+ """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
28
+
29
+ Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB.
30
+
31
+ Args:
32
+ root (string): Root directory of the dataset.
33
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
34
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
35
+ version. E.g, ``transforms.RandomCrop``.
36
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
37
+ download (bool, optional): If True, downloads the dataset from the internet and
38
+ puts it in root directory. If dataset is already downloaded, it is not
39
+ downloaded again.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ root: str,
45
+ split: str = "train",
46
+ transform: Optional[Callable] = None,
47
+ target_transform: Optional[Callable] = None,
48
+ download: bool = False,
49
+ ) -> None:
50
+
51
+ super().__init__(root, transform=transform, target_transform=target_transform)
52
+
53
+ self._split = verify_str_arg(split, "split", ("train", "test"))
54
+ self._base_folder = pathlib.Path(root) / "gtsrb"
55
+ self._target_folder = (
56
+ self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
57
+ )
58
+
59
+ if download:
60
+ self.download()
61
+
62
+ if not self._check_exists():
63
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
64
+
65
+ if self._split == "train":
66
+ _, class_to_idx = find_classes(str(self._target_folder))
67
+ samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx)
68
+ else:
69
+ with open(self._base_folder / "GT-final_test.csv") as csv_file:
70
+ samples = [
71
+ (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
72
+ for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
73
+ ]
74
+
75
+ self._samples = samples
76
+ self.transform = transform
77
+ self.target_transform = target_transform
78
+
79
+ def __len__(self) -> int:
80
+ return len(self._samples)
81
+
82
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
83
+
84
+ path, target = self._samples[index]
85
+ sample = PIL.Image.open(path).convert("RGB")
86
+
87
+ if self.transform is not None:
88
+ sample = self.transform(sample)
89
+
90
+ if self.target_transform is not None:
91
+ target = self.target_transform(target)
92
+
93
+ return sample, target
94
+
95
+
96
+ def _check_exists(self) -> bool:
97
+ return self._target_folder.is_dir()
98
+
99
+ def download(self) -> None:
100
+ if self._check_exists():
101
+ return
102
+
103
+ base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
104
+
105
+ if self._split == "train":
106
+ download_and_extract_archive(
107
+ f"{base_url}GTSRB-Training_fixed.zip",
108
+ download_root=str(self._base_folder),
109
+ md5="513f3c79a4c5141765e10e952eaa2478",
110
+ )
111
+ else:
112
+ download_and_extract_archive(
113
+ f"{base_url}GTSRB_Final_Test_Images.zip",
114
+ download_root=str(self._base_folder),
115
+ md5="c7e4e6327067d32654124b0fe9e82185",
116
+ )
117
+ download_and_extract_archive(
118
+ f"{base_url}GTSRB_Final_Test_GT.zip",
119
+ download_root=str(self._base_folder),
120
+ md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
121
+ )
122
+
123
+
124
+ class GTSRB:
125
+ def __init__(self,
126
+ preprocess,
127
+ location=os.path.expanduser('~/data'),
128
+ batch_size=128,
129
+ num_workers=16):
130
+
131
+ # to fit with repo conventions for location
132
+ self.train_dataset = PyTorchGTSRB(
133
+ root=location,
134
+ download=True,
135
+ split='train',
136
+ transform=preprocess
137
+ )
138
+
139
+ self.train_loader = torch.utils.data.DataLoader(
140
+ self.train_dataset,
141
+ batch_size=batch_size,
142
+ shuffle=True,
143
+ num_workers=num_workers
144
+ )
145
+
146
+ self.test_dataset = PyTorchGTSRB(
147
+ root=location,
148
+ download=True,
149
+ split='test',
150
+ transform=preprocess
151
+ )
152
+
153
+ self.test_loader = torch.utils.data.DataLoader(
154
+ self.test_dataset,
155
+ batch_size=batch_size,
156
+ shuffle=False,
157
+ num_workers=num_workers
158
+ )
159
+
160
+ # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md
161
+ self.classnames = [
162
+ 'red and white circle 20 kph speed limit',
163
+ 'red and white circle 30 kph speed limit',
164
+ 'red and white circle 50 kph speed limit',
165
+ 'red and white circle 60 kph speed limit',
166
+ 'red and white circle 70 kph speed limit',
167
+ 'red and white circle 80 kph speed limit',
168
+ 'end / de-restriction of 80 kph speed limit',
169
+ 'red and white circle 100 kph speed limit',
170
+ 'red and white circle 120 kph speed limit',
171
+ 'red and white circle red car and black car no passing',
172
+ 'red and white circle red truck and black car no passing',
173
+ 'red and white triangle road intersection warning',
174
+ 'white and yellow diamond priority road',
175
+ 'red and white upside down triangle yield right-of-way',
176
+ 'stop',
177
+ 'empty red and white circle',
178
+ 'red and white circle no truck entry',
179
+ 'red circle with white horizonal stripe no entry',
180
+ 'red and white triangle with exclamation mark warning',
181
+ 'red and white triangle with black left curve approaching warning',
182
+ 'red and white triangle with black right curve approaching warning',
183
+ 'red and white triangle with black double curve approaching warning',
184
+ 'red and white triangle rough / bumpy road warning',
185
+ 'red and white triangle car skidding / slipping warning',
186
+ 'red and white triangle with merging / narrow lanes warning',
187
+ 'red and white triangle with person digging / construction / road work warning',
188
+ 'red and white triangle with traffic light approaching warning',
189
+ 'red and white triangle with person walking warning',
190
+ 'red and white triangle with child and person walking warning',
191
+ 'red and white triangle with bicyle warning',
192
+ 'red and white triangle with snowflake / ice warning',
193
+ 'red and white triangle with deer warning',
194
+ 'white circle with gray strike bar no speed limit',
195
+ 'blue circle with white right turn arrow mandatory',
196
+ 'blue circle with white left turn arrow mandatory',
197
+ 'blue circle with white forward arrow mandatory',
198
+ 'blue circle with white forward or right turn arrow mandatory',
199
+ 'blue circle with white forward or left turn arrow mandatory',
200
+ 'blue circle with white keep right arrow mandatory',
201
+ 'blue circle with white keep left arrow mandatory',
202
+ 'blue circle with white arrows indicating a traffic circle',
203
+ 'white circle with gray strike bar indicating no passing for cars has ended',
204
+ 'white circle with gray strike bar indicating no passing for trucks has ended',
205
+ ]
src/datasets/imagenet.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from .common import ImageFolderWithPaths, SubsetSampler
5
+ import numpy as np
6
+
7
+
8
+ imagenet_classnames = [
9
+ "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
10
+ "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
11
+ "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
12
+ "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
13
+ "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
14
+ "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
15
+ "box turtle", "banded gecko", "green iguana", "Carolina anole",
16
+ "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
17
+ "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
18
+ "American alligator", "triceratops", "worm snake", "ring-necked snake",
19
+ "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
20
+ "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
21
+ "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
22
+ "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
23
+ "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
24
+ "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
25
+ "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
26
+ "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
27
+ "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
28
+ "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
29
+ "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
30
+ "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
31
+ "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
32
+ "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
33
+ "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
34
+ "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
35
+ "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
36
+ "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
37
+ "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
38
+ "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
39
+ "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
40
+ "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
41
+ "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
42
+ "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
43
+ "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
44
+ "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
45
+ "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
46
+ "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
47
+ "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
48
+ "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
49
+ "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
50
+ "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
51
+ "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
52
+ "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
53
+ "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
54
+ "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
55
+ "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
56
+ "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
57
+ "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
58
+ "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
59
+ "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
60
+ "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
61
+ "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
62
+ "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
63
+ "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
64
+ "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
65
+ "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
66
+ "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
67
+ "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
68
+ "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
69
+ "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
70
+ "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
71
+ "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
72
+ "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
73
+ "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
74
+ "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
75
+ "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
76
+ "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
77
+ "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
78
+ "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
79
+ "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
80
+ "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
81
+ "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
82
+ "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
83
+ "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
84
+ "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
85
+ "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
86
+ "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
87
+ "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
88
+ "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
89
+ "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
90
+ "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
91
+ "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
92
+ "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
93
+ "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
94
+ "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
95
+ "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
96
+ "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
97
+ "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
98
+ "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
99
+ "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
100
+ "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
101
+ "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
102
+ "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
103
+ "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
104
+ "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
105
+ "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
106
+ "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
107
+ "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
108
+ "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
109
+ "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
110
+ "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
111
+ "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
112
+ "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
113
+ "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
114
+ "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
115
+ "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
116
+ "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
117
+ "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
118
+ "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
119
+ "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
120
+ "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
121
+ "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
122
+ "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
123
+ "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
124
+ "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
125
+ "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
126
+ "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
127
+ "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
128
+ "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
129
+ "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
130
+ "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
131
+ "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
132
+ "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
133
+ "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
134
+ "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
135
+ "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
136
+ "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
137
+ "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
138
+ "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
139
+ "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
140
+ "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
141
+ "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
142
+ "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
143
+ "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
144
+ "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
145
+ "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
146
+ "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
147
+ "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
148
+ "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
149
+ "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
150
+ "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
151
+ "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
152
+ "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
153
+ "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
154
+ "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
155
+ "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
156
+ "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
157
+ "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
158
+ "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
159
+ "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
160
+ "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
161
+ "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
162
+ "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
163
+ "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
164
+ "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
165
+ "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
166
+ "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
167
+ "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
168
+ "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
169
+ "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
170
+ "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
171
+ "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
172
+ "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
173
+ "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"
174
+ ]
175
+
176
+ class ImageNet:
177
+ def __init__(self,
178
+ preprocess,
179
+ location='~/data',
180
+ batch_size=32,
181
+ num_workers=32):
182
+ self.preprocess = preprocess
183
+ self.location = '/path/ImageNet2012/' # TODO
184
+ self.batch_size = batch_size
185
+ self.num_workers = num_workers
186
+ self.classnames = imagenet_classnames
187
+
188
+ self.populate_train()
189
+ self.populate_test()
190
+
191
+ def populate_train(self):
192
+ traindir = os.path.join(self.location, 'train')
193
+ self.train_dataset = ImageFolderWithPaths(
194
+ traindir,
195
+ transform=self.preprocess)
196
+ sampler = self.get_train_sampler()
197
+ kwargs = {'shuffle' : True} if sampler is None else {}
198
+ self.train_loader = torch.utils.data.DataLoader(
199
+ self.train_dataset,
200
+ sampler=sampler,
201
+ batch_size=self.batch_size,
202
+ num_workers=self.num_workers,
203
+ **kwargs,
204
+ )
205
+
206
+ def populate_test(self):
207
+ self.test_dataset = self.get_test_dataset()
208
+ self.test_loader = torch.utils.data.DataLoader(
209
+ self.test_dataset,
210
+ batch_size=self.batch_size,
211
+ num_workers=self.num_workers,
212
+ sampler=self.get_test_sampler()
213
+ )
214
+
215
+ def get_test_path(self):
216
+ test_path = os.path.join(self.location, 'val_dir')
217
+ if not os.path.exists(test_path):
218
+ test_path = os.path.join(self.location,'val')
219
+ return test_path
220
+
221
+ def get_train_sampler(self):
222
+ return None
223
+
224
+ def get_test_sampler(self):
225
+ return None
226
+
227
+ def get_test_dataset(self):
228
+ return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
229
+
230
+ def name(self):
231
+ return 'imagenet'
232
+
233
+ class ImageNetTrain(ImageNet):
234
+
235
+ def get_test_dataset(self):
236
+ pass
237
+
238
+ class ImageNetK(ImageNet):
239
+
240
+ def get_train_sampler(self):
241
+ idxs = np.zeros(len(self.train_dataset.targets))
242
+ target_array = np.array(self.train_dataset.targets)
243
+ for c in range(1000):
244
+ m = target_array == c
245
+ n = len(idxs[m])
246
+ arr = np.zeros(n)
247
+ arr[:self.k()] = 1
248
+ np.random.shuffle(arr)
249
+ idxs[m] = arr
250
+
251
+ idxs = idxs.astype('int')
252
+ sampler = SubsetSampler(np.where(idxs)[0])
253
+ return sampler
src/datasets/kmnist.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torchvision.datasets as datasets
5
+
6
+
7
+ class KMNIST:
8
+ def __init__(
9
+ self,
10
+ preprocess,
11
+ location=os.path.expanduser("~/data"),
12
+ batch_size=128,
13
+ num_workers=6,
14
+ ):
15
+
16
+ location = os.path.join(location, "KMNIST")
17
+ self.train_dataset = datasets.KMNIST(
18
+ root=location, download=True, train=True, transform=preprocess
19
+ )
20
+
21
+ self.train_loader = torch.utils.data.DataLoader(
22
+ self.train_dataset,
23
+ batch_size=batch_size,
24
+ shuffle=True,
25
+ num_workers=num_workers,
26
+ )
27
+
28
+ self.test_dataset = datasets.KMNIST(
29
+ root=location, download=True, train=False, transform=preprocess
30
+ )
31
+
32
+ self.test_loader = torch.utils.data.DataLoader(
33
+ self.test_dataset,
34
+ batch_size=batch_size,
35
+ shuffle=False,
36
+ num_workers=num_workers,
37
+ )
38
+
39
+ self.classnames = self.train_dataset.classes
src/datasets/mnist.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.datasets as datasets
4
+
5
+ class MNIST:
6
+ def __init__(self,
7
+ preprocess,
8
+ location=os.path.expanduser('~/data'),
9
+ batch_size=128,
10
+ num_workers=16):
11
+
12
+
13
+ self.train_dataset = datasets.MNIST(
14
+ root=location,
15
+ download=True,
16
+ train=True,
17
+ transform=preprocess
18
+ )
19
+
20
+ self.train_loader = torch.utils.data.DataLoader(
21
+ self.train_dataset,
22
+ batch_size=batch_size,
23
+ shuffle=True,
24
+ num_workers=num_workers
25
+ )
26
+
27
+ self.test_dataset = datasets.MNIST(
28
+ root=location,
29
+ download=True,
30
+ train=False,
31
+ transform=preprocess
32
+ )
33
+
34
+ self.test_loader = torch.utils.data.DataLoader(
35
+ self.test_dataset,
36
+ batch_size=batch_size,
37
+ shuffle=False,
38
+ num_workers=num_workers
39
+ )
40
+
41
+ self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
src/datasets/oxfordpets.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.datasets as datasets
4
+
5
+
6
+ class OxfordIIITPet:
7
+ def __init__(
8
+ self,
9
+ preprocess,
10
+ location=os.path.expanduser("~/data"),
11
+ batch_size=128,
12
+ num_workers=6,
13
+ ):
14
+
15
+ location = os.path.join(location, "OxfordIIITPet")
16
+ self.train_dataset = datasets.OxfordIIITPet(
17
+ root=location, download=True, split="trainval", transform=preprocess
18
+ )
19
+
20
+ self.train_loader = torch.utils.data.DataLoader(
21
+ self.train_dataset,
22
+ batch_size=batch_size,
23
+ shuffle=True,
24
+ num_workers=num_workers,
25
+ )
26
+
27
+ self.test_dataset = datasets.OxfordIIITPet(
28
+ root=location, download=True, split="test", transform=preprocess
29
+ )
30
+
31
+ self.test_loader = torch.utils.data.DataLoader(
32
+ self.test_dataset,
33
+ batch_size=batch_size,
34
+ shuffle=False,
35
+ num_workers=num_workers,
36
+ )
37
+
38
+ self.classnames = self.train_dataset.classes
src/datasets/registry.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import inspect
3
+ import random
4
+ import torch
5
+ import copy
6
+
7
+ from torch.utils.data.dataset import random_split
8
+
9
+ from src.datasets.cars import Cars
10
+ from src.datasets.cifar10 import CIFAR10
11
+ from src.datasets.cifar100 import CIFAR100
12
+ from src.datasets.dtd import DTD
13
+ from src.datasets.eurosat import EuroSAT, EuroSATVal
14
+ from src.datasets.gtsrb import GTSRB
15
+ from src.datasets.imagenet import ImageNet
16
+ from src.datasets.mnist import MNIST
17
+ from src.datasets.resisc45 import RESISC45
18
+ from src.datasets.stl10 import STL10
19
+ from src.datasets.svhn import SVHN
20
+ from src.datasets.sun397 import SUN397
21
+ from src.datasets.emnist import EMNIST
22
+ from src.datasets.kmnist import KMNIST
23
+ from src.datasets.oxfordpets import OxfordIIITPet
24
+
25
+ registry = {
26
+ name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
27
+ }
28
+
29
+
30
+ class GenericDataset(object):
31
+ def __init__(self):
32
+ self.train_dataset = None
33
+ self.train_loader = None
34
+ self.test_dataset = None
35
+ self.test_loader = None
36
+ self.classnames = None
37
+
38
+
39
+ def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0):
40
+ assert val_fraction > 0. and val_fraction < 1.
41
+ total_size = len(dataset.train_dataset)
42
+ val_size = int(total_size * val_fraction)
43
+ if max_val_samples is not None:
44
+ val_size = min(val_size, max_val_samples)
45
+ train_size = total_size - val_size
46
+
47
+ assert val_size > 0
48
+ assert train_size > 0
49
+
50
+ lengths = [train_size, val_size]
51
+
52
+ trainset, valset = random_split(
53
+ dataset.train_dataset,
54
+ lengths,
55
+ generator=torch.Generator().manual_seed(seed)
56
+ )
57
+ if new_dataset_class_name == 'MNISTVal':
58
+ assert trainset.indices[0] == 36044
59
+
60
+
61
+ new_dataset = None
62
+
63
+ new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {})
64
+ new_dataset = new_dataset_class()
65
+
66
+ new_dataset.train_dataset = trainset
67
+ new_dataset.train_loader = torch.utils.data.DataLoader(
68
+ new_dataset.train_dataset,
69
+ shuffle=True,
70
+ batch_size=batch_size,
71
+ num_workers=num_workers,
72
+ )
73
+
74
+ new_dataset.test_dataset = valset
75
+ new_dataset.test_loader = torch.utils.data.DataLoader(
76
+ new_dataset.test_dataset,
77
+ batch_size=batch_size,
78
+ num_workers=num_workers
79
+ )
80
+
81
+ new_dataset.classnames = copy.copy(dataset.classnames)
82
+
83
+ return new_dataset
84
+
85
+
86
+ def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000):
87
+ if dataset_name.endswith('Val'):
88
+ # Handle val splits
89
+ if dataset_name in registry:
90
+ dataset_class = registry[dataset_name]
91
+ else:
92
+ base_dataset_name = dataset_name.split('Val')[0]
93
+ base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers)
94
+ dataset = split_train_into_train_val(
95
+ base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples)
96
+ return dataset
97
+ else:
98
+ assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}'
99
+ dataset_class = registry[dataset_name]
100
+ dataset = dataset_class(
101
+ preprocess, location=location, batch_size=batch_size, num_workers=num_workers
102
+ )
103
+ return dataset
src/datasets/resisc45.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ import abc
5
+ import os
6
+ from typing import Any, Callable, Dict, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.utils.data import Dataset
12
+ from torchvision.datasets import ImageFolder
13
+ from torchvision.datasets.folder import default_loader as pil_loader
14
+
15
+
16
+ # modified from: https://github.com/microsoft/torchgeo
17
+ class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
18
+ """Abstract base class for datasets lacking geospatial information.
19
+ This base class is designed for datasets with pre-defined image chips.
20
+ """
21
+
22
+ @abc.abstractmethod
23
+ def __getitem__(self, index: int) -> Dict[str, Any]:
24
+ """Return an index within the dataset.
25
+ Args:
26
+ index: index to return
27
+ Returns:
28
+ data and labels at that index
29
+ Raises:
30
+ IndexError: if index is out of range of the dataset
31
+ """
32
+
33
+ @abc.abstractmethod
34
+ def __len__(self) -> int:
35
+ """Return the length of the dataset.
36
+ Returns:
37
+ length of the dataset
38
+ """
39
+
40
+ def __str__(self) -> str:
41
+ """Return the informal string representation of the object.
42
+ Returns:
43
+ informal string representation
44
+ """
45
+ return f"""\
46
+ {self.__class__.__name__} Dataset
47
+ type: VisionDataset
48
+ size: {len(self)}"""
49
+
50
+
51
+ class VisionClassificationDataset(VisionDataset, ImageFolder):
52
+ """Abstract base class for classification datasets lacking geospatial information.
53
+ This base class is designed for datasets with pre-defined image chips which
54
+ are separated into separate folders per class.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ root: str,
60
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
61
+ loader: Optional[Callable[[str], Any]] = pil_loader,
62
+ is_valid_file: Optional[Callable[[str], bool]] = None,
63
+ ) -> None:
64
+ """Initialize a new VisionClassificationDataset instance.
65
+ Args:
66
+ root: root directory where dataset can be found
67
+ transforms: a function/transform that takes input sample and its target as
68
+ entry and returns a transformed version
69
+ loader: a callable function which takes as input a path to an image and
70
+ returns a PIL Image or numpy array
71
+ is_valid_file: A function that takes the path of an Image file and checks if
72
+ the file is a valid file
73
+ """
74
+ # When transform & target_transform are None, ImageFolder.__getitem__(index)
75
+ # returns a PIL.Image and int for image and label, respectively
76
+ super().__init__(
77
+ root=root,
78
+ transform=None,
79
+ target_transform=None,
80
+ loader=loader,
81
+ is_valid_file=is_valid_file,
82
+ )
83
+
84
+ # Must be set after calling super().__init__()
85
+ self.transforms = transforms
86
+
87
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
88
+ """Return an index within the dataset.
89
+ Args:
90
+ index: index to return
91
+ Returns:
92
+ data and label at that index
93
+ """
94
+ image, label = self._load_image(index)
95
+
96
+ if self.transforms is not None:
97
+ return self.transforms(image), label
98
+
99
+ return image, label
100
+
101
+ def __len__(self) -> int:
102
+ """Return the number of data points in the dataset.
103
+ Returns:
104
+ length of the dataset
105
+ """
106
+ return len(self.imgs)
107
+
108
+ def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
109
+ """Load a single image and it's class label.
110
+ Args:
111
+ index: index to return
112
+ Returns:
113
+ the image
114
+ the image class label
115
+ """
116
+ img, label = ImageFolder.__getitem__(self, index)
117
+ label = torch.tensor(label)
118
+ return img, label
119
+
120
+
121
+ class RESISC45Dataset(VisionClassificationDataset):
122
+ """RESISC45 dataset.
123
+ The `RESISC45 <http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html>`_
124
+ dataset is a dataset for remote sensing image scene classification.
125
+ Dataset features:
126
+ * 31,500 images with 0.2-30 m per pixel resolution (256x256 px)
127
+ * three spectral bands - RGB
128
+ * 45 scene classes, 700 images per class
129
+ * images extracted from Google Earth from over 100 countries
130
+ * images conditions with high variability (resolution, weather, illumination)
131
+ Dataset format:
132
+ * images are three-channel jpgs
133
+ Dataset classes:
134
+ 0. airplane
135
+ 1. airport
136
+ 2. baseball_diamond
137
+ 3. basketball_court
138
+ 4. beach
139
+ 5. bridge
140
+ 6. chaparral
141
+ 7. church
142
+ 8. circular_farmland
143
+ 9. cloud
144
+ 10. commercial_area
145
+ 11. dense_residential
146
+ 12. desert
147
+ 13. forest
148
+ 14. freeway
149
+ 15. golf_course
150
+ 16. ground_track_field
151
+ 17. harbor
152
+ 18. industrial_area
153
+ 19. intersection
154
+ 20. island
155
+ 21. lake
156
+ 22. meadow
157
+ 23. medium_residential
158
+ 24. mobile_home_park
159
+ 25. mountain
160
+ 26. overpass
161
+ 27. palace
162
+ 28. parking_lot
163
+ 29. railway
164
+ 30. railway_station
165
+ 31. rectangular_farmland
166
+ 32. river
167
+ 33. roundabout
168
+ 34. runway
169
+ 35. sea_ice
170
+ 36. ship
171
+ 37. snowberg
172
+ 38. sparse_residential
173
+ 39. stadium
174
+ 40. storage_tank
175
+ 41. tennis_court
176
+ 42. terrace
177
+ 43. thermal_power_station
178
+ 44. wetland
179
+ This dataset uses the train/val/test splits defined in the "In-domain representation
180
+ learning for remote sensing" paper:
181
+ * https://arxiv.org/abs/1911.06721
182
+ If you use this dataset in your research, please cite the following paper:
183
+ * https://doi.org/10.1109/jproc.2017.2675998
184
+ """
185
+
186
+ # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
187
+ # md5 = "d824acb73957502b00efd559fc6cfbbb"
188
+ # filename = "NWPU-RESISC45.rar"
189
+ directory = "resisc45/NWPU-RESISC45"
190
+
191
+ splits = ["train", "val", "test"]
192
+ split_urls = {
193
+ "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501
194
+ "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501
195
+ "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501
196
+ }
197
+ split_md5s = {
198
+ "train": "b5a4c05a37de15e4ca886696a85c403e",
199
+ "val": "a0770cee4c5ca20b8c32bbd61e114805",
200
+ "test": "3dda9e4988b47eb1de9f07993653eb08",
201
+ }
202
+ classes = [
203
+ "airplane",
204
+ "airport",
205
+ "baseball_diamond",
206
+ "basketball_court",
207
+ "beach",
208
+ "bridge",
209
+ "chaparral",
210
+ "church",
211
+ "circular_farmland",
212
+ "cloud",
213
+ "commercial_area",
214
+ "dense_residential",
215
+ "desert",
216
+ "forest",
217
+ "freeway",
218
+ "golf_course",
219
+ "ground_track_field",
220
+ "harbor",
221
+ "industrial_area",
222
+ "intersection",
223
+ "island",
224
+ "lake",
225
+ "meadow",
226
+ "medium_residential",
227
+ "mobile_home_park",
228
+ "mountain",
229
+ "overpass",
230
+ "palace",
231
+ "parking_lot",
232
+ "railway",
233
+ "railway_station",
234
+ "rectangular_farmland",
235
+ "river",
236
+ "roundabout",
237
+ "runway",
238
+ "sea_ice",
239
+ "ship",
240
+ "snowberg",
241
+ "sparse_residential",
242
+ "stadium",
243
+ "storage_tank",
244
+ "tennis_court",
245
+ "terrace",
246
+ "thermal_power_station",
247
+ "wetland",
248
+ ]
249
+
250
+ def __init__(
251
+ self,
252
+ root: str = "data",
253
+ split: str = "train",
254
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
255
+ ) -> None:
256
+ """Initialize a new RESISC45 dataset instance.
257
+ Args:
258
+ root: root directory where dataset can be found
259
+ split: one of "train", "val", or "test"
260
+ transforms: a function/transform that takes input sample and its target as
261
+ entry and returns a transformed version
262
+ """
263
+ assert split in self.splits
264
+ self.root = root
265
+
266
+ valid_fns = set()
267
+ with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f:
268
+ for fn in f:
269
+ valid_fns.add(fn.strip())
270
+ is_in_split: Callable[[str], bool] = lambda x: os.path.basename(
271
+ x) in valid_fns
272
+
273
+ super().__init__(
274
+ root=os.path.join(root, self.directory),
275
+ transforms=transforms,
276
+ is_valid_file=is_in_split,
277
+ )
278
+
279
+
280
+
281
+ class RESISC45:
282
+ def __init__(self,
283
+ preprocess,
284
+ location=os.path.expanduser('~/data'),
285
+ batch_size=32,
286
+ num_workers=16):
287
+
288
+ self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess)
289
+ self.train_loader = torch.utils.data.DataLoader(
290
+ self.train_dataset,
291
+ shuffle=True,
292
+ batch_size=batch_size,
293
+ num_workers=num_workers,
294
+ )
295
+
296
+ self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess)
297
+ self.test_loader = torch.utils.data.DataLoader(
298
+ self.test_dataset,
299
+ batch_size=batch_size,
300
+ num_workers=num_workers
301
+ )
302
+
303
+ # class names have _ so split on this for better zero-shot head
304
+ self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes]