diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..9c87bd622a152a7102d40e5127a6c332bf8a36f4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/WVO-WD-TFS.png filter=lfs diff=lfs merge=lfs -text
+assets/orthoreg_loss.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..603865857d29639fc3ed3299c2e1d2221e8e2bcb
--- /dev/null
+++ b/README.md
@@ -0,0 +1,207 @@
+# Understanding and Enforcing Weight Disentanglement in Task Arithmetic
+
+[CVPR 2026] Official code of the paper **"Understanding and Enforcing Weight Disentanglement in Task Arithmetic"**.
+
+[[Paper](https://arxiv.org/abs/2604.17078)] [[Checkpoints](#-checkpoints)] [[Datasets](#-datasets)]
+
+---
+
+## ๐ฏ Abstract
+
+Task arithmetic provides an efficient, training-free way to edit pre-trained models, yet lacks a fundamental theoretical explanation for its success. The existing concept of "weight disentanglement" describes the ideal outcome of non-interfering task composition but does not reveal its underlying cause. Crucially, what intrinsic properties of the pre-trained model ($\theta_0$) or the task vectors ($\tau_t$) enable this disentanglement remains underexplored. In this paper, we introduce Task-Feature Specialization (TFS), a model's ability to allocate distinct internal features to different tasks, as the fundamental principle. We first prove that TFS is a sufficient condition for weight disentanglement. More importantly, we find that TFS also gives rise to an observable geometric consequence: weight vector orthogonality. This positions TFS as the common cause for both the desired functional outcome (disentanglement) and a measurable geometric property (orthogonality). This relationship provides the key insight for our method: since the abstract TFS property is intractable to enforce directly, we can instead promote weight disentanglement by shaping its concrete geometric consequence, orthogonality. Therefore, we propose OrthoReg, a simple and effective regularization method that actively enforces an internal orthogonal structure on weight updates ($\Delta W$) that constitute $\tau_t$ during fine-tuning. And we theoretically prove that OrthoReg promotes disentanglement. Extensive experiments demonstrate that OrthoReg consistently and significantly enhances the performance of various task arithmetic methods.
+
+
+
+
+ TFS is the common cause connecting Weight Vector Orthogonality (WVO) with Weight Disentanglement (WD).
+
+
+### โจ Key Contributions
+
+- ๐ **Theory**: We identify TFS as a sufficient condition for weight disentanglement, and WVO as its geometric consequence, providing the first principled explanation for task arithmetic.
+- ๐ง **Method (OrthoReg)**: A simple regularization term added to the fine-tuning loss that enforces column-wise orthogonality on ฮW, for which we prove theoretical efficacy.
+- ๐ **Connection to TTA**: We show that OrthoReg and Tangent Task Arithmetic (TTA) share the same underlying mechanism (i.e. inter-task vector orthogonality), but OrthoReg achieves this more efficiently.
+- ๐ **Experiments**: Consistent and significant improvements over Non-linear FT, TTA, ATT-FT, LoRA-ATT across ViT-B-32, ViT-B-16, and ViT-L-14.
+
+---
+
+### The OrthoReg Loss
+
+
+
+
+
+The total loss adds a regularization term to the standard task objective:
+
+$$\mathcal{L} = \mathcal{L}_{\text{task}}(\theta_0 + \Delta\theta) + \lambda \cdot \mathcal{L}_{\text{ortho}}(\Delta\theta)$$
+
+$$\mathcal{L}_{\text{ortho}}(\Delta\theta) = \sum_l \left\|(\Delta W^{(l)})^\top \Delta W^{(l)} - I\right\|_F^2$$
+
+---
+
+## ๐ ๏ธ Installation
+
+This codebase is built on top of [Tangent Task Arithmetic (TTA)](https://github.com/gortizji/tangent_task_arithmetic). Environment setup follows theirs exactly.
+
+
+To run the code, please install all its dependencies:
+```sh
+conda env create
+conda activate tangent-arithmetic
+```
+and add the `src` directory to the `PYTHONPATH`:
+```sh
+cd OrthoReg
+export PYTHONPATH="$PYTHONPATH:$PWD"
+```
+
+---
+
+## ๐ฆ Datasets
+
+We evaluate on 8 image classification benchmarks following [Task Arithmetic](https://github.com/mlfoundations/task_vectors) and [TTA](https://github.com/gortizji/tangent_task_arithmetic):
+
+**Cars ยท DTD ยท EuroSAT ยท GTSRB ยท MNIST ยท RESISC45 ยท SUN397 ยท SVHN**
+
+For dataset download and preparation, please follow the instructions in the [TTA repository](https://github.com/gortizji/tangent_task_arithmetic#datasets).
+
+We also provide a pre-packaged dataset archive for convenience:
+
+> ๐ฅ **Dataset Download:** `https://pan.baidu.com/s/1PgLyjUrAhsmgSAz4ms5mcQ?pwd=fwf5`
+
+Set the root path via `--data-location /path/to/datasets/`.
+
+---
+
+## ๐ Quick Start
+
+All scripts are run from the `OrthoReg/` directory. This repository implements **6 finetuning modes**:
+
+| `--finetuning-mode` | Description |
+|---|---|
+| `standard` | Non-linear full fine-tuning (baseline) |
+| `standard_ortho` | Non-linear FT + OrthoReg |
+| `linear` | TTA โ tangent space fine-tuning (baseline) |
+| `linear_ortho` | TTA + OrthoReg |
+| `linear-2` | ATT-FT โ attention-only fine-tuning (baseline) |
+| `linear-2_ortho` | ATT-FT + OrthoReg |
+
+> **Note on LoRA-ATT:** The LoRA-ATT and LoRA-ATT+OrthoReg results from the paper are implemented in a separate repository due to the complexity of patching OpenCLIP's fused QKV projection. Code will be released at: `https://github.com/lshangge/OrthoReg_lora`
+
+### Step 1 โ Fine-tune
+
+```bash
+python src/finetune.py \
+ --model ViT-B-32 \
+ --finetuning-mode standard_ortho \
+ --ortho-lambda 10 \
+ --lr 1e-5 \
+ --data-location /path/to/datasets/ \
+```
+
+Switch between all six modes by changing `--finetuning-mode` and `--ortho-lambda`:
+
+```bash
+--finetuning-mode standard --ortho-lambda 0 # Non-linear FT
+--finetuning-mode standard_ortho --ortho-lambda xx # Non-linear FT + OrthoReg
+--finetuning-mode linear --ortho-lambda 0 # TTA
+--finetuning-mode linear_ortho --ortho-lambda xx # TTA + OrthoReg
+--finetuning-mode linear-2 --ortho-lambda 0 # ATT-FT
+--finetuning-mode linear-2_ortho --ortho-lambda xx # ATT-FT + OrthoReg
+```
+
+Checkpoints are saved to:
+- `checkpoints_{seed}/{mode}_{lr}_{model}/` โ for baselines
+- `checkpoints_{seed}/{mode}_{lr}_lambda{lambda}_{model}/` โ for OrthoReg variants
+
+### Step 2 โ Evaluate Single-Task Accuracy
+
+```bash
+python src/eval_single_task.py \
+ --model ViT-B-32 \
+ --finetuning-mode standard_ortho \
+ --ortho-lambda 10 \
+ --lr 1e-5 \
+ --data-location /path/to/datasets/
+```
+
+> Run `eval_single_task` with `--finetuning-mode none --ortho-lambda 0` first to generate `zeroshot_accuracies.json`, which is required as the reference for normalized accuracy in Steps 3โ4.
+
+### Step 3 โ Evaluate Task Addition
+
+```bash
+python src/eval_task_addition.py \
+ --model ViT-B-32 \
+ --finetuning-mode standard_ortho \
+ --ortho-lambda 10 \
+ --lr 1e-5 \
+ --data-location /path/to/datasets/
+```
+
+### Step 4 โ Evaluate Task Negation
+
+```bash
+python src/eval_task_negation.py \
+ --model ViT-B-32 \
+ --finetuning-mode standard_ortho \
+ --ortho-lambda 10 \
+ --lr 1e-5 \
+ --data-location /path/to/datasets/
+```
+
+---
+
+## ๐ง Key Arguments
+
+| Argument | Default | Description |
+|---|:---:|---|
+| `--model` | `ViT-B-32` | CLIP model architecture |
+| `--finetuning-mode` | โ | One of the 6 modes above |
+| `--ortho-lambda` | `0.0` | OrthoReg strength ฮป; set to `0` for baselines |
+| `--lr` | `1e-5` | Learning rate |
+| `--seed` | `1993` | Random seed |
+| `--world-size` | `1` | Number of GPUs (DDP) |
+| `--data-location` | โ | Dataset root directory |
+| `--batch-size` | `128` | Batch size per GPU |
+
+---
+
+## ๐ Checkpoints
+
+We release fine-tuned checkpoints for ViT-B-32, ViT-B-16, and ViT-L-14 on all 8 tasks, covering all 6 modes.
+
+> ๐ฅ **Checkpoint Download:** `https://huggingface.co/gezi2333/OrthoReg_checkpoints`
+
+Unzip into `OrthoReg/checkpoints_{seed}/` and pass the corresponding `--seed`, `--lr`, and `--ortho-lambda` to the eval scripts to reproduce the paper's results directly.
+
+---
+
+## ๐ Citation
+
+If you find this work useful, please cite:
+
+```bibtex
+@inproceedings{liu2026orthoreg,
+ title = {Understanding and Enforcing Weight Disentanglement in Task Arithmetic},
+ author = {Liu, Shangge and Yin, Yuehan and Wang, Lei and Fan, Qi and
+ Shi, Yinghuan and Li, Wenbin and Gao, Yang and Tao, Dacheng},
+ booktitle = {CVPR},
+ year = {2026}
+}
+```
+
+---
+
+## ๐ Contact
+
+For questions or issues, please:
+
+- Open an issue on GitHub
+- Contact the authors at [lshangge@smail.nju.edu.cn]
+
+---
+
+## ๐ฌ Acknowledgements
+
+This codebase is built on top of [Task Arithmetic](https://github.com/mlfoundations/task_vectors), [Tangent Task Arithmetic](https://github.com/gortizji/tangent_task_arithmetic), and [Attention-Only Fine-tuning](https://github.com/kyrie-23/linear_task_arithmetic). We thank the authors for releasing their code.
+
diff --git a/assets/WVO-WD-TFS.png b/assets/WVO-WD-TFS.png
new file mode 100644
index 0000000000000000000000000000000000000000..8e1296166a218be6e3e652c466f2ea56da1ab3ce
--- /dev/null
+++ b/assets/WVO-WD-TFS.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc8a9efc76ecb495a5de03a98215606a8cbab5b38cdbb53ea5d2c2ed133e535a
+size 149842
diff --git a/assets/orthoreg_loss.png b/assets/orthoreg_loss.png
new file mode 100644
index 0000000000000000000000000000000000000000..c1370b3fe742341b87d034583f312ae5ac4327f7
--- /dev/null
+++ b/assets/orthoreg_loss.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7974c4003a412cc9a5134f26b71cc58739f5cd1f48301ab0d3d428cb17ecb8e
+size 157982
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c3fa11f1d080d690ddf12ad5a58ff3c507e19b31
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,140 @@
+name: tangent-arithmetic
+channels:
+ - pytorch
+ - nvidia
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1
+ - _openmp_mutex=5.1
+ - blas=1.0
+ - brotlipy=0.7.0
+ - bzip2=1.0.8
+ - ca-certificates=2023.05.30
+ - certifi=2023.5.7
+ - cffi=1.15.1
+ - charset-normalizer=2.0.4
+ - cryptography=39.0.1
+ - cuda=11.6.1
+ - cuda-cccl=11.6.55
+ - cuda-command-line-tools=11.6.2
+ - cuda-compiler=11.6.2
+ - cuda-cudart=11.6.55
+ - cuda-cudart-dev=11.6.55
+ - cuda-cuobjdump=11.6.124
+ - cuda-cupti=11.6.124
+ - cuda-cuxxfilt=11.6.124
+ - cuda-driver-dev=11.6.55
+ - cuda-gdb=12.1.105
+ - cuda-libraries=11.6.1
+ - cuda-libraries-dev=11.6.1
+ - cuda-memcheck=11.8.86
+ - cuda-nsight=12.1.105
+ - cuda-nsight-compute=12.1.1
+ - cuda-nvcc=11.6.124
+ - cuda-nvdisasm=12.1.105
+ - cuda-nvml-dev=11.6.55
+ - cuda-nvprof=12.1.105
+ - cuda-nvprune=11.6.124
+ - cuda-nvrtc=11.6.124
+ - cuda-nvrtc-dev=11.6.124
+ - cuda-nvtx=11.6.124
+ - cuda-nvvp=12.1.105
+ - cuda-runtime=11.6.1
+ - cuda-samples=11.6.101
+ - cuda-sanitizer-api=12.1.105
+ - cuda-toolkit=11.6.1
+ - cuda-tools=11.6.1
+ - cuda-visual-tools=11.6.1
+ - ffmpeg=4.3
+ - freetype=2.12.1
+ - gds-tools=1.6.1.9
+ - giflib=5.2.1
+ - gmp=6.2.1
+ - gnutls=3.6.15
+ - idna=3.4
+ - intel-openmp=2023.1.0
+ - jpeg=9e
+ - lame=3.100
+ - lcms2=2.12
+ - ld_impl_linux-64=2.38
+ - lerc=3.0
+ - libcublas=11.9.2.110
+ - libcublas-dev=11.9.2.110
+ - libcufft=10.7.1.112
+ - libcufft-dev=10.7.1.112
+ - libcufile=1.6.1.9
+ - libcufile-dev=1.6.1.9
+ - libcurand=10.3.2.106
+ - libcurand-dev=10.3.2.106
+ - libcusolver=11.3.4.124
+ - libcusparse=11.7.2.124
+ - libcusparse-dev=11.7.2.124
+ - libdeflate=1.17
+ - libffi=3.4.4
+ - libgcc-ng=11.2.0
+ - libgomp=11.2.0
+ - libiconv=1.16
+ - libidn2=2.3.4
+ - libnpp=11.6.3.124
+ - libnpp-dev=11.6.3.124
+ - libnvjpeg=11.6.2.124
+ - libnvjpeg-dev=11.6.2.124
+ - libpng=1.6.39
+ - libstdcxx-ng=11.2.0
+ - libtasn1=4.19.0
+ - libtiff=4.5.0
+ - libunistring=0.9.10
+ - libuuid=1.41.5
+ - libwebp=1.2.4
+ - libwebp-base=1.2.4
+ - lz4-c=1.9.4
+ - mkl=2023.1.0
+ - mkl-service=2.4.0
+ - mkl_fft=1.3.6
+ - mkl_random=1.2.2
+ - ncurses=6.4
+ - nettle=3.7.3
+ - nsight-compute=2023.1.1.4
+ - numpy=1.24.3
+ - numpy-base=1.24.3
+ - openh264=2.1.1
+ - openssl=1.1.1t
+ - pillow=9.4.0
+ - pip=23.0.1
+ - pycparser=2.21
+ - pyopenssl=23.0.0
+ - pysocks=1.7.1
+ - python=3.10.11
+ - pytorch=1.13.1
+ - pytorch-cuda=11.6
+ - pytorch-mutex=1.0
+ - readline=8.2
+ - requests=2.29.0
+ - setuptools=67.8.0
+ - sqlite=3.41.2
+ - tbb=2021.8.0
+ - tk=8.6.12
+ - torchaudio=0.13.1
+ - torchvision=0.14.1
+ - typing_extensions=4.5.0
+ - tzdata=2023c
+ - urllib3=1.26.16
+ - wheel=0.38.4
+ - xz=5.4.2
+ - zlib=1.2.13
+ - zstd=1.5.5
+ - pip:
+ - filelock==3.12.0
+ - fsspec==2023.5.0
+ - ftfy==6.1.1
+ - huggingface-hub==0.15.1
+ - open-clip-torch==2.10.1
+ - packaging==23.1
+ - protobuf==3.20.3
+ - pyyaml==6.0
+ - regex==2023.6.3
+ - safetensors==0.3.1
+ - scipy==1.10.1
+ - sentencepiece==0.1.99
+ - timm==0.9.2
+ - wcwidth==0.2.6
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a7cc90a8c713cef874a8e3c40b8d24b5f361d3e
Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/__pycache__/args.cpython-310.pyc b/src/__pycache__/args.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5f3a4d5e74a11757f90d656cb4bbec449a7dc94
Binary files /dev/null and b/src/__pycache__/args.cpython-310.pyc differ
diff --git a/src/__pycache__/attention_only_finetune.cpython-310.pyc b/src/__pycache__/attention_only_finetune.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ea5e49c92e2c2dd9ebb0791e73cb383d95a5c21
Binary files /dev/null and b/src/__pycache__/attention_only_finetune.cpython-310.pyc differ
diff --git a/src/__pycache__/distributed.cpython-310.pyc b/src/__pycache__/distributed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..340b9dedddebb30967cf4e654dd64c089ca69d39
Binary files /dev/null and b/src/__pycache__/distributed.cpython-310.pyc differ
diff --git a/src/__pycache__/eval.cpython-310.pyc b/src/__pycache__/eval.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..11c84b135a6e0563a6713b2fb974809f520a75a0
Binary files /dev/null and b/src/__pycache__/eval.cpython-310.pyc differ
diff --git a/src/__pycache__/heads.cpython-310.pyc b/src/__pycache__/heads.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5560cfb464a58f4e14cee196889d273e886ac012
Binary files /dev/null and b/src/__pycache__/heads.cpython-310.pyc differ
diff --git a/src/__pycache__/linearize.cpython-310.pyc b/src/__pycache__/linearize.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0333faf468cb7e4bdc4601287a562bd070f187e3
Binary files /dev/null and b/src/__pycache__/linearize.cpython-310.pyc differ
diff --git a/src/__pycache__/modeling.cpython-310.pyc b/src/__pycache__/modeling.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b61a266a18103c25d4bf3023dc2bb7def057bd5c
Binary files /dev/null and b/src/__pycache__/modeling.cpython-310.pyc differ
diff --git a/src/__pycache__/task_vectors.cpython-310.pyc b/src/__pycache__/task_vectors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3bb60527b24f61861488ebe52b59d46d856b4cc
Binary files /dev/null and b/src/__pycache__/task_vectors.cpython-310.pyc differ
diff --git a/src/__pycache__/utils.cpython-310.pyc b/src/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..257c32d1052c6b28a85df1a438f5c1f8b1e6ffb4
Binary files /dev/null and b/src/__pycache__/utils.cpython-310.pyc differ
diff --git a/src/args.py b/src/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..47e11d8c846276908874a5c8a1540d94a9dc2d74
--- /dev/null
+++ b/src/args.py
@@ -0,0 +1,153 @@
+import argparse
+import os
+
+import torch
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data-location",
+ type=str,
+ default=os.path.expanduser("/path/datasets/"),
+ help="The root directory for the datasets.",
+ )
+ parser.add_argument(
+ "--eval-datasets",
+ default=None,
+ type=lambda x: x.split(","),
+ help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ",
+ )
+ parser.add_argument(
+ "--train-dataset",
+ default=None,
+ type=lambda x: x.split(","),
+ help="Which dataset(s) to patch on.",
+ )
+ parser.add_argument(
+ "--exp_name",
+ type=str,
+ default=None,
+ help="Name of the experiment, for organization purposes only.",
+ )
+ parser.add_argument(
+ "--results-db",
+ type=str,
+ default=None,
+ help="Where to store the results, else does not store",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="ViT-B-32",
+ help="The type of model (e.g. RN50, ViT-B-32).",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=128,
+ )
+ parser.add_argument(
+ "--num-grad-accumulation",
+ type=int,
+ default=1,
+ help="Number of gradient accumulation steps.",
+ )
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
+ parser.add_argument("--wd", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.")
+ parser.add_argument(
+ "--warmup_length",
+ type=int,
+ default=500,
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ )
+ parser.add_argument(
+ "--load",
+ type=lambda x: x.split(","),
+ default=None,
+ help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
+ )
+ parser.add_argument(
+ "--save",
+ type=str,
+ default=None,
+ help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
+ )
+ parser.add_argument(
+ "--cache-dir",
+ type=str,
+ default=None,
+ help="Directory for caching features and encoder",
+ )
+ parser.add_argument(
+ "--openclip-cachedir",
+ type=str,
+ default=os.path.expanduser("~/openclip-cachedir/open_clip"),
+ help="Directory for caching models from OpenCLIP",
+ )
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of processes for distributed training.",
+ )
+ parser.add_argument(
+ "--checkpoint-every",
+ type=int,
+ default=-1,
+ help="How often to checkpoint the model.",
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=12355,
+ help="Port for distributed training.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=1993,
+ help="Random seed.",
+ )
+ parser.add_argument(
+ "--finetuning-mode",
+ choices=["standard", "standard_ortho", "linear", "linear_ortho", "linear-2", "linear-2_ortho"],
+ help="Finetuning mode: standard/linear/linear-2 with optional ortho regularization.",
+ )
+ parser.add_argument(
+ "--n-eval-points",
+ type=int,
+ default=21,
+ help="Number of evaluation points used to find optimal coefficient in task arithmetic.",
+ )
+ parser.add_argument(
+ "--ortho-lambda",
+ type=float,
+ default=0.0,
+ help="Weight of the orthogonality regularization term. Default 0.0 means no regularization.",
+ )
+ parser.add_argument(
+ "--control_threshold",
+ type=float,
+ default=0.95,
+ help="Control dataset performance degradation threshold.",
+ )
+ parser.add_argument(
+ "--alpha",
+ type=float,
+ default=None,
+ help="Manually specify the scaling coefficient for task vectors. If None, it will be optimized on validation set.",
+ )
+
+ parsed_args = parser.parse_args()
+ parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ if parsed_args.load is not None and len(parsed_args.load) == 1:
+ parsed_args.load = parsed_args.load[0]
+
+ return parsed_args
diff --git a/src/attention_only_finetune.py b/src/attention_only_finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a77b965fddf2370fdb56b4fbd4e6f182b4bf2b1
--- /dev/null
+++ b/src/attention_only_finetune.py
@@ -0,0 +1,116 @@
+import os
+import torch
+import torch.nn as nn
+from src.modeling import ImageEncoder
+from src.utils import DotDict
+
+class AttentionOnlyFinetuneEncoder(ImageEncoder):
+ """
+ A specialized ImageEncoder that fine-tunes only the attention module weights in the ViT.
+ Corresponds to the method described in Jin et al. (2025).
+ """
+ def __init__(self, args, keep_lang=False):
+ # 1. Call the parent constructor to build the full model as usual
+ super().__init__(args, keep_lang=keep_lang)
+
+ self.args = args
+
+ # 2. Freeze all model parameters
+ # print("Freezing all parameters of the model initially...")
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ # 3. Unfreeze only the Attention module weights (Wq, Wk, Wv, Wo)
+ # print("Unfreezing Attention module weights for fine-tuning...")
+ self._unfreeze_attention_weights(self.model.visual)
+
+ # 4. (Optional but recommended) Print trainable parameters for verification
+ # self._verify_trainable_params()
+
+ def _unfreeze_attention_weights(self, vit_model):
+ """
+ Iterate over all Transformer blocks and unfreeze the attention projection weights.
+ """
+ # Iterate over the model and unfreeze target parameters
+ for block in vit_model.transformer.resblocks:
+ # Unfreeze the combined input projection weight for Q, K, V
+ block.attn.in_proj_weight.requires_grad = True
+
+ # Unfreeze the output projection weight
+ block.attn.out_proj.weight.requires_grad = True
+
+ # Per the paper's ablation study, not fine-tuning biases yields better results; keep them frozen
+ # block.attn.in_proj_bias.requires_grad = True
+ # block.attn.out_proj.bias.requires_grad = True
+
+ def _verify_trainable_params(self):
+ """Print all trainable parameters for debugging and verification."""
+ print("="*80)
+ print("Trainable parameters in AttentionOnlyFinetuneEncoder:")
+ trainable_params_count = 0
+ for name, param in self.model.named_parameters():
+ if param.requires_grad:
+ print(f" - {name}")
+ trainable_params_count += param.numel()
+ print(f"Total trainable parameters: {trainable_params_count / 1e6:.2f}M")
+ print("="*80)
+
+ def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None):
+ """
+ Extended forward method to optionally compute and return the orthogonal loss.
+ Consistent with the logic implemented for standard_ortho.
+ """
+ # Original forward pass
+ features = self.model.encode_image(images)
+
+ # Return features directly if orthogonal loss is not needed
+ if not calculate_ortho_loss:
+ return features
+
+ # --- Compute orthogonal loss if requested ---
+ if pretrained_state_dict is None:
+ raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True")
+
+ ortho_loss = 0.0
+ # self.model is the open_clip model (e.g. ViT); iterate over its parameters
+ for name, p_finetuned in self.model.named_parameters():
+ # Only compute loss for trainable parameters with gradients
+ if p_finetuned.requires_grad and p_finetuned.dim() == 2:
+ if name in pretrained_state_dict:
+ p_pretrained = pretrained_state_dict[name].to(p_finetuned.device)
+
+ delta_W = p_finetuned - p_pretrained
+
+ # Compute orthogonal loss (W^T * W - I)
+ rows, cols = delta_W.shape
+ if rows < cols:
+ mat = delta_W @ delta_W.T
+ identity = torch.eye(rows, device=delta_W.device)
+ else:
+ mat = delta_W.T @ delta_W
+ identity = torch.eye(cols, device=delta_W.device)
+
+ ortho_loss += torch.norm(mat - identity, p='fro')
+
+ return features, ortho_loss
+
+ def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
+ # Ensure __call__ forwards all arguments
+ return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict)
+
+ def save(self, filename):
+ """Save model weights."""
+ # print(f"Saving AttentionOnlyFinetuneEncoder state_dict to {filename}")
+ if os.path.dirname(filename):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ # Save only the state_dict; reconstruct the model on load
+ torch.save(self.state_dict(), filename)
+
+ @classmethod
+ def load(cls, filename, args):
+ """Load model from a state_dict."""
+ # print(f"Loading AttentionOnlyFinetuneEncoder from {filename}")
+ encoder = cls(args) # Create a new instance
+ state_dict = torch.load(filename, map_location='cpu')
+ encoder.load_state_dict(state_dict) # Load weights
+ return encoder
\ No newline at end of file
diff --git a/src/datasets/__pycache__/cars.cpython-310.pyc b/src/datasets/__pycache__/cars.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9d4405bfc3a5b86e1f34823494ed704d655402e
Binary files /dev/null and b/src/datasets/__pycache__/cars.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/cifar10.cpython-310.pyc b/src/datasets/__pycache__/cifar10.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..392eb3ee6607d1678c45dc1de315783067186aa9
Binary files /dev/null and b/src/datasets/__pycache__/cifar10.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/cifar100.cpython-310.pyc b/src/datasets/__pycache__/cifar100.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfce4e0671aa1ca3848a5ebd644343c16d4cdeeb
Binary files /dev/null and b/src/datasets/__pycache__/cifar100.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/common.cpython-310.pyc b/src/datasets/__pycache__/common.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5283c26161f41f3a07d31f3e96aa755b05e4a4d1
Binary files /dev/null and b/src/datasets/__pycache__/common.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/dtd.cpython-310.pyc b/src/datasets/__pycache__/dtd.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75cde45975e35320f1fddf7ddc0f3d400d34e34c
Binary files /dev/null and b/src/datasets/__pycache__/dtd.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/emnist.cpython-310.pyc b/src/datasets/__pycache__/emnist.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ba145d7792e5347fe678092596ca250d00a5d73
Binary files /dev/null and b/src/datasets/__pycache__/emnist.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/eurosat.cpython-310.pyc b/src/datasets/__pycache__/eurosat.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7982e0eba87a3be3f0a2ec694fd22857a67a101f
Binary files /dev/null and b/src/datasets/__pycache__/eurosat.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/gtsrb.cpython-310.pyc b/src/datasets/__pycache__/gtsrb.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54dd342d1682b0e5349233c103ad1e4aef31578b
Binary files /dev/null and b/src/datasets/__pycache__/gtsrb.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/imagenet.cpython-310.pyc b/src/datasets/__pycache__/imagenet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2c2ecbd705b1dd2672d571b1477f366f517671a
Binary files /dev/null and b/src/datasets/__pycache__/imagenet.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/kmnist.cpython-310.pyc b/src/datasets/__pycache__/kmnist.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f3c07ef170873d91e6d0a7729d0d766ca25e9ce
Binary files /dev/null and b/src/datasets/__pycache__/kmnist.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/mnist.cpython-310.pyc b/src/datasets/__pycache__/mnist.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6c961255829fd3125c65faf3f761217bfdad25f
Binary files /dev/null and b/src/datasets/__pycache__/mnist.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/oxfordpets.cpython-310.pyc b/src/datasets/__pycache__/oxfordpets.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f95c4a61bbb3e78942c8689fddf973dc6010bf26
Binary files /dev/null and b/src/datasets/__pycache__/oxfordpets.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/registry.cpython-310.pyc b/src/datasets/__pycache__/registry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..722d832aa156ec400101753f3b608d785faee229
Binary files /dev/null and b/src/datasets/__pycache__/registry.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/resisc45.cpython-310.pyc b/src/datasets/__pycache__/resisc45.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d115426a3da66263410795803530cf51d66233b
Binary files /dev/null and b/src/datasets/__pycache__/resisc45.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/stl10.cpython-310.pyc b/src/datasets/__pycache__/stl10.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4723be7f4219d30244aa269f37c4b9a9953b022
Binary files /dev/null and b/src/datasets/__pycache__/stl10.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/sun397.cpython-310.pyc b/src/datasets/__pycache__/sun397.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43bc175deb4d4221c924b91ac1c54c7c238ddd33
Binary files /dev/null and b/src/datasets/__pycache__/sun397.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/svhn.cpython-310.pyc b/src/datasets/__pycache__/svhn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f277d480777bae9330af34f83874359aa5c8d84
Binary files /dev/null and b/src/datasets/__pycache__/svhn.cpython-310.pyc differ
diff --git a/src/datasets/__pycache__/templates.cpython-310.pyc b/src/datasets/__pycache__/templates.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2972bcb8cee9391d7f9dcd35ac6c53b62a881077
Binary files /dev/null and b/src/datasets/__pycache__/templates.cpython-310.pyc differ
diff --git a/src/datasets/cars.py b/src/datasets/cars.py
new file mode 100644
index 0000000000000000000000000000000000000000..4546031aa1bde18c22e753da607ee0db555f08f2
--- /dev/null
+++ b/src/datasets/cars.py
@@ -0,0 +1,155 @@
+import os
+import torch
+import torchvision.datasets as datasets
+
+
+import pathlib
+from typing import Callable, Optional, Any, Tuple
+
+from PIL import Image
+
+from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg
+from torchvision.datasets.vision import VisionDataset
+
+
+class PytorchStanfordCars(VisionDataset):
+ """`Stanford Cars `_ Dataset
+
+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
+ split into 8,144 training images and 8,041 testing images, where each class
+ has been split roughly in a 50-50 split
+
+ .. note::
+
+ This class needs `scipy `_ to load target files from `.mat` format.
+
+ Args:
+ root (string): Root directory of dataset
+ split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again."""
+
+ def __init__(
+ self,
+ root: str,
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+
+ try:
+ import scipy.io as sio
+ except ImportError:
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self._split = verify_str_arg(split, "split", ("train", "test"))
+ self._base_folder = pathlib.Path(root) / "stanford_cars"
+ devkit = self._base_folder / "devkit"
+
+ if self._split == "train":
+ self._annotations_mat_path = devkit / "cars_train_annos.mat"
+ self._images_base_path = self._base_folder / "cars_train"
+ else:
+ self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
+ self._images_base_path = self._base_folder / "cars_test"
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ self._samples = [
+ (
+ str(self._images_base_path / annotation["fname"]),
+ annotation["class"] - 1, # Original target mapping starts from 1, hence -1
+ )
+ for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
+ ]
+
+ self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
+
+ def __len__(self) -> int:
+ return len(self._samples)
+
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+ """Returns pil_image and class_id for given index"""
+ image_path, target = self._samples[idx]
+ pil_image = Image.open(image_path).convert("RGB")
+
+ if self.transform is not None:
+ pil_image = self.transform(pil_image)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return pil_image, target
+
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
+ download_root=str(self._base_folder),
+ md5="c3b158d763b6e2245038c8ad08e45376",
+ )
+ if self._split == "train":
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
+ download_root=str(self._base_folder),
+ md5="065e5b463ae28d29e77c1b4b166cfe61",
+ )
+ else:
+ download_and_extract_archive(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
+ download_root=str(self._base_folder),
+ md5="4ce7ebf6a94d07f1952d94dd34c4d501",
+ )
+ download_url(
+ url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
+ root=str(self._base_folder),
+ md5="b0a2b23655a3edd16d84508592a98d10",
+ )
+
+ def _check_exists(self) -> bool:
+ if not (self._base_folder / "devkit").is_dir():
+ return False
+
+ return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
+
+
+class Cars:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=32,
+ num_workers=16):
+ # Data loading code
+
+ self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=False)
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ shuffle=True,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=False)
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers
+ )
+ idx_to_class = dict((v, k)
+ for k, v in self.train_dataset.class_to_idx.items())
+ self.classnames = [idx_to_class[i].replace(
+ '_', ' ') for i in range(len(idx_to_class))]
diff --git a/src/datasets/cifar10.py b/src/datasets/cifar10.py
new file mode 100644
index 0000000000000000000000000000000000000000..096913b59e5158c2eb507ebf066658547df5b533
--- /dev/null
+++ b/src/datasets/cifar10.py
@@ -0,0 +1,56 @@
+import os
+import PIL
+import torch
+import numpy as np
+import torchvision
+from torchvision import transforms
+from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
+from torchvision.datasets import VisionDataset
+
+cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
+
+class CIFAR10:
+ def __init__(self, preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=128,
+ num_workers=16):
+
+
+ self.train_dataset = PyTorchCIFAR10(
+ root=location, download=True, train=True, transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
+ )
+
+ self.test_dataset = PyTorchCIFAR10(
+ root=location, download=True, train=False, transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ self.classnames = self.test_dataset.classes
+
+def convert(x):
+ if isinstance(x, np.ndarray):
+ return torchvision.transforms.functional.to_pil_image(x)
+ return x
+
+class BasicVisionDataset(VisionDataset):
+ def __init__(self, images, targets, transform=None, target_transform=None):
+ if transform is not None:
+ transform.transforms.insert(0, convert)
+ super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform)
+ assert len(images) == len(targets)
+
+ self.images = images
+ self.targets = targets
+
+ def __getitem__(self, index):
+ return self.transform(self.images[index]), self.targets[index]
+
+ def __len__(self):
+ return len(self.targets)
diff --git a/src/datasets/cifar100.py b/src/datasets/cifar100.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7b3bb4953336e7ef3f24bc429f51df8953b4c47
--- /dev/null
+++ b/src/datasets/cifar100.py
@@ -0,0 +1,30 @@
+import os
+import torch
+from torchvision.datasets import CIFAR100 as PyTorchCIFAR100
+
+class CIFAR100:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=128,
+ num_workers=16):
+
+ self.train_dataset = PyTorchCIFAR100(
+ root=location, download=True, train=True, transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset, batch_size=batch_size, num_workers=num_workers
+ )
+
+ self.test_dataset = PyTorchCIFAR100(
+ root=location, download=True, train=False, transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
+ )
+
+ self.classnames = self.test_dataset.classes
+
+
diff --git a/src/datasets/common.py b/src/datasets/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cb6cce76f71f0eea7f770fa4508795703082a95
--- /dev/null
+++ b/src/datasets/common.py
@@ -0,0 +1,139 @@
+import os
+import torch
+import json
+import glob
+import collections
+import random
+
+import numpy as np
+
+from tqdm import tqdm
+
+import torchvision.datasets as datasets
+from torch.utils.data import Dataset, DataLoader, Sampler
+
+
+class SubsetSampler(Sampler):
+ def __init__(self, indices):
+ self.indices = indices
+
+ def __iter__(self):
+ return (i for i in self.indices)
+
+ def __len__(self):
+ return len(self.indices)
+
+class ImageFolderWithPaths(datasets.ImageFolder):
+ def __init__(self, path, transform, flip_label_prob=0.0):
+ super().__init__(path, transform)
+ self.flip_label_prob = flip_label_prob
+ if self.flip_label_prob > 0:
+ print(f'Flipping labels with probability {self.flip_label_prob}')
+ num_classes = len(self.classes)
+ for i in range(len(self.samples)):
+ if random.random() < self.flip_label_prob:
+ new_label = random.randint(0, num_classes-1)
+ self.samples[i] = (
+ self.samples[i][0],
+ new_label
+ )
+
+ def __getitem__(self, index):
+ image, label = super(ImageFolderWithPaths, self).__getitem__(index)
+ return {
+ 'images': image,
+ 'labels': label,
+ 'image_paths': self.samples[index][0]
+ }
+
+
+def maybe_dictionarize(batch):
+ if isinstance(batch, dict):
+ return batch
+
+ if len(batch) == 2:
+ batch = {'images': batch[0], 'labels': batch[1]}
+ elif len(batch) == 3:
+ batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
+ else:
+ raise ValueError(f'Unexpected number of elements: {len(batch)}')
+
+ return batch
+
+
+def get_features_helper(image_encoder, dataloader, device):
+ all_data = collections.defaultdict(list)
+
+ image_encoder = image_encoder.to(device)
+ image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
+ image_encoder.eval()
+
+ with torch.no_grad():
+ for batch in tqdm(dataloader):
+ batch = maybe_dictionarize(batch)
+ features = image_encoder(batch['images'].cuda())
+
+ all_data['features'].append(features.cpu())
+
+ for key, val in batch.items():
+ if key == 'images':
+ continue
+ if hasattr(val, 'cpu'):
+ val = val.cpu()
+ all_data[key].append(val)
+ else:
+ all_data[key].extend(val)
+
+ for key, val in all_data.items():
+ if torch.is_tensor(val[0]):
+ all_data[key] = torch.cat(val).numpy()
+
+ return all_data
+
+
+def get_features(is_train, image_encoder, dataset, device):
+ split = 'train' if is_train else 'val'
+ dname = type(dataset).__name__
+ if image_encoder.cache_dir is not None:
+ cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
+ cached_files = glob.glob(f'{cache_dir}/*')
+ if image_encoder.cache_dir is not None and len(cached_files) > 0:
+ print(f'Getting features from {cache_dir}')
+ data = {}
+ for cached_file in cached_files:
+ name = os.path.splitext(os.path.basename(cached_file))[0]
+ data[name] = torch.load(cached_file)
+ else:
+ print(f'Did not find cached features at {cache_dir}. Building from scratch.')
+ loader = dataset.train_loader if is_train else dataset.test_loader
+ data = get_features_helper(image_encoder, loader, device)
+ if image_encoder.cache_dir is None:
+ print('Not caching because no cache directory was passed.')
+ else:
+ os.makedirs(cache_dir, exist_ok=True)
+ print(f'Caching data at {cache_dir}')
+ for name, val in data.items():
+ torch.save(val, f'{cache_dir}/{name}.pt')
+ return data
+
+
+class FeatureDataset(Dataset):
+ def __init__(self, is_train, image_encoder, dataset, device):
+ self.data = get_features(is_train, image_encoder, dataset, device)
+
+ def __len__(self):
+ return len(self.data['features'])
+
+ def __getitem__(self, idx):
+ data = {k: v[idx] for k, v in self.data.items()}
+ data['features'] = torch.from_numpy(data['features']).float()
+ return data
+
+
+def get_dataloader(dataset, is_train, args, image_encoder=None):
+ if image_encoder is not None:
+ feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device)
+ dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train)
+ else:
+ dataloader = dataset.train_loader if is_train else dataset.test_loader
+ return dataloader
\ No newline at end of file
diff --git a/src/datasets/dtd.py b/src/datasets/dtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..79fb3c3018634e9a9c7ecef38a86f0b3573f39a7
--- /dev/null
+++ b/src/datasets/dtd.py
@@ -0,0 +1,34 @@
+import os
+import torch
+import torchvision.datasets as datasets
+
+
+class DTD:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=32,
+ num_workers=16):
+ # Data loading code
+ traindir = os.path.join(location, 'dtd', 'train')
+ valdir = os.path.join(location, 'dtd', 'val')
+
+ self.train_dataset = datasets.ImageFolder(
+ traindir, transform=preprocess)
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ shuffle=True,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess)
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers
+ )
+ idx_to_class = dict((v, k)
+ for k, v in self.train_dataset.class_to_idx.items())
+ self.classnames = [idx_to_class[i].replace(
+ '_', ' ') for i in range(len(idx_to_class))]
\ No newline at end of file
diff --git a/src/datasets/emnist.py b/src/datasets/emnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..790eca50a2ef5c772b06fb60e2bebd5584eaa14a
--- /dev/null
+++ b/src/datasets/emnist.py
@@ -0,0 +1,74 @@
+import os
+
+import torch
+
+import torchvision
+import torchvision.datasets as datasets
+
+
+def rotate_img(img):
+ return torchvision.transforms.functional.rotate(img, -90)
+
+
+def flip_img(img):
+ return torchvision.transforms.functional.hflip(img)
+
+
+def emnist_preprocess():
+ return torchvision.transforms.Compose(
+ [
+ rotate_img,
+ flip_img,
+ ]
+ )
+
+
+class EMNIST:
+ def __init__(
+ self,
+ preprocess,
+ location,
+ batch_size=128,
+ num_workers=8,
+ ):
+ preprocess1 = emnist_preprocess()
+ preprocess = torchvision.transforms.Compose(
+ [
+ preprocess,
+ preprocess1,
+ ]
+ )
+ # if not os.path.exists(location):
+ # os.makedirs(location, exist_ok=True)
+
+ self.train_dataset = datasets.EMNIST(
+ root=location,
+ download=True,
+ split="digits",
+ transform=preprocess,
+ train=True,
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = datasets.EMNIST(
+ root=location,
+ download=True,
+ split="digits",
+ transform=preprocess,
+ train=False,
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=32,
+ shuffle=False,
+ num_workers=num_workers,
+ )
+
+ self.classnames = self.train_dataset.classes
diff --git a/src/datasets/eurosat.py b/src/datasets/eurosat.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef20d90f16ef8cf046895e89c525ae09d37d1f0
--- /dev/null
+++ b/src/datasets/eurosat.py
@@ -0,0 +1,75 @@
+import os
+import torch
+import torchvision.datasets as datasets
+import re
+
+def pretify_classname(classname):
+ l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname)
+ l = [i.lower() for i in l]
+ out = ' '.join(l)
+ if out.endswith('al'):
+ return out + ' area'
+ return out
+
+class EuroSATBase:
+ def __init__(self,
+ preprocess,
+ test_split,
+ location='~/datasets',
+ batch_size=32,
+ num_workers=16):
+ # Data loading code
+ traindir = os.path.join(location, 'EuroSAT_splits', 'train')
+ testdir = os.path.join(location, 'EuroSAT_splits', test_split)
+
+
+ self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess)
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ shuffle=True,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess)
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers
+ )
+ idx_to_class = dict((v, k)
+ for k, v in self.train_dataset.class_to_idx.items())
+ self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))]
+ self.classnames = [pretify_classname(c) for c in self.classnames]
+ ours_to_open_ai = {
+ 'annual crop': 'annual crop land',
+ 'forest': 'forest',
+ 'herbaceous vegetation': 'brushland or shrubland',
+ 'highway': 'highway or road',
+ 'industrial area': 'industrial buildings or commercial buildings',
+ 'pasture': 'pasture land',
+ 'permanent crop': 'permanent crop land',
+ 'residential area': 'residential buildings or homes or apartments',
+ 'river': 'river',
+ 'sea lake': 'lake or sea',
+ }
+ for i in range(len(self.classnames)):
+ self.classnames[i] = ours_to_open_ai[self.classnames[i]]
+
+
+class EuroSAT(EuroSATBase):
+ def __init__(self,
+ preprocess,
+ location='~/datasets',
+ batch_size=32,
+ num_workers=16):
+ super().__init__(preprocess, 'test', location, batch_size, num_workers)
+
+
+class EuroSATVal(EuroSATBase):
+ def __init__(self,
+ preprocess,
+ location='~/datasets',
+ batch_size=32,
+ num_workers=16):
+ super().__init__(preprocess, 'val', location, batch_size, num_workers)
diff --git a/src/datasets/gtsrb.py b/src/datasets/gtsrb.py
new file mode 100644
index 0000000000000000000000000000000000000000..b089e128489f1526c27b542f3bb8dcf5a1683d8e
--- /dev/null
+++ b/src/datasets/gtsrb.py
@@ -0,0 +1,205 @@
+import csv
+import os
+import pathlib
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+import PIL
+import torch
+from torchvision.datasets.folder import make_dataset
+from torchvision.datasets.utils import (download_and_extract_archive,
+ verify_str_arg)
+from torchvision.datasets.vision import VisionDataset
+
+def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
+ """Finds the class folders in a dataset.
+
+ See :class:`DatasetFolder` for details.
+ """
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+ if not classes:
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+ return classes, class_to_idx
+
+class PyTorchGTSRB(VisionDataset):
+ """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset.
+
+ Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB.
+
+ Args:
+ root (string): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``.
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self._split = verify_str_arg(split, "split", ("train", "test"))
+ self._base_folder = pathlib.Path(root) / "gtsrb"
+ self._target_folder = (
+ self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
+ )
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ if self._split == "train":
+ _, class_to_idx = find_classes(str(self._target_folder))
+ samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx)
+ else:
+ with open(self._base_folder / "GT-final_test.csv") as csv_file:
+ samples = [
+ (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
+ for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
+ ]
+
+ self._samples = samples
+ self.transform = transform
+ self.target_transform = target_transform
+
+ def __len__(self) -> int:
+ return len(self._samples)
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+
+ path, target = self._samples[index]
+ sample = PIL.Image.open(path).convert("RGB")
+
+ if self.transform is not None:
+ sample = self.transform(sample)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+
+ def _check_exists(self) -> bool:
+ return self._target_folder.is_dir()
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
+
+ if self._split == "train":
+ download_and_extract_archive(
+ f"{base_url}GTSRB-Training_fixed.zip",
+ download_root=str(self._base_folder),
+ md5="513f3c79a4c5141765e10e952eaa2478",
+ )
+ else:
+ download_and_extract_archive(
+ f"{base_url}GTSRB_Final_Test_Images.zip",
+ download_root=str(self._base_folder),
+ md5="c7e4e6327067d32654124b0fe9e82185",
+ )
+ download_and_extract_archive(
+ f"{base_url}GTSRB_Final_Test_GT.zip",
+ download_root=str(self._base_folder),
+ md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
+ )
+
+
+class GTSRB:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=128,
+ num_workers=16):
+
+ # to fit with repo conventions for location
+ self.train_dataset = PyTorchGTSRB(
+ root=location,
+ download=True,
+ split='train',
+ transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers
+ )
+
+ self.test_dataset = PyTorchGTSRB(
+ root=location,
+ download=True,
+ split='test',
+ transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers
+ )
+
+ # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md
+ self.classnames = [
+ 'red and white circle 20 kph speed limit',
+ 'red and white circle 30 kph speed limit',
+ 'red and white circle 50 kph speed limit',
+ 'red and white circle 60 kph speed limit',
+ 'red and white circle 70 kph speed limit',
+ 'red and white circle 80 kph speed limit',
+ 'end / de-restriction of 80 kph speed limit',
+ 'red and white circle 100 kph speed limit',
+ 'red and white circle 120 kph speed limit',
+ 'red and white circle red car and black car no passing',
+ 'red and white circle red truck and black car no passing',
+ 'red and white triangle road intersection warning',
+ 'white and yellow diamond priority road',
+ 'red and white upside down triangle yield right-of-way',
+ 'stop',
+ 'empty red and white circle',
+ 'red and white circle no truck entry',
+ 'red circle with white horizonal stripe no entry',
+ 'red and white triangle with exclamation mark warning',
+ 'red and white triangle with black left curve approaching warning',
+ 'red and white triangle with black right curve approaching warning',
+ 'red and white triangle with black double curve approaching warning',
+ 'red and white triangle rough / bumpy road warning',
+ 'red and white triangle car skidding / slipping warning',
+ 'red and white triangle with merging / narrow lanes warning',
+ 'red and white triangle with person digging / construction / road work warning',
+ 'red and white triangle with traffic light approaching warning',
+ 'red and white triangle with person walking warning',
+ 'red and white triangle with child and person walking warning',
+ 'red and white triangle with bicyle warning',
+ 'red and white triangle with snowflake / ice warning',
+ 'red and white triangle with deer warning',
+ 'white circle with gray strike bar no speed limit',
+ 'blue circle with white right turn arrow mandatory',
+ 'blue circle with white left turn arrow mandatory',
+ 'blue circle with white forward arrow mandatory',
+ 'blue circle with white forward or right turn arrow mandatory',
+ 'blue circle with white forward or left turn arrow mandatory',
+ 'blue circle with white keep right arrow mandatory',
+ 'blue circle with white keep left arrow mandatory',
+ 'blue circle with white arrows indicating a traffic circle',
+ 'white circle with gray strike bar indicating no passing for cars has ended',
+ 'white circle with gray strike bar indicating no passing for trucks has ended',
+ ]
diff --git a/src/datasets/imagenet.py b/src/datasets/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9d5a5613047c134ff81c33393d3e88cab3c471b
--- /dev/null
+++ b/src/datasets/imagenet.py
@@ -0,0 +1,253 @@
+import os
+import torch
+
+from .common import ImageFolderWithPaths, SubsetSampler
+import numpy as np
+
+
+imagenet_classnames = [
+ "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
+ "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
+ "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
+ "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
+ "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
+ "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
+ "box turtle", "banded gecko", "green iguana", "Carolina anole",
+ "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
+ "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
+ "American alligator", "triceratops", "worm snake", "ring-necked snake",
+ "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
+ "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
+ "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
+ "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
+ "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
+ "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
+ "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
+ "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
+ "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
+ "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
+ "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
+ "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
+ "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
+ "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
+ "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
+ "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
+ "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
+ "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
+ "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
+ "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
+ "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
+ "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
+ "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
+ "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
+ "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
+ "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
+ "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
+ "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
+ "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
+ "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
+ "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
+ "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
+ "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
+ "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
+ "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
+ "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
+ "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
+ "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
+ "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
+ "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
+ "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
+ "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
+ "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
+ "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
+ "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
+ "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
+ "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
+ "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
+ "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
+ "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
+ "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
+ "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
+ "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
+ "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
+ "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
+ "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
+ "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
+ "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
+ "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
+ "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
+ "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
+ "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
+ "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
+ "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
+ "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
+ "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
+ "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
+ "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
+ "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
+ "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
+ "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
+ "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
+ "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
+ "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
+ "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
+ "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
+ "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
+ "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
+ "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
+ "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
+ "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
+ "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
+ "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
+ "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
+ "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
+ "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
+ "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
+ "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
+ "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
+ "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
+ "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
+ "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
+ "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
+ "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
+ "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
+ "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
+ "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
+ "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
+ "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
+ "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
+ "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
+ "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
+ "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
+ "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
+ "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
+ "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
+ "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
+ "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
+ "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
+ "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
+ "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
+ "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
+ "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
+ "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
+ "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
+ "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
+ "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
+ "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
+ "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
+ "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
+ "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
+ "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
+ "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
+ "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
+ "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
+ "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
+ "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
+ "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
+ "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
+ "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
+ "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
+ "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
+ "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
+ "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
+ "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
+ "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
+ "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
+ "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
+ "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
+ "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
+ "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
+ "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
+ "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
+ "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
+ "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
+ "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
+ "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
+ "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
+ "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
+ "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
+ "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
+ "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
+ "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
+ "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
+ "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"
+]
+
+class ImageNet:
+ def __init__(self,
+ preprocess,
+ location='~/data',
+ batch_size=32,
+ num_workers=32):
+ self.preprocess = preprocess
+ self.location = '/path/ImageNet2012/' # TODO
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.classnames = imagenet_classnames
+
+ self.populate_train()
+ self.populate_test()
+
+ def populate_train(self):
+ traindir = os.path.join(self.location, 'train')
+ self.train_dataset = ImageFolderWithPaths(
+ traindir,
+ transform=self.preprocess)
+ sampler = self.get_train_sampler()
+ kwargs = {'shuffle' : True} if sampler is None else {}
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ sampler=sampler,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ **kwargs,
+ )
+
+ def populate_test(self):
+ self.test_dataset = self.get_test_dataset()
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ sampler=self.get_test_sampler()
+ )
+
+ def get_test_path(self):
+ test_path = os.path.join(self.location, 'val_dir')
+ if not os.path.exists(test_path):
+ test_path = os.path.join(self.location,'val')
+ return test_path
+
+ def get_train_sampler(self):
+ return None
+
+ def get_test_sampler(self):
+ return None
+
+ def get_test_dataset(self):
+ return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
+
+ def name(self):
+ return 'imagenet'
+
+class ImageNetTrain(ImageNet):
+
+ def get_test_dataset(self):
+ pass
+
+class ImageNetK(ImageNet):
+
+ def get_train_sampler(self):
+ idxs = np.zeros(len(self.train_dataset.targets))
+ target_array = np.array(self.train_dataset.targets)
+ for c in range(1000):
+ m = target_array == c
+ n = len(idxs[m])
+ arr = np.zeros(n)
+ arr[:self.k()] = 1
+ np.random.shuffle(arr)
+ idxs[m] = arr
+
+ idxs = idxs.astype('int')
+ sampler = SubsetSampler(np.where(idxs)[0])
+ return sampler
\ No newline at end of file
diff --git a/src/datasets/kmnist.py b/src/datasets/kmnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..243359d7a6cc7aec1bdd99126995daf3d4c4b414
--- /dev/null
+++ b/src/datasets/kmnist.py
@@ -0,0 +1,39 @@
+import os
+
+import torch
+import torchvision.datasets as datasets
+
+
+class KMNIST:
+ def __init__(
+ self,
+ preprocess,
+ location=os.path.expanduser("~/data"),
+ batch_size=128,
+ num_workers=6,
+ ):
+
+ location = os.path.join(location, "KMNIST")
+ self.train_dataset = datasets.KMNIST(
+ root=location, download=True, train=True, transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = datasets.KMNIST(
+ root=location, download=True, train=False, transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ )
+
+ self.classnames = self.train_dataset.classes
diff --git a/src/datasets/mnist.py b/src/datasets/mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd4819351db8edf370311851026eba40b698ad3e
--- /dev/null
+++ b/src/datasets/mnist.py
@@ -0,0 +1,41 @@
+import os
+import torch
+import torchvision.datasets as datasets
+
+class MNIST:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=128,
+ num_workers=16):
+
+
+ self.train_dataset = datasets.MNIST(
+ root=location,
+ download=True,
+ train=True,
+ transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers
+ )
+
+ self.test_dataset = datasets.MNIST(
+ root=location,
+ download=True,
+ train=False,
+ transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers
+ )
+
+ self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
\ No newline at end of file
diff --git a/src/datasets/oxfordpets.py b/src/datasets/oxfordpets.py
new file mode 100644
index 0000000000000000000000000000000000000000..7285b424c6f970df692bd7eb3fd9656d423e7cb0
--- /dev/null
+++ b/src/datasets/oxfordpets.py
@@ -0,0 +1,38 @@
+import os
+import torch
+import torchvision.datasets as datasets
+
+
+class OxfordIIITPet:
+ def __init__(
+ self,
+ preprocess,
+ location=os.path.expanduser("~/data"),
+ batch_size=128,
+ num_workers=6,
+ ):
+
+ location = os.path.join(location, "OxfordIIITPet")
+ self.train_dataset = datasets.OxfordIIITPet(
+ root=location, download=True, split="trainval", transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = datasets.OxfordIIITPet(
+ root=location, download=True, split="test", transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ )
+
+ self.classnames = self.train_dataset.classes
diff --git a/src/datasets/registry.py b/src/datasets/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f732c998918194b21926aa177ffb19502d874a3
--- /dev/null
+++ b/src/datasets/registry.py
@@ -0,0 +1,103 @@
+import sys
+import inspect
+import random
+import torch
+import copy
+
+from torch.utils.data.dataset import random_split
+
+from src.datasets.cars import Cars
+from src.datasets.cifar10 import CIFAR10
+from src.datasets.cifar100 import CIFAR100
+from src.datasets.dtd import DTD
+from src.datasets.eurosat import EuroSAT, EuroSATVal
+from src.datasets.gtsrb import GTSRB
+from src.datasets.imagenet import ImageNet
+from src.datasets.mnist import MNIST
+from src.datasets.resisc45 import RESISC45
+from src.datasets.stl10 import STL10
+from src.datasets.svhn import SVHN
+from src.datasets.sun397 import SUN397
+from src.datasets.emnist import EMNIST
+from src.datasets.kmnist import KMNIST
+from src.datasets.oxfordpets import OxfordIIITPet
+
+registry = {
+ name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
+}
+
+
+class GenericDataset(object):
+ def __init__(self):
+ self.train_dataset = None
+ self.train_loader = None
+ self.test_dataset = None
+ self.test_loader = None
+ self.classnames = None
+
+
+def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0):
+ assert val_fraction > 0. and val_fraction < 1.
+ total_size = len(dataset.train_dataset)
+ val_size = int(total_size * val_fraction)
+ if max_val_samples is not None:
+ val_size = min(val_size, max_val_samples)
+ train_size = total_size - val_size
+
+ assert val_size > 0
+ assert train_size > 0
+
+ lengths = [train_size, val_size]
+
+ trainset, valset = random_split(
+ dataset.train_dataset,
+ lengths,
+ generator=torch.Generator().manual_seed(seed)
+ )
+ if new_dataset_class_name == 'MNISTVal':
+ assert trainset.indices[0] == 36044
+
+
+ new_dataset = None
+
+ new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {})
+ new_dataset = new_dataset_class()
+
+ new_dataset.train_dataset = trainset
+ new_dataset.train_loader = torch.utils.data.DataLoader(
+ new_dataset.train_dataset,
+ shuffle=True,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+
+ new_dataset.test_dataset = valset
+ new_dataset.test_loader = torch.utils.data.DataLoader(
+ new_dataset.test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers
+ )
+
+ new_dataset.classnames = copy.copy(dataset.classnames)
+
+ return new_dataset
+
+
+def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000):
+ if dataset_name.endswith('Val'):
+ # Handle val splits
+ if dataset_name in registry:
+ dataset_class = registry[dataset_name]
+ else:
+ base_dataset_name = dataset_name.split('Val')[0]
+ base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers)
+ dataset = split_train_into_train_val(
+ base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples)
+ return dataset
+ else:
+ assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}'
+ dataset_class = registry[dataset_name]
+ dataset = dataset_class(
+ preprocess, location=location, batch_size=batch_size, num_workers=num_workers
+ )
+ return dataset
diff --git a/src/datasets/resisc45.py b/src/datasets/resisc45.py
new file mode 100644
index 0000000000000000000000000000000000000000..056122dbe5d43751dbdb4ea82048d47b56c6fe38
--- /dev/null
+++ b/src/datasets/resisc45.py
@@ -0,0 +1,304 @@
+import os
+import torch
+
+import abc
+import os
+from typing import Any, Callable, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.datasets import ImageFolder
+from torchvision.datasets.folder import default_loader as pil_loader
+
+
+# modified from: https://github.com/microsoft/torchgeo
+class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
+ """Abstract base class for datasets lacking geospatial information.
+ This base class is designed for datasets with pre-defined image chips.
+ """
+
+ @abc.abstractmethod
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ """Return an index within the dataset.
+ Args:
+ index: index to return
+ Returns:
+ data and labels at that index
+ Raises:
+ IndexError: if index is out of range of the dataset
+ """
+
+ @abc.abstractmethod
+ def __len__(self) -> int:
+ """Return the length of the dataset.
+ Returns:
+ length of the dataset
+ """
+
+ def __str__(self) -> str:
+ """Return the informal string representation of the object.
+ Returns:
+ informal string representation
+ """
+ return f"""\
+{self.__class__.__name__} Dataset
+ type: VisionDataset
+ size: {len(self)}"""
+
+
+class VisionClassificationDataset(VisionDataset, ImageFolder):
+ """Abstract base class for classification datasets lacking geospatial information.
+ This base class is designed for datasets with pre-defined image chips which
+ are separated into separate folders per class.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
+ loader: Optional[Callable[[str], Any]] = pil_loader,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ ) -> None:
+ """Initialize a new VisionClassificationDataset instance.
+ Args:
+ root: root directory where dataset can be found
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ loader: a callable function which takes as input a path to an image and
+ returns a PIL Image or numpy array
+ is_valid_file: A function that takes the path of an Image file and checks if
+ the file is a valid file
+ """
+ # When transform & target_transform are None, ImageFolder.__getitem__(index)
+ # returns a PIL.Image and int for image and label, respectively
+ super().__init__(
+ root=root,
+ transform=None,
+ target_transform=None,
+ loader=loader,
+ is_valid_file=is_valid_file,
+ )
+
+ # Must be set after calling super().__init__()
+ self.transforms = transforms
+
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
+ """Return an index within the dataset.
+ Args:
+ index: index to return
+ Returns:
+ data and label at that index
+ """
+ image, label = self._load_image(index)
+
+ if self.transforms is not None:
+ return self.transforms(image), label
+
+ return image, label
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+ Returns:
+ length of the dataset
+ """
+ return len(self.imgs)
+
+ def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
+ """Load a single image and it's class label.
+ Args:
+ index: index to return
+ Returns:
+ the image
+ the image class label
+ """
+ img, label = ImageFolder.__getitem__(self, index)
+ label = torch.tensor(label)
+ return img, label
+
+
+class RESISC45Dataset(VisionClassificationDataset):
+ """RESISC45 dataset.
+ The `RESISC45 `_
+ dataset is a dataset for remote sensing image scene classification.
+ Dataset features:
+ * 31,500 images with 0.2-30 m per pixel resolution (256x256 px)
+ * three spectral bands - RGB
+ * 45 scene classes, 700 images per class
+ * images extracted from Google Earth from over 100 countries
+ * images conditions with high variability (resolution, weather, illumination)
+ Dataset format:
+ * images are three-channel jpgs
+ Dataset classes:
+ 0. airplane
+ 1. airport
+ 2. baseball_diamond
+ 3. basketball_court
+ 4. beach
+ 5. bridge
+ 6. chaparral
+ 7. church
+ 8. circular_farmland
+ 9. cloud
+ 10. commercial_area
+ 11. dense_residential
+ 12. desert
+ 13. forest
+ 14. freeway
+ 15. golf_course
+ 16. ground_track_field
+ 17. harbor
+ 18. industrial_area
+ 19. intersection
+ 20. island
+ 21. lake
+ 22. meadow
+ 23. medium_residential
+ 24. mobile_home_park
+ 25. mountain
+ 26. overpass
+ 27. palace
+ 28. parking_lot
+ 29. railway
+ 30. railway_station
+ 31. rectangular_farmland
+ 32. river
+ 33. roundabout
+ 34. runway
+ 35. sea_ice
+ 36. ship
+ 37. snowberg
+ 38. sparse_residential
+ 39. stadium
+ 40. storage_tank
+ 41. tennis_court
+ 42. terrace
+ 43. thermal_power_station
+ 44. wetland
+ This dataset uses the train/val/test splits defined in the "In-domain representation
+ learning for remote sensing" paper:
+ * https://arxiv.org/abs/1911.06721
+ If you use this dataset in your research, please cite the following paper:
+ * https://doi.org/10.1109/jproc.2017.2675998
+ """
+
+ # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
+ # md5 = "d824acb73957502b00efd559fc6cfbbb"
+ # filename = "NWPU-RESISC45.rar"
+ directory = "resisc45/NWPU-RESISC45"
+
+ splits = ["train", "val", "test"]
+ split_urls = {
+ "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501
+ "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501
+ "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501
+ }
+ split_md5s = {
+ "train": "b5a4c05a37de15e4ca886696a85c403e",
+ "val": "a0770cee4c5ca20b8c32bbd61e114805",
+ "test": "3dda9e4988b47eb1de9f07993653eb08",
+ }
+ classes = [
+ "airplane",
+ "airport",
+ "baseball_diamond",
+ "basketball_court",
+ "beach",
+ "bridge",
+ "chaparral",
+ "church",
+ "circular_farmland",
+ "cloud",
+ "commercial_area",
+ "dense_residential",
+ "desert",
+ "forest",
+ "freeway",
+ "golf_course",
+ "ground_track_field",
+ "harbor",
+ "industrial_area",
+ "intersection",
+ "island",
+ "lake",
+ "meadow",
+ "medium_residential",
+ "mobile_home_park",
+ "mountain",
+ "overpass",
+ "palace",
+ "parking_lot",
+ "railway",
+ "railway_station",
+ "rectangular_farmland",
+ "river",
+ "roundabout",
+ "runway",
+ "sea_ice",
+ "ship",
+ "snowberg",
+ "sparse_residential",
+ "stadium",
+ "storage_tank",
+ "tennis_court",
+ "terrace",
+ "thermal_power_station",
+ "wetland",
+ ]
+
+ def __init__(
+ self,
+ root: str = "data",
+ split: str = "train",
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
+ ) -> None:
+ """Initialize a new RESISC45 dataset instance.
+ Args:
+ root: root directory where dataset can be found
+ split: one of "train", "val", or "test"
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ """
+ assert split in self.splits
+ self.root = root
+
+ valid_fns = set()
+ with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f:
+ for fn in f:
+ valid_fns.add(fn.strip())
+ is_in_split: Callable[[str], bool] = lambda x: os.path.basename(
+ x) in valid_fns
+
+ super().__init__(
+ root=os.path.join(root, self.directory),
+ transforms=transforms,
+ is_valid_file=is_in_split,
+ )
+
+
+
+class RESISC45:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=32,
+ num_workers=16):
+
+ self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess)
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ shuffle=True,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess)
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers
+ )
+
+ # class names have _ so split on this for better zero-shot head
+ self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes]
diff --git a/src/datasets/stl10.py b/src/datasets/stl10.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c7237f014e6b7983d46e2ab9866fa523b6a0ef4
--- /dev/null
+++ b/src/datasets/stl10.py
@@ -0,0 +1,41 @@
+import os
+import torch
+import torchvision.datasets as datasets
+
+class STL10:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=128,
+ num_workers=16):
+
+ location = os.path.join(location, 'stl10')
+ self.train_dataset = datasets.STL10(
+ root=location,
+ download=True,
+ split='train',
+ transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers
+ )
+
+ self.test_dataset = datasets.STL10(
+ root=location,
+ download=True,
+ split='test',
+ transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers
+ )
+
+ self.classnames = self.train_dataset.classes
\ No newline at end of file
diff --git a/src/datasets/sun397.py b/src/datasets/sun397.py
new file mode 100644
index 0000000000000000000000000000000000000000..684c648146ed78c838a8b155f4ffb7278e3433b4
--- /dev/null
+++ b/src/datasets/sun397.py
@@ -0,0 +1,32 @@
+import os
+import torch
+import torchvision.datasets as datasets
+
+class SUN397:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=32,
+ num_workers=16):
+ # Data loading code
+ traindir = os.path.join(location, 'sun397', 'train')
+ valdir = os.path.join(location, 'sun397', 'val')
+
+
+ self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess)
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ shuffle=True,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+
+ self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess)
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers
+ )
+ idx_to_class = dict((v, k)
+ for k, v in self.train_dataset.class_to_idx.items())
+ self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))]
diff --git a/src/datasets/svhn.py b/src/datasets/svhn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e9b47c7b04b7c419d486afcf0cae467cec93728
--- /dev/null
+++ b/src/datasets/svhn.py
@@ -0,0 +1,45 @@
+import os
+import torch
+from torchvision.datasets import SVHN as PyTorchSVHN
+import numpy as np
+
+
+class SVHN:
+ def __init__(self,
+ preprocess,
+ location=os.path.expanduser('~/data'),
+ batch_size=128,
+ num_workers=16):
+
+ # to fit with repo conventions for location
+ modified_location = os.path.join(location, 'svhn')
+
+ self.train_dataset = PyTorchSVHN(
+ root=modified_location,
+ download=True,
+ split='train',
+ transform=preprocess
+ )
+
+ self.train_loader = torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers
+ )
+
+ self.test_dataset = PyTorchSVHN(
+ root=modified_location,
+ download=True,
+ split='test',
+ transform=preprocess
+ )
+
+ self.test_loader = torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers
+ )
+
+ self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
diff --git a/src/datasets/templates.py b/src/datasets/templates.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fb4a6d8bf4ac341bd06b6165120a1fa72c4f21
--- /dev/null
+++ b/src/datasets/templates.py
@@ -0,0 +1,239 @@
+cars_template = [
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'a photo of the {c}.',
+ lambda c: f'a photo of my {c}.',
+ lambda c: f'i love my {c}!',
+ lambda c: f'a photo of my dirty {c}.',
+ lambda c: f'a photo of my clean {c}.',
+ lambda c: f'a photo of my new {c}.',
+ lambda c: f'a photo of my old {c}.',
+]
+
+cifar10_template = [
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'a blurry photo of a {c}.',
+ lambda c: f'a black and white photo of a {c}.',
+ lambda c: f'a low contrast photo of a {c}.',
+ lambda c: f'a high contrast photo of a {c}.',
+ lambda c: f'a bad photo of a {c}.',
+ lambda c: f'a good photo of a {c}.',
+ lambda c: f'a photo of a small {c}.',
+ lambda c: f'a photo of a big {c}.',
+ lambda c: f'a photo of the {c}.',
+ lambda c: f'a blurry photo of the {c}.',
+ lambda c: f'a black and white photo of the {c}.',
+ lambda c: f'a low contrast photo of the {c}.',
+ lambda c: f'a high contrast photo of the {c}.',
+ lambda c: f'a bad photo of the {c}.',
+ lambda c: f'a good photo of the {c}.',
+ lambda c: f'a photo of the small {c}.',
+ lambda c: f'a photo of the big {c}.',
+]
+
+cifar100_template = [
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'a blurry photo of a {c}.',
+ lambda c: f'a black and white photo of a {c}.',
+ lambda c: f'a low contrast photo of a {c}.',
+ lambda c: f'a high contrast photo of a {c}.',
+ lambda c: f'a bad photo of a {c}.',
+ lambda c: f'a good photo of a {c}.',
+ lambda c: f'a photo of a small {c}.',
+ lambda c: f'a photo of a big {c}.',
+ lambda c: f'a photo of the {c}.',
+ lambda c: f'a blurry photo of the {c}.',
+ lambda c: f'a black and white photo of the {c}.',
+ lambda c: f'a low contrast photo of the {c}.',
+ lambda c: f'a high contrast photo of the {c}.',
+ lambda c: f'a bad photo of the {c}.',
+ lambda c: f'a good photo of the {c}.',
+ lambda c: f'a photo of the small {c}.',
+ lambda c: f'a photo of the big {c}.',
+]
+
+dtd_template = [
+ lambda c: f'a photo of a {c} texture.',
+ lambda c: f'a photo of a {c} pattern.',
+ lambda c: f'a photo of a {c} thing.',
+ lambda c: f'a photo of a {c} object.',
+ lambda c: f'a photo of the {c} texture.',
+ lambda c: f'a photo of the {c} pattern.',
+ lambda c: f'a photo of the {c} thing.',
+ lambda c: f'a photo of the {c} object.',
+]
+
+eurosat_template = [
+ lambda c: f'a centered satellite photo of {c}.',
+ lambda c: f'a centered satellite photo of a {c}.',
+ lambda c: f'a centered satellite photo of the {c}.',
+]
+
+food101_template = [
+ lambda c: f'a photo of {c}, a type of food.',
+]
+
+gtsrb_template = [
+ lambda c: f'a zoomed in photo of a "{c}" traffic sign.',
+ lambda c: f'a centered photo of a "{c}" traffic sign.',
+ lambda c: f'a close up photo of a "{c}" traffic sign.',
+]
+
+mnist_template = [
+ lambda c: f'a photo of the number: "{c}".',
+]
+
+imagenet_template = [
+ lambda c: f'a bad photo of a {c}.',
+ lambda c: f'a photo of many {c}.',
+ lambda c: f'a sculpture of a {c}.',
+ lambda c: f'a photo of the hard to see {c}.',
+ lambda c: f'a low resolution photo of the {c}.',
+ lambda c: f'a rendering of a {c}.',
+ lambda c: f'graffiti of a {c}.',
+ lambda c: f'a bad photo of the {c}.',
+ lambda c: f'a cropped photo of the {c}.',
+ lambda c: f'a tattoo of a {c}.',
+ lambda c: f'the embroidered {c}.',
+ lambda c: f'a photo of a hard to see {c}.',
+ lambda c: f'a bright photo of a {c}.',
+ lambda c: f'a photo of a clean {c}.',
+ lambda c: f'a photo of a dirty {c}.',
+ lambda c: f'a dark photo of the {c}.',
+ lambda c: f'a drawing of a {c}.',
+ lambda c: f'a photo of my {c}.',
+ lambda c: f'the plastic {c}.',
+ lambda c: f'a photo of the cool {c}.',
+ lambda c: f'a close-up photo of a {c}.',
+ lambda c: f'a black and white photo of the {c}.',
+ lambda c: f'a painting of the {c}.',
+ lambda c: f'a painting of a {c}.',
+ lambda c: f'a pixelated photo of the {c}.',
+ lambda c: f'a sculpture of the {c}.',
+ lambda c: f'a bright photo of the {c}.',
+ lambda c: f'a cropped photo of a {c}.',
+ lambda c: f'a plastic {c}.',
+ lambda c: f'a photo of the dirty {c}.',
+ lambda c: f'a jpeg corrupted photo of a {c}.',
+ lambda c: f'a blurry photo of the {c}.',
+ lambda c: f'a photo of the {c}.',
+ lambda c: f'a good photo of the {c}.',
+ lambda c: f'a rendering of the {c}.',
+ lambda c: f'a {c} in a video game.',
+ lambda c: f'a photo of one {c}.',
+ lambda c: f'a doodle of a {c}.',
+ lambda c: f'a close-up photo of the {c}.',
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'the origami {c}.',
+ lambda c: f'the {c} in a video game.',
+ lambda c: f'a sketch of a {c}.',
+ lambda c: f'a doodle of the {c}.',
+ lambda c: f'a origami {c}.',
+ lambda c: f'a low resolution photo of a {c}.',
+ lambda c: f'the toy {c}.',
+ lambda c: f'a rendition of the {c}.',
+ lambda c: f'a photo of the clean {c}.',
+ lambda c: f'a photo of a large {c}.',
+ lambda c: f'a rendition of a {c}.',
+ lambda c: f'a photo of a nice {c}.',
+ lambda c: f'a photo of a weird {c}.',
+ lambda c: f'a blurry photo of a {c}.',
+ lambda c: f'a cartoon {c}.',
+ lambda c: f'art of a {c}.',
+ lambda c: f'a sketch of the {c}.',
+ lambda c: f'a embroidered {c}.',
+ lambda c: f'a pixelated photo of a {c}.',
+ lambda c: f'itap of the {c}.',
+ lambda c: f'a jpeg corrupted photo of the {c}.',
+ lambda c: f'a good photo of a {c}.',
+ lambda c: f'a plushie {c}.',
+ lambda c: f'a photo of the nice {c}.',
+ lambda c: f'a photo of the small {c}.',
+ lambda c: f'a photo of the weird {c}.',
+ lambda c: f'the cartoon {c}.',
+ lambda c: f'art of the {c}.',
+ lambda c: f'a drawing of the {c}.',
+ lambda c: f'a photo of the large {c}.',
+ lambda c: f'a black and white photo of a {c}.',
+ lambda c: f'the plushie {c}.',
+ lambda c: f'a dark photo of a {c}.',
+ lambda c: f'itap of a {c}.',
+ lambda c: f'graffiti of the {c}.',
+ lambda c: f'a toy {c}.',
+ lambda c: f'itap of my {c}.',
+ lambda c: f'a photo of a cool {c}.',
+ lambda c: f'a photo of a small {c}.',
+ lambda c: f'a tattoo of the {c}.',
+]
+
+resisc45_template = [
+ lambda c: f'satellite imagery of {c}.',
+ lambda c: f'aerial imagery of {c}.',
+ lambda c: f'satellite photo of {c}.',
+ lambda c: f'aerial photo of {c}.',
+ lambda c: f'satellite view of {c}.',
+ lambda c: f'aerial view of {c}.',
+ lambda c: f'satellite imagery of a {c}.',
+ lambda c: f'aerial imagery of a {c}.',
+ lambda c: f'satellite photo of a {c}.',
+ lambda c: f'aerial photo of a {c}.',
+ lambda c: f'satellite view of a {c}.',
+ lambda c: f'aerial view of a {c}.',
+ lambda c: f'satellite imagery of the {c}.',
+ lambda c: f'aerial imagery of the {c}.',
+ lambda c: f'satellite photo of the {c}.',
+ lambda c: f'aerial photo of the {c}.',
+ lambda c: f'satellite view of the {c}.',
+ lambda c: f'aerial view of the {c}.',
+]
+
+stl10_template = [
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'a photo of the {c}.',
+]
+
+sun397_template = [
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'a photo of the {c}.',
+]
+
+svhn_template = [
+ lambda c: f'a photo of the number: "{c}".',
+]
+
+oxfordpets_template = [
+ lambda c: f"a photo of a {c}, a type of pet.",
+]
+
+emnist_template = [
+ lambda c: f'a photo of the number: "{c}".',
+]
+
+kmnist_template = [
+ lambda c: f"a photo of the character {c}.",
+]
+
+dataset_to_template = {
+ 'Cars': cars_template,
+ 'CIFAR10': cifar10_template,
+ 'CIFAR100': cifar100_template,
+ 'DTD': dtd_template,
+ 'EuroSAT': eurosat_template,
+ 'Food101': food101_template,
+ 'GTSRB': gtsrb_template,
+ 'MNIST': mnist_template,
+ 'ImageNet': imagenet_template,
+ 'RESISC45': resisc45_template,
+ 'STL10': stl10_template,
+ 'SUN397': sun397_template,
+ 'SVHN': svhn_template,
+ "OxfordIIITPet": oxfordpets_template,
+ "EMNIST": emnist_template,
+ "KMNIST": kmnist_template,
+}
+
+
+def get_templates(dataset_name):
+ if dataset_name.endswith('Val'):
+ return get_templates(dataset_name.replace('Val', ''))
+ assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}'
+ return dataset_to_template[dataset_name]
\ No newline at end of file
diff --git a/src/distributed.py b/src/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f401bcf4694e1b05a90e8041b7d11af04316713
--- /dev/null
+++ b/src/distributed.py
@@ -0,0 +1,39 @@
+import os
+
+import torch
+
+
+def setup_ddp(rank, world_size, port=12357):
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
+
+ # initialize the process group
+ torch.distributed.init_process_group(
+ "nccl",
+ rank=rank,
+ world_size=world_size,
+ )
+ torch.cuda.set_device(rank)
+ torch.distributed.barrier()
+
+
+def cleanup_ddp():
+ torch.distributed.destroy_process_group()
+
+
+def is_main_process():
+ return torch.distributed.get_rank() == 0
+
+
+def distribute_loader(loader):
+ return torch.utils.data.DataLoader(
+ loader.dataset,
+ batch_size=loader.batch_size // torch.distributed.get_world_size(),
+ sampler=torch.utils.data.distributed.DistributedSampler(
+ loader.dataset,
+ num_replicas=torch.distributed.get_world_size(),
+ rank=torch.distributed.get_rank(),
+ ),
+ num_workers=loader.num_workers,
+ pin_memory=loader.pin_memory,
+ )
diff --git a/src/eval.py b/src/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ddbd85b6190dec9ae01fe43607fed1b02da76be
--- /dev/null
+++ b/src/eval.py
@@ -0,0 +1,126 @@
+import numpy as np
+import torch
+import tqdm
+
+from src import utils
+from src.datasets.common import get_dataloader, maybe_dictionarize
+from src.datasets.registry import get_dataset
+from src.heads import get_classification_head
+from src.linearize import LinearizedImageEncoder
+from src.modeling import ImageClassifier
+
+
+def eval_single_dataset(image_encoder, dataset_name, args):
+ classification_head = get_classification_head(args, dataset_name)
+ model = ImageClassifier(image_encoder, classification_head)
+
+ model.eval()
+
+ dataset = get_dataset(
+ dataset_name,
+ model.val_preprocess,
+ location=args.data_location,
+ batch_size=args.batch_size,
+ )
+ dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None)
+ device = args.device
+
+ with torch.no_grad():
+ top1, correct, n = 0.0, 0.0, 0.0
+ for _, data in enumerate(tqdm.tqdm(dataloader)):
+ data = maybe_dictionarize(data)
+ x = data["images"].to(device)
+ y = data["labels"].to(device)
+
+ logits = utils.get_logits(x, model)
+
+ pred = logits.argmax(dim=1, keepdim=True).to(device)
+
+ correct += pred.eq(y.view_as(pred)).sum().item()
+
+ n += y.size(0)
+
+ top1 = correct / n
+
+ metrics = {"top1": top1}
+ print(f"Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%")
+
+ return metrics
+
+
+def evaluate(image_encoder, args):
+ if args.eval_datasets is None:
+ return
+ per_dataset_results = {}
+ eval_datasets = (
+ args.eval_datasets
+ if args.control_dataset is None
+ else args.eval_datasets + [args.control_dataset]
+ )
+ for dataset_name in eval_datasets:
+ print("Evaluating on", dataset_name)
+
+ results = eval_single_dataset(image_encoder, dataset_name, args)
+
+ print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}")
+ per_dataset_results[dataset_name + ":top1"] = results["top1"]
+
+ return per_dataset_results
+
+
+def evaluate_task_vector_at_coef(
+ task_vector, pretrained_checkpoint, args, scaling_coef, posthoc_linearization=False
+):
+ image_encoder = task_vector.apply_to(
+ pretrained_checkpoint, scaling_coef=scaling_coef
+ )
+ if posthoc_linearization:
+ pretrained_encoder = task_vector.apply_to(
+ pretrained_checkpoint, scaling_coef=0.0
+ )
+ image_encoder = LinearizedImageEncoder(
+ init_encoder=pretrained_encoder, image_encoder=image_encoder, args=args
+ )
+ coef_info = evaluate(image_encoder, args)
+
+ coef_info = add_normalized_accuracy(coef_info, args)
+ coef_info["avg_normalized_top1"] = np.mean(
+ [coef_info[dataset + ":normalized_top1"] for dataset in args.eval_datasets]
+ )
+ coef_info["avg_top1"] = np.mean(
+ [coef_info[dataset + ":top1"] for dataset in args.eval_datasets]
+ )
+
+ return coef_info
+
+
+def evaluate_task_vector(
+ task_vector, pretrained_checkpoint, args, posthoc_linearization=False
+):
+ info = {}
+ for scaling_coef in np.linspace(0.0, 1.0, args.n_eval_points):
+ print(f"Evaluating for scaling coefficient {scaling_coef:.2f}")
+ info[scaling_coef] = evaluate_task_vector_at_coef(
+ task_vector,
+ pretrained_checkpoint,
+ args,
+ scaling_coef,
+ posthoc_linearization,
+ )
+
+ return info
+
+
+def add_normalized_accuracy(results, args):
+ for dataset_name in args.eval_datasets:
+ results[dataset_name + ":normalized_top1"] = (
+ results[dataset_name + ":top1"] / args.finetuning_accuracies[dataset_name]
+ )
+
+ return results
+
+
+def nonlinear_advantage(acc_linear, acc_nonlinear, num_classes):
+ err_linear = 1 - acc_linear
+ err_nonlinear = 1 - acc_nonlinear
+ return (err_linear - err_nonlinear) * num_classes / (num_classes - 1)
diff --git a/src/eval_single_task.py b/src/eval_single_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b2b7ab154058302a50c36c640c6a78bcc8f7726
--- /dev/null
+++ b/src/eval_single_task.py
@@ -0,0 +1,115 @@
+import json
+import os
+
+from src.args import parse_arguments
+from src.eval import eval_single_dataset
+from src.linearize import LinearizedImageEncoder
+from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
+from src.attention_only_finetune import AttentionOnlyFinetuneEncoder
+
+args = parse_arguments()
+if 'ortho' in args.finetuning_mode:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}"
+else:
+ if args.seed is not None:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}"
+ else:
+ args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}"
+
+accuracies = {}
+
+print("*" * 100)
+mode_labels = {
+ "standard": "Evaluating non-linear FT models.",
+ "standard_ortho": "Evaluating standard FT models with orthogonality regularization.",
+ "linear": "Evaluating linear FT models.",
+ "linear_ortho": "Evaluating linear FT models with orthogonality regularization.",
+ "linear-2": "Evaluating Attention-Only Finetune models.",
+ "linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.",
+}
+print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models."))
+
+for dataset in [
+ "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
+]:
+ print("*" * 100)
+ print(f"Evaluating on {dataset}")
+
+ mode = args.finetuning_mode
+
+ if mode == "standard":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
+ try:
+ task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+ image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0)
+ except FileNotFoundError:
+ print(f"Error: Could not find checkpoints for {dataset}.")
+ continue
+
+ elif mode == "standard_ortho":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt"
+ try:
+ task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+ image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0)
+ except FileNotFoundError:
+ print(f"Error: Could not find checkpoints for {dataset}.")
+ continue
+
+ elif mode == "linear":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
+ try:
+ task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+ image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0)
+ except FileNotFoundError:
+ print(f"Error: Could not find checkpoints for {dataset}.")
+ continue
+
+ elif mode == "linear_ortho":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt"
+ try:
+ task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+ image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0)
+ except FileNotFoundError:
+ print(f"Error: Could not find checkpoints for {dataset}.")
+ continue
+
+ elif mode in ("linear-2", "linear-2_ortho"):
+ prefix = mode + "_"
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt"
+ try:
+ task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+ image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0)
+ except FileNotFoundError:
+ print(f"Error: Could not find checkpoints for {dataset} with mode {mode}.")
+ continue
+
+ else:
+ print(f"Unknown finetuning mode: {mode}")
+ continue
+
+ for split in ["test", "val"]:
+ print("=" * 100)
+ print(f"Evaluating on {split} split.")
+ eval_dataset = dataset if split == "test" else f"{dataset}Val"
+ accuracies[eval_dataset] = eval_single_dataset(image_encoder, eval_dataset, args)["top1"]
+
+# Save results
+save_name_map = {
+ "standard": "ft_accuracies.json",
+ "standard_ortho": "standard_ortho_ft_accuracies.json",
+ "linear": "linear_ft_accuracies.json",
+ "linear_ortho": "linear_ortho_ft_accuracies.json",
+ "linear-2": "linear-2_ft_accuracies.json",
+ "linear-2_ortho": "linear-2_ortho_ft_accuracies.json",
+}
+
+save_path = os.path.join(args.save, save_name_map[args.finetuning_mode])
+os.makedirs(os.path.dirname(save_path), exist_ok=True)
+with open(save_path, "w") as f:
+ json.dump(accuracies, f, indent=4)
+print(f"Results saved to {save_path}")
diff --git a/src/eval_task_addition.py b/src/eval_task_addition.py
new file mode 100644
index 0000000000000000000000000000000000000000..987e89d65ad62d1c2356ba78ed45c2cd78275577
--- /dev/null
+++ b/src/eval_task_addition.py
@@ -0,0 +1,154 @@
+import json
+import os
+
+from utils import find_optimal_coef
+
+from src.args import parse_arguments
+from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef
+from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
+from src.attention_only_finetune import AttentionOnlyFinetuneEncoder
+
+args = parse_arguments()
+
+if 'ortho' in args.finetuning_mode:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}"
+else:
+ if args.seed is not None:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}"
+ else:
+ args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}"
+
+print("*" * 100)
+mode_labels = {
+ "standard": "Evaluating non-linear FT models.",
+ "standard_ortho": "Evaluating standard FT models with orthogonality regularization.",
+ "linear": "Evaluating linear FT models.",
+ "linear_ortho": "Evaluating linear FT models with orthogonality regularization.",
+ "linear-2": "Evaluating Attention-Only Finetune models.",
+ "linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.",
+}
+ft_accuracies_name_map = {
+ "standard": "ft_accuracies.json",
+ "standard_ortho": "standard_ortho_ft_accuracies.json",
+ "linear": "linear_ft_accuracies.json",
+ "linear_ortho": "linear_ortho_ft_accuracies.json",
+ "linear-2": "linear-2_ft_accuracies.json",
+ "linear-2_ortho": "linear-2_ortho_ft_accuracies.json",
+}
+print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models."))
+print("*" * 100)
+
+ft_accuracies_path = os.path.join(args.save, ft_accuracies_name_map[args.finetuning_mode])
+with open(ft_accuracies_path) as f:
+ args.finetuning_accuracies = json.load(f)
+
+if args.seed is not None:
+ base_model_save_path = f"checkpoints_{args.seed}/{args.model}"
+else:
+ base_model_save_path = f"checkpoints/{args.model}"
+
+with open(os.path.join(base_model_save_path, "zeroshot_accuracies.json")) as f:
+ pretrained_accuracies = json.load(f)
+
+eval_datasets = [
+ "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SVHN", "SUN397",
+]
+
+task_vectors = []
+mode = args.finetuning_mode
+
+for dataset in eval_datasets:
+ if mode == "linear":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
+ task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint))
+
+ elif mode == "linear_ortho":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt"
+ task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint))
+
+ elif mode == "standard_ortho":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt"
+ task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))
+
+ elif mode in ("linear-2", "linear-2_ortho"):
+ prefix = mode + "_"
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt"
+ if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
+ print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
+ continue
+ task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))
+
+ else: # standard
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
+ task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint))
+
+if not task_vectors:
+ print("No task vectors were created. Exiting.")
+ exit()
+
+task_vector = sum(task_vectors)
+
+# Determine the base pretrained checkpoint
+mode_prefix_map = {
+ "standard": "",
+ "standard_ortho": "standard_ortho_",
+ "linear": "linear_",
+ "linear_ortho": "linear_ortho_",
+ "linear-2": "linear-2_",
+ "linear-2_ortho": "linear-2_ortho_",
+}
+mode_prefix = mode_prefix_map[mode]
+pretrained_checkpoint = f"{args.save}/{eval_datasets[0]}Val/{mode_prefix}zeroshot.pt"
+
+if not os.path.exists(pretrained_checkpoint):
+ print(f"Error: Base pretrained checkpoint not found at {pretrained_checkpoint}")
+ exit()
+
+args.eval_datasets = [dataset + "Val" for dataset in eval_datasets]
+args.control_dataset = None
+
+val_metrics = evaluate_task_vector(
+ task_vector,
+ pretrained_checkpoint,
+ args,
+ posthoc_linearization=False,
+)
+
+optimal_coef = find_optimal_coef(
+ val_metrics,
+ metric="avg_normalized_top1",
+ minimize=False,
+)
+
+args.eval_datasets = [dataset for dataset in eval_datasets]
+test_metrics = evaluate_task_vector_at_coef(
+ task_vector,
+ pretrained_checkpoint,
+ args,
+ float(optimal_coef),
+ posthoc_linearization=False,
+)
+
+print("=" * 100)
+print(f"Optimal Coefficient: {optimal_coef}")
+print(f"Test normalized accuracy: {test_metrics['avg_normalized_top1']}")
+print(f"Test absolute accuracy: {test_metrics['avg_top1']}")
+additive_accuracies = {"test": test_metrics, "val": val_metrics, "optimal_coef": optimal_coef}
+
+save_name_map = {
+ "standard": "additions.json",
+ "standard_ortho": "standard_ortho_additions.json",
+ "linear": "linear_additions.json",
+ "linear_ortho": "linear_ortho_additions.json",
+ "linear-2": "linear-2_additions.json",
+ "linear-2_ortho": "linear-2_ortho_additions.json",
+}
+save_file = os.path.join(args.save, save_name_map[mode])
+with open(save_file, "w") as f:
+ json.dump(additive_accuracies, f, indent=4)
+print(f"Addition results saved to {save_file}")
diff --git a/src/eval_task_negation.py b/src/eval_task_negation.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d115405e992004bc9200e28fcaadbdbc3f5efe
--- /dev/null
+++ b/src/eval_task_negation.py
@@ -0,0 +1,163 @@
+import json
+import os
+
+from utils import find_optimal_coef
+
+from src.args import parse_arguments
+from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef
+from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector
+from src.attention_only_finetune import AttentionOnlyFinetuneEncoder
+
+args = parse_arguments()
+
+if 'ortho' in args.finetuning_mode:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}"
+else:
+ if args.seed is not None:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}"
+ else:
+ args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}"
+
+if args.seed is not None:
+ base_model_save_path = f"checkpoints_{args.seed}/{args.model}"
+else:
+ base_model_save_path = f"checkpoints/{args.model}"
+
+with open(os.path.join(base_model_save_path, "zeroshot_accuracies.json")) as f:
+ pretrained_accuracies = json.load(f)
+
+eval_datasets = [
+ "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN",
+]
+
+print("*" * 100)
+mode_labels = {
+ "standard": "Evaluating non-linear FT models.",
+ "standard_ortho": "Evaluating standard FT models with orthogonality regularization.",
+ "linear": "Evaluating linear FT models.",
+ "linear_ortho": "Evaluating linear FT models with orthogonality regularization.",
+ "linear-2": "Evaluating Attention-Only Finetune models.",
+ "linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.",
+}
+ft_accuracies_name_map = {
+ "standard": "ft_accuracies.json",
+ "standard_ortho": "standard_ortho_ft_accuracies.json",
+ "linear": "linear_ft_accuracies.json",
+ "linear_ortho": "linear_ortho_ft_accuracies.json",
+ "linear-2": "linear-2_ft_accuracies.json",
+ "linear-2_ortho": "linear-2_ortho_ft_accuracies.json",
+}
+print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models."))
+print("*" * 100)
+
+ft_accuracies_path = os.path.join(args.save, ft_accuracies_name_map[args.finetuning_mode])
+with open(ft_accuracies_path) as f:
+ args.finetuning_accuracies = json.load(f)
+
+control_dataset = "ImageNet"
+negation_accuracies = {}
+mode = args.finetuning_mode
+
+for dataset in eval_datasets:
+ task_vector = None
+ pretrained_checkpoint = None
+
+ if mode == "linear":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt"
+ if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
+ print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
+ continue
+ task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+
+ elif mode == "linear_ortho":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt"
+ if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
+ print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
+ continue
+ task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+
+ elif mode == "standard_ortho":
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt"
+ if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
+ print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
+ continue
+ task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+
+ elif mode in ("linear-2", "linear-2_ortho"):
+ prefix = mode + "_"
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt"
+ if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
+ print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
+ continue
+ task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+
+ else: # standard
+ pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt"
+ finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt"
+ if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)):
+ print(f"Warning: Missing checkpoints for {dataset}. Skipping.")
+ continue
+ task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)
+
+ if not os.path.exists(pretrained_checkpoint):
+ print(f"Error: Base pretrained checkpoint not found at {pretrained_checkpoint}. Skipping {dataset}.")
+ continue
+
+ task_vector = -task_vector
+
+ args.eval_datasets = [dataset + "Val"]
+ args.control_dataset = control_dataset + "Val"
+ val_metrics = evaluate_task_vector(
+ task_vector,
+ pretrained_checkpoint,
+ args,
+ posthoc_linearization=False,
+ )
+
+ optimal_coef = find_optimal_coef(
+ val_metrics,
+ metric=f"{dataset}Val:top1",
+ minimize=True,
+ control_metric=f"{control_dataset}Val:top1",
+ control_metric_threshold=args.control_threshold * pretrained_accuracies[control_dataset + "Val"],
+ )
+
+ args.eval_datasets = [dataset]
+ args.control_dataset = control_dataset
+ test_metrics = evaluate_task_vector_at_coef(
+ task_vector,
+ pretrained_checkpoint,
+ args,
+ optimal_coef,
+ posthoc_linearization=False,
+ )
+
+ print("=" * 100)
+ print(f"Results for dataset: {dataset}")
+ print(f"Optimal Coefficient: {optimal_coef}")
+ print(f"Test accuracy: {test_metrics.get(f'{dataset}:top1', 'N/A')}")
+ print(f"Control accuracy on {control_dataset}: {test_metrics.get(f'{control_dataset}:top1', 'N/A')}")
+
+ negation_accuracies[dataset] = {
+ "test": test_metrics.get(f"{dataset}:top1"),
+ "test_control": test_metrics.get(f"{control_dataset}:top1"),
+ "val": val_metrics,
+ "optimal_coef": optimal_coef,
+ }
+
+save_name_map = {
+ "standard": "negations.json",
+ "standard_ortho": "standard_ortho_negations.json",
+ "linear": "linear_negations.json",
+ "linear_ortho": "linear_ortho_negations.json",
+ "linear-2": "linear-2_negations.json",
+ "linear-2_ortho": "linear-2_ortho_negations.json",
+}
+save_file = os.path.join(args.save, save_name_map[mode])
+with open(save_file, "w") as f:
+ json.dump(negation_accuracies, f, indent=4)
+print(f"Negation results saved to {save_file}")
diff --git a/src/finetune.py b/src/finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d0bb949408d3bf6c00171af91be635bab3ab06b
--- /dev/null
+++ b/src/finetune.py
@@ -0,0 +1,243 @@
+import os
+import time
+
+import torch
+
+from src.args import parse_arguments
+from src.datasets.common import get_dataloader, maybe_dictionarize
+from src.datasets.registry import get_dataset
+from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp
+from src.eval import eval_single_dataset
+from src.heads import get_classification_head
+from src.linearize import LinearizedImageEncoder
+from src.modeling import ImageClassifier, ImageEncoder
+from src.attention_only_finetune import AttentionOnlyFinetuneEncoder
+from src.utils import LabelSmoothing, cosine_lr, accuracy
+
+
+def finetune(rank, args):
+ setup_ddp(rank, args.world_size, port=args.port)
+
+ train_dataset = args.train_dataset
+ ckpdir = os.path.join(args.save, train_dataset)
+
+ valid_modes = [
+ "standard", "standard_ortho",
+ "linear", "linear_ortho",
+ "linear-2", "linear-2_ortho",
+ ]
+ assert args.finetuning_mode in valid_modes, f"Mode {args.finetuning_mode} not supported."
+
+ is_linearized = args.finetuning_mode in ("linear", "linear_ortho")
+ is_linear2 = args.finetuning_mode in ("linear-2", "linear-2_ortho")
+ is_standard_ortho = args.finetuning_mode == "standard_ortho"
+ is_linear_ortho = args.finetuning_mode == "linear_ortho"
+ is_linear2_ortho = args.finetuning_mode == "linear-2_ortho"
+ needs_ortho = is_standard_ortho or is_linear_ortho or is_linear2_ortho
+
+ print(f"Using fine-tuning mode: {args.finetuning_mode}")
+ if needs_ortho and args.ortho_lambda > 0:
+ print(f" -> With OrthoReg (lambda={args.ortho_lambda})")
+
+ mode_prefix_map = {
+ "standard": "",
+ "standard_ortho": "standard_ortho",
+ "linear": "linear",
+ "linear_ortho": "linear_ortho",
+ "linear-2": "linear-2",
+ "linear-2_ortho": "linear-2_ortho",
+ }
+ mode_prefix = mode_prefix_map[args.finetuning_mode]
+
+ ft_path = os.path.join(ckpdir, f"{mode_prefix}_finetuned.pt" if mode_prefix else "finetuned.pt")
+ zs_path = os.path.join(ckpdir, f"{mode_prefix}_zeroshot.pt" if mode_prefix else "zeroshot.pt")
+
+ if os.path.exists(zs_path) and os.path.exists(ft_path):
+ print(f"Skipping fine-tuning because {ft_path} exists.")
+ return zs_path, ft_path
+
+ assert train_dataset is not None, "Please provide a training dataset."
+
+ if args.load is not None and args.load.endswith("pt"):
+ if is_linearized:
+ image_encoder = LinearizedImageEncoder.load(args.load)
+ elif is_linear2:
+ image_encoder = AttentionOnlyFinetuneEncoder.load(args.load, args)
+ else:
+ image_encoder = ImageEncoder.load(args.load)
+ else:
+ print("Building image encoder.")
+ if is_linearized:
+ image_encoder = LinearizedImageEncoder(args, keep_lang=False)
+ elif is_linear2:
+ image_encoder = AttentionOnlyFinetuneEncoder(args, keep_lang=False)
+ else:
+ image_encoder = ImageEncoder(args)
+
+ # Save a frozen copy of pretrained weights for ortho loss (standard_ortho / linear-2_ortho)
+ pretrained_state_dict_ref = None
+ if is_standard_ortho or is_linear2_ortho:
+ print("Saving pretrained state dict reference for ortho loss.")
+ pretrained_state_dict_ref = {
+ k: v.clone().detach() for k, v in image_encoder.model.state_dict().items()
+ }
+
+ classification_head = get_classification_head(args, train_dataset)
+ model = ImageClassifier(image_encoder, classification_head)
+ model.freeze_head()
+ model = model.cuda()
+
+ preprocess_fn = model.train_preprocess
+ print_every = 100
+
+ dataset = get_dataset(
+ train_dataset,
+ preprocess_fn,
+ location=args.data_location,
+ batch_size=args.batch_size,
+ )
+ data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
+ num_batches = len(dataset.train_loader)
+
+ ddp_loader = distribute_loader(data_loader)
+ ddp_model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[rank],
+ find_unused_parameters=True,
+ output_device=rank,
+ )
+
+ loss_fn = LabelSmoothing(args.ls) if args.ls > 0 else torch.nn.CrossEntropyLoss()
+
+ params = [p for p in ddp_model.parameters() if p.requires_grad]
+ optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
+ scheduler = cosine_lr(
+ optimizer,
+ args.lr,
+ args.warmup_length,
+ args.epochs * num_batches // args.num_grad_accumulation,
+ )
+
+ if args.save is not None and is_main_process():
+ os.makedirs(ckpdir, exist_ok=True)
+ ddp_model.module.image_encoder.save(zs_path)
+
+ for epoch in range(args.epochs):
+ ddp_model.train()
+
+ for i, batch in enumerate(ddp_loader):
+ start_time = time.time()
+ step = (
+ i // args.num_grad_accumulation
+ + epoch * num_batches // args.num_grad_accumulation
+ )
+
+ batch = maybe_dictionarize(batch)
+ inputs = batch["images"].cuda()
+ labels = batch["labels"].cuda()
+ data_time = time.time() - start_time
+
+ ortho_loss = 0.0
+ if needs_ortho and args.ortho_lambda > 0:
+ logits, ortho_loss = ddp_model(
+ inputs,
+ calculate_ortho_loss=True,
+ pretrained_state_dict=pretrained_state_dict_ref,
+ )
+ else:
+ logits = ddp_model(inputs)
+
+ classification_loss = loss_fn(logits, labels)
+ loss = classification_loss + args.ortho_lambda * ortho_loss
+
+ (acc1,) = accuracy(logits, labels, topk=(1,))
+ acc1 /= labels.size(0)
+
+ loss.backward()
+
+ if (i + 1) % args.num_grad_accumulation == 0:
+ scheduler(step)
+ torch.nn.utils.clip_grad_norm_(params, 1.0)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ batch_time = time.time() - start_time
+
+ if (
+ args.checkpoint_every > 0
+ and step % args.checkpoint_every == 0
+ and is_main_process()
+ ):
+ ckpt_name = f"{mode_prefix}_checkpoint_{step}.pt" if mode_prefix else f"checkpoint_{step}.pt"
+ ddp_model.module.image_encoder.save(os.path.join(ckpdir, ckpt_name))
+
+ if (
+ step % print_every == 0
+ and ((i + 1) % args.num_grad_accumulation == 0)
+ and is_main_process()
+ ):
+ percent_complete = 100 * i / len(ddp_loader)
+ log_msg = (
+ f"Train Epoch: {epoch} [{percent_complete:.0f}%]\t"
+ f"Total Loss: {loss.item():.6f}\t"
+ f"CE Loss: {classification_loss.item():.6f}\t"
+ )
+ if needs_ortho and args.ortho_lambda > 0:
+ log_msg += f"Ortho Loss: {ortho_loss.item():.6f}\t"
+ log_msg += f"Acc@1: {100*acc1:.2f}%\tData (t) {data_time:.3f}"
+ print(log_msg, flush=True)
+
+ if is_main_process():
+ image_encoder = ddp_model.module.image_encoder
+ eval_single_dataset(image_encoder, train_dataset, args)
+
+ if args.save is not None and is_main_process():
+ image_encoder.save(ft_path)
+ return zs_path, ft_path
+
+ cleanup_ddp()
+
+
+if __name__ == "__main__":
+ train_datasets = [
+ "Cars",
+ "DTD",
+ "EuroSAT",
+ "GTSRB",
+ "MNIST",
+ "RESISC45",
+ "SUN397",
+ "SVHN",
+ ]
+ epochs = {
+ "Cars": 35,
+ "DTD": 76,
+ "EuroSAT": 12,
+ "GTSRB": 11,
+ "MNIST": 5,
+ "RESISC45": 15,
+ "SUN397": 14,
+ "SVHN": 4,
+ }
+
+ for dataset in train_datasets:
+ args = parse_arguments()
+
+ args.epochs = epochs[dataset]
+ args.train_dataset = dataset + "Val"
+
+ args.batch_size = 64 if args.model == "ViT-L-14" else 128
+ args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1
+
+ if 'ortho' in args.finetuning_mode:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}"
+ else:
+ if args.seed is not None:
+ args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}"
+ else:
+ args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}"
+
+ print("=" * 100)
+ print(f"Finetuning {args.model} on {dataset}")
+ print("=" * 100)
+ torch.multiprocessing.spawn(finetune, args=(args,), nprocs=args.world_size)
diff --git a/src/heads.py b/src/heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..c60d022284760b621e9b26147541f7b5988ea1ac
--- /dev/null
+++ b/src/heads.py
@@ -0,0 +1,68 @@
+import os
+
+import open_clip
+import torch
+from tqdm import tqdm
+
+from src.datasets.registry import get_dataset
+from src.datasets.templates import get_templates
+from src.modeling import ClassificationHead, ImageEncoder
+
+
+def build_classification_head(model, dataset_name, template, data_location, device):
+ template = get_templates(dataset_name)
+
+ logit_scale = model.logit_scale
+ dataset = get_dataset(dataset_name, None, location=data_location)
+ model.eval()
+ model.to(device)
+
+ print("Building classification head.")
+ with torch.no_grad():
+ zeroshot_weights = []
+ for classname in tqdm(dataset.classnames):
+ texts = []
+ for t in template:
+ texts.append(t(classname))
+ texts = open_clip.tokenize(texts).to(device) # tokenize
+ embeddings = model.encode_text(texts) # embed with text encoder
+ embeddings /= embeddings.norm(dim=-1, keepdim=True)
+
+ embeddings = embeddings.mean(dim=0, keepdim=True)
+ embeddings /= embeddings.norm()
+
+ zeroshot_weights.append(embeddings)
+
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
+ zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
+
+ zeroshot_weights *= logit_scale.exp()
+
+ zeroshot_weights = zeroshot_weights.squeeze().float()
+ zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)
+
+ classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)
+
+ return classification_head
+
+
+def get_classification_head(args, dataset):
+ if not dataset.endswith("Val"):
+ # We want to load the head for the validation set always to be consistent with the one generated at training time.
+ dataset += "Val"
+
+ filename = os.path.join(args.save, f"head_{dataset}.pt")
+ if os.path.exists(filename):
+ print(f"Classification head for {args.model} on {dataset} exists at {filename}")
+ return ClassificationHead.load(filename)
+ print(
+ f"Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch." # noqa: E501
+ )
+ model = ImageEncoder(args, keep_lang=True).model
+ template = get_templates(dataset)
+ classification_head = build_classification_head(
+ model, dataset, template, args.data_location, args.device
+ )
+ os.makedirs(args.save, exist_ok=True)
+ classification_head.save(filename)
+ return classification_head
diff --git a/src/linearize.py b/src/linearize.py
new file mode 100644
index 0000000000000000000000000000000000000000..2523a39d3f3f14654105aa16c3499281a1150a49
--- /dev/null
+++ b/src/linearize.py
@@ -0,0 +1,179 @@
+import abc
+import os
+
+import torch
+import torch.nn as nn
+from functorch import jvp, make_functional_with_buffers
+
+from src.modeling import ImageEncoder
+from src.utils import DotDict
+
+
+class LinearizedModel(nn.Module):
+ """Creates a linearized version of a nn.Module.
+
+ The linearized version of a model is a proper PyTorch model and can be
+ trained as any other nn.Module.
+
+ Args:
+ model (nn.Module): The model to linearize. The trainable parameters of
+ the linearized model will be initialized to the parameters of this
+ model.
+ init_model (nn.Module): A model of the same type as `model` containing
+ the parameters around which the model is initialized. If not
+ provided, `model` is used as the initialization model.
+ """
+
+ def __init__(self, model: nn.Module, init_model: nn.Module = None) -> None:
+ """Initializes the linearized model."""
+ super().__init__()
+ if init_model is None:
+ init_model = model
+
+ func0, params0, self.buffers0 = make_functional_with_buffers(
+ init_model.eval(), disable_autograd_tracking=True
+ )
+ self.func0 = lambda params, x: func0(params, self.buffers0, x)
+
+ _, params, _ = make_functional_with_buffers(
+ model, disable_autograd_tracking=True
+ )
+
+ self.params = nn.ParameterList(params)
+ self.params0 = nn.ParameterList(params0)
+ self._model_name = model.__class__.__name__
+
+ # The intial parameters are not trainable.
+ for p in self.params0:
+ p.requires_grad = False
+
+ # The params are.
+ for p in self.params:
+ p.requires_grad = True
+
+ def forward(self, x, calculate_ortho_loss=False):
+ """
+ Computes the linearized model output and optionally the orthogonality loss.
+ The method name is changed from __call__ to forward for clarity with DDP.
+ """
+ dparams = [p - p0 for p, p0 in zip(self.params, self.params0)]
+ out, dp = jvp(
+ lambda param: self.func0(param, x),
+ (tuple(self.params0),),
+ (tuple(dparams),),
+ )
+ output = out + dp
+
+ if not calculate_ortho_loss:
+ return output
+
+ # --- Integrate Orthogonality Loss Calculation ---
+ ortho_loss = 0.0
+ for p_finetuned, p_pretrained in zip(self.params, self.params0):
+ if p_finetuned.requires_grad and p_finetuned.dim() == 2:
+ delta_W = p_finetuned - p_pretrained
+
+ rows, cols = delta_W.shape
+ if rows < cols:
+ mat = delta_W @ delta_W.T
+ identity = torch.eye(rows, device=delta_W.device)
+ else:
+ mat = delta_W.T @ delta_W
+ identity = torch.eye(cols, device=delta_W.device)
+
+ ortho_loss += torch.norm(mat - identity, p='fro')
+
+ return output, ortho_loss
+
+ # for tau_jp
+ def dp(self, x) -> torch.Tensor:
+ """
+ Computes only the change in output (JVP) due to parameter shift.
+ Used for the 'Reg' penalty calculation.
+ """
+ dparams = [p - p0 for p, p0 in zip(self.params, self.params0)]
+ _, dp = jvp(
+ lambda param: self.func0(param, x),
+ (tuple(self.params0),),
+ (tuple(dparams),),
+ )
+ return dp
+ def __call__(self, x, calculate_ortho_loss=False):
+ return self.forward(x, calculate_ortho_loss)
+
+
+class LinearizedImageEncoder(abc.ABC, nn.Module):
+ """Creates a linearized version of an image encoder."""
+
+ def __init__(
+ self, args=None, keep_lang=False, image_encoder=None, init_encoder=None
+ ):
+ super().__init__()
+ if image_encoder is None:
+ image_encoder = ImageEncoder(args, keep_lang)
+ if init_encoder is None:
+ init_encoder = image_encoder
+
+ # Copy the attributes from the image encoder.
+ self.train_preprocess = image_encoder.train_preprocess
+ self.val_preprocess = image_encoder.val_preprocess
+ self.cache_dir = image_encoder.cache_dir
+
+ self._model_name = self._get_name(args.model)
+ self.model = LinearizedModel(init_model=init_encoder, model=image_encoder)
+
+ def _get_name(self, model_name):
+ if "__pretrained__" in model_name:
+ model_name, _ = model_name.split("__pretrained__", "")
+ return model_name
+
+ def forward(self, x, calculate_ortho_loss=False, pretrained_state_dict=None):
+ # Pass the flag down to the wrapped LinearizedModel
+ return self.model(x, calculate_ortho_loss=calculate_ortho_loss)
+
+ def __call__(self, x, calculate_ortho_loss=False, pretrained_state_dict=None):
+ return self.forward(x, calculate_ortho_loss, pretrained_state_dict)
+
+ def save(self, filename):
+ """Saves the linearized image encoder.
+
+ We save the model name in the state dict so that we can load the
+ correct model when loading the linearized image encoder. Directly using
+ torch.save would not work becuse func0 is not serializable.
+
+ Args:
+ filename (str): The path to save the taylorized image encoder.
+ """
+ if os.path.dirname(filename) != "":
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ state_dict = self.state_dict()
+ state_dict["model_name"] = self._model_name
+
+ torch.save(state_dict, filename)
+
+ @classmethod
+ def load(cls, filename):
+ """Loads a linearized image encoder.
+
+ It first loads the state dict with the model name and then creates the
+ correct model and loads the state dict.
+
+ Args:
+ filename (str): The path to the taylorized image encoder.
+
+ Returns:
+ LinearizedImageEncoder: The loaded taylorized image encoder.
+ """
+ print(f"Loading image encoder from {filename}")
+ state_dict = torch.load(filename, map_location="cpu")
+
+ # ImageEncoder expects a DotDict
+ args = DotDict({"model": state_dict["model_name"]})
+ taylorized_encoder = cls(args)
+
+ # Remove the model name from the state dict so that we can load the
+ # model.
+ state_dict.pop("model_name")
+ taylorized_encoder.load_state_dict(state_dict)
+ return taylorized_encoder
diff --git a/src/modeling.py b/src/modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..846c2366e4ee8f03b3d94852be35c9c16dc37d48
--- /dev/null
+++ b/src/modeling.py
@@ -0,0 +1,209 @@
+import open_clip
+import torch
+
+from src import utils
+
+
+class ImageEncoder(torch.nn.Module):
+ def __init__(self, args, keep_lang=False):
+ super().__init__()
+
+ print(f"Loading {args.model} pre-trained weights.")
+ if "__pretrained__" in args.model:
+ name, pretrained = args.model.split("__pretrained__")
+ elif "__init__" in args.model:
+ print("Using random initialization.")
+ name, pretrained = args.model.split("__init__")[0], None
+ else:
+ name = args.model
+ pretrained = "openai"
+ (
+ self.model,
+ self.train_preprocess,
+ self.val_preprocess,
+ ) = open_clip.create_model_and_transforms(
+ name, pretrained=pretrained, cache_dir=args.openclip_cachedir
+ )
+
+ self.cache_dir = args.cache_dir
+
+ if not keep_lang and hasattr(self.model, "transformer"):
+ delattr(self.model, "transformer")
+
+ # def forward(self, images):
+ # assert self.model is not None
+ # return self.model.encode_image(images)
+
+ # def __call__(self, inputs):
+ # return self.forward(inputs)
+
+ def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None):
+ """
+ Extended forward method to optionally compute and return the orthogonal loss.
+ """
+ # Original forward pass
+ features = self.model.encode_image(images)
+
+ # Return features directly if orthogonal loss is not needed
+ if not calculate_ortho_loss:
+ return features
+
+ # --- Compute orthogonal loss if requested ---
+ # This logic is moved here from utils.py
+ if pretrained_state_dict is None:
+ raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True")
+
+ ortho_loss = 0.0
+ # self.model is the open_clip model (e.g. ViT); iterate over its parameters
+ for name, p_finetuned in self.model.named_parameters():
+ if p_finetuned.requires_grad and p_finetuned.dim() == 2:
+ if name in pretrained_state_dict:
+ p_pretrained = pretrained_state_dict[name].to(p_finetuned.device)
+
+ delta_W = p_finetuned - p_pretrained
+
+ rows, cols = delta_W.shape
+ if rows < cols:
+ mat = delta_W @ delta_W.T
+ identity = torch.eye(rows, device=delta_W.device)
+ else:
+ mat = delta_W.T @ delta_W
+ identity = torch.eye(cols, device=delta_W.device)
+
+ ortho_loss += torch.norm(mat - identity, p='fro')
+
+ return features, ortho_loss
+
+ def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
+ # Ensure __call__ forwards all arguments
+ return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict)
+
+ def save(self, filename):
+ print(f"Saving image encoder to {filename}")
+ utils.torch_save(self, filename)
+
+ @classmethod
+ def load(cls, model_name, filename):
+ print(f"Loading image encoder from {filename}")
+ state_dict = torch.load(filename, map_location="cpu")
+ return cls.load(model_name, state_dict)
+
+ @classmethod
+ def load_from_state_dict(cls, model_name, state_dict):
+ (
+ self.model,
+ self.train_preprocess,
+ self.val_preprocess,
+ ) = open_clip.create_model_and_transforms(
+ name, pretrained=pretrained, cache_dir=args.openclip_cachedir
+ )
+ self.model.load_from_state_dict(state_dict)
+
+
+class ClassificationHead(torch.nn.Linear):
+ def __init__(self, normalize, weights, biases=None):
+ output_size, input_size = weights.shape
+ super().__init__(input_size, output_size)
+ self.normalize = normalize
+ if weights is not None:
+ self.weight = torch.nn.Parameter(weights.clone())
+ if biases is not None:
+ self.bias = torch.nn.Parameter(biases.clone())
+ else:
+ self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
+
+ def forward(self, inputs):
+ if self.normalize:
+ inputs = inputs / inputs.norm(dim=-1, keepdim=True)
+ return super().forward(inputs)
+
+ def __call__(self, inputs):
+ return self.forward(inputs)
+
+ def save(self, filename):
+ print(f"Saving classification head to {filename}")
+ utils.torch_save(self, filename)
+
+ @classmethod
+ def load(cls, filename):
+ print(f"Loading classification head from {filename}")
+ return utils.torch_load(filename)
+
+
+class ImageClassifier(torch.nn.Module):
+ def __init__(self, image_encoder, classification_head):
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.classification_head = classification_head
+ if self.image_encoder is not None:
+ self.train_preprocess = self.image_encoder.train_preprocess
+ self.val_preprocess = self.image_encoder.val_preprocess
+
+ def freeze_head(self):
+ self.classification_head.weight.requires_grad_(False)
+ self.classification_head.bias.requires_grad_(False)
+
+ # def forward(self, inputs):
+ # features = self.image_encoder(inputs)
+ # outputs = self.classification_head(features)
+ # return outputs
+
+ # def __call__(self, inputs):
+ # return self.forward(inputs)
+
+ def forward(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
+ # Forward arguments to image_encoder
+ encoder_output = self.image_encoder(inputs, calculate_ortho_loss, pretrained_state_dict)
+
+ if calculate_ortho_loss:
+ features, ortho_loss = encoder_output
+ outputs = self.classification_head(features)
+ return outputs, ortho_loss
+ else:
+ features = encoder_output
+ outputs = self.classification_head(features)
+ return outputs
+
+ def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
+ return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict)
+
+ def save(self, filename):
+ print(f"Saving image classifier to {filename}")
+ utils.torch_save(self, filename)
+
+ @classmethod
+ def load(cls, filename):
+ print(f"Loading image classifier from {filename}")
+ return utils.torch_load(filename)
+
+
+class MultiHeadImageClassifier(torch.nn.Module):
+ def __init__(self, image_encoder, classification_heads):
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.classification_heads = torch.nn.ModuleList(classification_heads)
+ if self.image_encoder is not None:
+ self.train_preprocess = self.image_encoder.train_preprocess
+ self.val_preprocess = self.image_encoder.val_preprocess
+
+ def freeze_head(self):
+ for idx in range(len(self.classification_heads)):
+ self.classification_heads[idx].weight.requires_grad_(False)
+ self.classification_heads[idx].bias.requires_grad_(False)
+
+ def forward(self, inputs, head_idx):
+ features = self.image_encoder(inputs)
+ outputs = self.classification_heads[head_idx](features)
+ return outputs
+
+ def __call__(self, inputs, head_idx):
+ return self.forward(inputs, head_idx)
+
+ def save(self, filename):
+ print(f"Saving image classifier to {filename}")
+ utils.torch_save(self, filename)
+
+ @classmethod
+ def load(cls, filename):
+ print(f"Loading image classifier from {filename}")
+ return utils.torch_load(filename)
diff --git a/src/task_vectors.py b/src/task_vectors.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b2489a230492f9fc28a41930d76b47ae1c85999
--- /dev/null
+++ b/src/task_vectors.py
@@ -0,0 +1,230 @@
+import abc
+import torch
+
+from src.linearize import LinearizedImageEncoder
+from src.modeling import ImageEncoder
+from src.attention_only_finetune import AttentionOnlyFinetuneEncoder
+
+
+class _TaskVector(abc.ABC):
+ def __init__(
+ self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None
+ ):
+ """
+ Initializes the task vector from a pretrained and a finetuned checkpoints.
+ This can either be done by passing two state dicts (one corresponding to the
+ pretrained model, and another to the finetuned model), or by directly passing in
+ the task vector state dict.
+ """
+ if vector is not None:
+ self.vector = vector
+ else:
+ assert (
+ pretrained_checkpoint is not None and finetuned_checkpoint is not None
+ )
+ with torch.no_grad():
+ pretrained_obj = self._load_checkpoint(pretrained_checkpoint)
+ finetuned_obj = self._load_checkpoint(finetuned_checkpoint)
+
+ if hasattr(pretrained_obj, 'state_dict'):
+ pretrained_state_dict = pretrained_obj.state_dict()
+ else:
+ pretrained_state_dict = pretrained_obj
+
+ if hasattr(finetuned_obj, 'state_dict'):
+ finetuned_state_dict = finetuned_obj.state_dict()
+ else:
+ finetuned_state_dict = finetuned_obj
+
+ self.vector = {}
+ for key in pretrained_state_dict:
+ if pretrained_state_dict[key].dtype not in [torch.float32, torch.float16, torch.bfloat16]:
+ continue
+ if key in finetuned_state_dict:
+ self.vector[key] = (
+ finetuned_state_dict[key] - pretrained_state_dict[key]
+ )
+
+ @abc.abstractmethod
+ def _load_checkpoint(self, checkpoint):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def _cast_to_same_type(self, other):
+ raise NotImplementedError
+
+ def __add__(self, other):
+ other = self._cast_to_same_type(other)
+ with torch.no_grad():
+ new_vector = {}
+ for key in self.vector:
+ if key not in other.vector:
+ print(f"Warning, key {key} is not present in both task vectors.")
+ continue
+ new_vector[key] = self.vector[key] + other.vector[key]
+ return self.__class__(vector=new_vector)
+
+ def __sub__(self, other):
+ return self.__add__(-other)
+
+ def __radd__(self, other):
+ if other is None or isinstance(other, int):
+ return self
+ return self.__add__(other)
+
+ def __neg__(self):
+ with torch.no_grad():
+ new_vector = {}
+ for key in self.vector:
+ new_vector[key] = -self.vector[key]
+ return self.__class__(vector=new_vector)
+
+ def __pow__(self, power):
+ with torch.no_grad():
+ new_vector = {}
+ for key in self.vector:
+ new_vector[key] = self.vector[key] ** power
+ return self.__class__(vector=new_vector)
+
+ def __mul__(self, other):
+ with torch.no_grad():
+ new_vector = {}
+ for key in self.vector:
+ new_vector[key] = other * self.vector[key]
+ return self.__class__(vector=new_vector)
+
+ def dot(self, other):
+ other = self._cast_to_same_type(other)
+ with torch.no_grad():
+ dot_product = 0.0
+ for key in self.vector:
+ if key not in other.vector:
+ print(f"Warning, key {key} is not present in both task vectors.")
+ continue
+ dot_product += torch.sum(self.vector[key] * other.vector[key])
+ return dot_product
+
+ def norm(self):
+ return torch.sqrt(self.dot(self))
+
+ def apply_to(self, pretrained_checkpoint, scaling_coef=1.0):
+ """Apply a task vector to a pretrained model."""
+ with torch.no_grad():
+ pretrained_model = self._load_checkpoint(pretrained_checkpoint)
+
+ if hasattr(pretrained_model, 'state_dict'):
+ new_state_dict = pretrained_model.state_dict()
+ else:
+ new_state_dict = pretrained_model.copy()
+
+ pretrained_state_dict = new_state_dict.copy()
+
+ for key in pretrained_state_dict:
+ if key in self.vector:
+ new_state_dict[key] = (
+ pretrained_state_dict[key] + scaling_coef * self.vector[key]
+ )
+
+ if hasattr(pretrained_model, 'state_dict'):
+ pretrained_model.load_state_dict(new_state_dict)
+ return pretrained_model
+ else:
+ from src.args import parse_arguments
+ args = parse_arguments()
+ if isinstance(self, NonLinearTaskVector):
+ encoder = self._build_model_from_checkpoint(pretrained_checkpoint, args)
+ encoder.load_state_dict(new_state_dict)
+ return encoder
+ else:
+ pretrained_model.load_state_dict(new_state_dict)
+ return pretrained_model
+
+
+class NonLinearTaskVector(_TaskVector):
+ """A task vector for nonlinear models."""
+
+ def _load_checkpoint(self, checkpoint):
+ return torch.load(checkpoint, map_location="cpu")
+
+ def _build_model_from_checkpoint(self, checkpoint_path, args):
+ mode = args.finetuning_mode
+ if mode in ["linear-2", "linear-2_ortho"]:
+ return AttentionOnlyFinetuneEncoder(args)
+ return ImageEncoder(args)
+
+ def apply_to(self, pretrained_checkpoint, scaling_coef=1.0):
+ with torch.no_grad():
+ from src.args import parse_arguments
+ args = parse_arguments()
+ pretrained_model = self._build_model_from_checkpoint(pretrained_checkpoint, args)
+ pretrained_state_dict = torch.load(pretrained_checkpoint, map_location='cpu')
+
+ if hasattr(pretrained_state_dict, 'state_dict'):
+ pretrained_state_dict = pretrained_state_dict.state_dict()
+
+ new_state_dict = pretrained_state_dict.copy()
+
+ for key in pretrained_state_dict:
+ if key in self.vector:
+ new_state_dict[key] += scaling_coef * self.vector[key]
+
+ pretrained_model.load_state_dict(new_state_dict)
+ return pretrained_model
+
+ def _cast_to_same_type(self, other):
+ if isinstance(other, LinearizedTaskVector):
+ return linear_to_nonlinear(other, self.vector.keys())
+ return other
+
+
+class LinearizedTaskVector(_TaskVector):
+ """A task vector for linearized models."""
+
+ def _load_checkpoint(self, checkpoint):
+ return LinearizedImageEncoder.load(checkpoint)
+
+ def apply_to(self, pretrained_checkpoint, scaling_coef=1.0):
+ with torch.no_grad():
+ pretrained_model = self._load_checkpoint(pretrained_checkpoint)
+ new_state_dict = pretrained_model.state_dict()
+ pretrained_state_dict = new_state_dict.copy()
+
+ for key in pretrained_state_dict:
+ if key in self.vector:
+ new_state_dict[key] += scaling_coef * self.vector[key]
+
+ pretrained_model.load_state_dict(new_state_dict)
+ return pretrained_model
+
+ def get_named_parameters(self, param_names):
+ params = {k: v for k, v in self.vector.items() if "model.params0" not in k}
+ return {k: v for k, v in zip(param_names, params.values())}
+
+ def _cast_to_same_type(self, other):
+ if isinstance(other, NonLinearTaskVector):
+ return nonlinear_to_linear(other)
+ return other
+
+
+def nonlinear_to_linear(nonlinear_task_vector):
+ if isinstance(nonlinear_task_vector, LinearizedTaskVector):
+ return nonlinear_task_vector
+ else:
+ linear_params = {
+ f"model.params.{i}": v
+ for i, v in enumerate(nonlinear_task_vector.vector.values())
+ }
+ linear_params.update({
+ f"model.params0.{i}": torch.zeros_like(v)
+ for i, v in enumerate(nonlinear_task_vector.vector.values())
+ })
+ return LinearizedTaskVector(vector=linear_params)
+
+
+def linear_to_nonlinear(linear_task_vector, param_names):
+ if isinstance(linear_task_vector, NonLinearTaskVector):
+ return linear_task_vector
+ else:
+ return NonLinearTaskVector(
+ vector=linear_task_vector.get_named_parameters(param_names)
+ )
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c6260bc271f972a25d4f58e1a4c413cb40f38c2
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,181 @@
+import os
+import pickle
+
+import numpy as np
+import torch
+
+
+def assign_learning_rate(param_group, new_lr):
+ param_group["lr"] = new_lr
+
+
+def _warmup_lr(base_lr, warmup_length, step):
+ return base_lr * (step + 1) / warmup_length
+
+
+def cosine_lr(optimizer, base_lrs, warmup_length, steps):
+ if not isinstance(base_lrs, list):
+ base_lrs = [base_lrs for _ in optimizer.param_groups]
+ assert len(base_lrs) == len(optimizer.param_groups)
+
+ def _lr_adjuster(step):
+ for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
+ if step < warmup_length:
+ lr = _warmup_lr(base_lr, warmup_length, step)
+ else:
+ e = step - warmup_length
+ es = steps - warmup_length
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
+ assign_learning_rate(param_group, lr)
+
+ return _lr_adjuster
+
+
+def accuracy(output, target, topk=(1,)):
+ pred = output.topk(max(topk), 1, True, True)[1].t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+ return [
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
+ for k in topk
+ ]
+
+
+def torch_load_old(save_path, device=None):
+ with open(save_path, "rb") as f:
+ classifier = pickle.load(f)
+ if device is not None:
+ classifier = classifier.to(device)
+ return classifier
+
+
+def torch_save(model, save_path):
+ if os.path.dirname(save_path) != "":
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ torch.save(model, save_path)
+
+
+def torch_load(save_path, device=None):
+ model = torch.load(save_path, map_location="cpu")
+ if device is not None:
+ model = model.to(device)
+ return model
+
+
+def get_logits(inputs, classifier):
+ assert callable(classifier)
+ if hasattr(classifier, "to"):
+ classifier = classifier.to(inputs.device)
+ return classifier(inputs)
+
+
+def get_probs(inputs, classifier):
+ if hasattr(classifier, "predict_proba"):
+ probs = classifier.predict_proba(inputs.detach().cpu().numpy())
+ return torch.from_numpy(probs)
+ logits = get_logits(inputs, classifier)
+ return logits.softmax(dim=1)
+
+
+class LabelSmoothing(torch.nn.Module):
+ def __init__(self, smoothing=0.0):
+ super(LabelSmoothing, self).__init__()
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+
+ def forward(self, x, target):
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
+
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
+ nll_loss = nll_loss.squeeze(1)
+ smooth_loss = -logprobs.mean(dim=-1)
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
+ return loss.mean()
+
+
+class DotDict(dict):
+ """dot.notation access to dictionary attributes"""
+
+ __getattr__ = dict.get
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+
+
+def find_optimal_coef(
+ results,
+ metric="avg_normalized_top1",
+ minimize=False,
+ control_metric=None,
+ control_metric_threshold=0.0,
+):
+ best_coef = None
+ if minimize:
+ best_metric = 1
+ else:
+ best_metric = 0
+ for scaling_coef in results.keys():
+ if control_metric is not None:
+ if results[scaling_coef][control_metric] < control_metric_threshold:
+ print(f"Control metric fell below {control_metric_threshold} threshold")
+ continue
+ if minimize:
+ if results[scaling_coef][metric] < best_metric:
+ best_metric = results[scaling_coef][metric]
+ best_coef = scaling_coef
+ else:
+ if results[scaling_coef][metric] > best_metric:
+ best_metric = results[scaling_coef][metric]
+ best_coef = scaling_coef
+ return best_coef
+
+
+def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes):
+ return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes)
+
+
+def calculate_linearized_orthogonality_loss(linearized_model):
+ """Compute orthogonality loss ||delta_W^T delta_W - I||_F for a LinearizedModel."""
+ ortho_loss = 0.0
+ for p_finetuned, p_pretrained in zip(linearized_model.params, linearized_model.params0):
+ if p_finetuned.requires_grad and p_finetuned.dim() == 2:
+ delta_W = p_finetuned - p_pretrained
+
+ rows, cols = delta_W.shape
+ if rows < cols:
+ mat = delta_W @ delta_W.T
+ identity = torch.eye(rows, device=delta_W.device)
+ else:
+ mat = delta_W.T @ delta_W
+ identity = torch.eye(cols, device=delta_W.device)
+
+ ortho_loss += torch.norm(mat - identity, p='fro')
+
+ return ortho_loss
+
+
+def calculate_standard_orthogonality_loss(model, pretrained_state_dict):
+ """Compute orthogonality loss ||delta_W^T delta_W - I||_F for standard/linear-2 finetuning.
+
+ Args:
+ model: DDP-wrapped ImageClassifier (ddp_model).
+ pretrained_state_dict: snapshot of the pretrained model's inner ViT state_dict.
+ """
+ ortho_loss = 0.0
+
+ for name, p_finetuned in model.module.image_encoder.model.named_parameters():
+ if p_finetuned.requires_grad and p_finetuned.dim() == 2:
+ if name in pretrained_state_dict:
+ p_pretrained = pretrained_state_dict[name].to(p_finetuned.device)
+
+ delta_W = p_finetuned - p_pretrained
+
+ rows, cols = delta_W.shape
+ if rows < cols:
+ mat = delta_W @ delta_W.T
+ identity = torch.eye(rows, device=delta_W.device)
+ else:
+ mat = delta_W.T @ delta_W
+ identity = torch.eye(cols, device=delta_W.device)
+
+ ortho_loss += torch.norm(mat - identity, p='fro')
+
+ return ortho_loss