diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..9c87bd622a152a7102d40e5127a6c332bf8a36f4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/WVO-WD-TFS.png filter=lfs diff=lfs merge=lfs -text +assets/orthoreg_loss.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..603865857d29639fc3ed3299c2e1d2221e8e2bcb --- /dev/null +++ b/README.md @@ -0,0 +1,207 @@ +# Understanding and Enforcing Weight Disentanglement in Task Arithmetic + +[CVPR 2026] Official code of the paper **"Understanding and Enforcing Weight Disentanglement in Task Arithmetic"**. + +[[Paper](https://arxiv.org/abs/2604.17078)]   [[Checkpoints](#-checkpoints)]   [[Datasets](#-datasets)] + +--- + +## ๐ŸŽฏ Abstract + +Task arithmetic provides an efficient, training-free way to edit pre-trained models, yet lacks a fundamental theoretical explanation for its success. The existing concept of "weight disentanglement" describes the ideal outcome of non-interfering task composition but does not reveal its underlying cause. Crucially, what intrinsic properties of the pre-trained model ($\theta_0$) or the task vectors ($\tau_t$) enable this disentanglement remains underexplored. In this paper, we introduce Task-Feature Specialization (TFS), a model's ability to allocate distinct internal features to different tasks, as the fundamental principle. We first prove that TFS is a sufficient condition for weight disentanglement. More importantly, we find that TFS also gives rise to an observable geometric consequence: weight vector orthogonality. This positions TFS as the common cause for both the desired functional outcome (disentanglement) and a measurable geometric property (orthogonality). This relationship provides the key insight for our method: since the abstract TFS property is intractable to enforce directly, we can instead promote weight disentanglement by shaping its concrete geometric consequence, orthogonality. Therefore, we propose OrthoReg, a simple and effective regularization method that actively enforces an internal orthogonal structure on weight updates ($\Delta W$) that constitute $\tau_t$ during fine-tuning. And we theoretically prove that OrthoReg promotes disentanglement. Extensive experiments demonstrate that OrthoReg consistently and significantly enhances the performance of various task arithmetic methods. + +

+ +
+ TFS is the common cause connecting Weight Vector Orthogonality (WVO) with Weight Disentanglement (WD). +

+ +### โœจ Key Contributions + +- ๐Ÿ“ **Theory**: We identify TFS as a sufficient condition for weight disentanglement, and WVO as its geometric consequence, providing the first principled explanation for task arithmetic. +- ๐Ÿ”ง **Method (OrthoReg)**: A simple regularization term added to the fine-tuning loss that enforces column-wise orthogonality on ฮ”W, for which we prove theoretical efficacy. +- ๐Ÿ”— **Connection to TTA**: We show that OrthoReg and Tangent Task Arithmetic (TTA) share the same underlying mechanism (i.e. inter-task vector orthogonality), but OrthoReg achieves this more efficiently. +- ๐Ÿ“Š **Experiments**: Consistent and significant improvements over Non-linear FT, TTA, ATT-FT, LoRA-ATT across ViT-B-32, ViT-B-16, and ViT-L-14. + +--- + +### The OrthoReg Loss + +

+ +

+ +The total loss adds a regularization term to the standard task objective: + +$$\mathcal{L} = \mathcal{L}_{\text{task}}(\theta_0 + \Delta\theta) + \lambda \cdot \mathcal{L}_{\text{ortho}}(\Delta\theta)$$ + +$$\mathcal{L}_{\text{ortho}}(\Delta\theta) = \sum_l \left\|(\Delta W^{(l)})^\top \Delta W^{(l)} - I\right\|_F^2$$ + +--- + +## ๐Ÿ› ๏ธ Installation + +This codebase is built on top of [Tangent Task Arithmetic (TTA)](https://github.com/gortizji/tangent_task_arithmetic). Environment setup follows theirs exactly. + + +To run the code, please install all its dependencies: +```sh +conda env create +conda activate tangent-arithmetic +``` +and add the `src` directory to the `PYTHONPATH`: +```sh +cd OrthoReg +export PYTHONPATH="$PYTHONPATH:$PWD" +``` + +--- + +## ๐Ÿ“ฆ Datasets + +We evaluate on 8 image classification benchmarks following [Task Arithmetic](https://github.com/mlfoundations/task_vectors) and [TTA](https://github.com/gortizji/tangent_task_arithmetic): + +**Cars ยท DTD ยท EuroSAT ยท GTSRB ยท MNIST ยท RESISC45 ยท SUN397 ยท SVHN** + +For dataset download and preparation, please follow the instructions in the [TTA repository](https://github.com/gortizji/tangent_task_arithmetic#datasets). + +We also provide a pre-packaged dataset archive for convenience: + +> ๐Ÿ“ฅ **Dataset Download:** `https://pan.baidu.com/s/1PgLyjUrAhsmgSAz4ms5mcQ?pwd=fwf5` + +Set the root path via `--data-location /path/to/datasets/`. + +--- + +## ๐Ÿš€ Quick Start + +All scripts are run from the `OrthoReg/` directory. This repository implements **6 finetuning modes**: + +| `--finetuning-mode` | Description | +|---|---| +| `standard` | Non-linear full fine-tuning (baseline) | +| `standard_ortho` | Non-linear FT + OrthoReg | +| `linear` | TTA โ€” tangent space fine-tuning (baseline) | +| `linear_ortho` | TTA + OrthoReg | +| `linear-2` | ATT-FT โ€” attention-only fine-tuning (baseline) | +| `linear-2_ortho` | ATT-FT + OrthoReg | + +> **Note on LoRA-ATT:** The LoRA-ATT and LoRA-ATT+OrthoReg results from the paper are implemented in a separate repository due to the complexity of patching OpenCLIP's fused QKV projection. Code will be released at: `https://github.com/lshangge/OrthoReg_lora` + +### Step 1 โ€” Fine-tune + +```bash +python src/finetune.py \ + --model ViT-B-32 \ + --finetuning-mode standard_ortho \ + --ortho-lambda 10 \ + --lr 1e-5 \ + --data-location /path/to/datasets/ \ +``` + +Switch between all six modes by changing `--finetuning-mode` and `--ortho-lambda`: + +```bash +--finetuning-mode standard --ortho-lambda 0 # Non-linear FT +--finetuning-mode standard_ortho --ortho-lambda xx # Non-linear FT + OrthoReg +--finetuning-mode linear --ortho-lambda 0 # TTA +--finetuning-mode linear_ortho --ortho-lambda xx # TTA + OrthoReg +--finetuning-mode linear-2 --ortho-lambda 0 # ATT-FT +--finetuning-mode linear-2_ortho --ortho-lambda xx # ATT-FT + OrthoReg +``` + +Checkpoints are saved to: +- `checkpoints_{seed}/{mode}_{lr}_{model}/` โ€” for baselines +- `checkpoints_{seed}/{mode}_{lr}_lambda{lambda}_{model}/` โ€” for OrthoReg variants + +### Step 2 โ€” Evaluate Single-Task Accuracy + +```bash +python src/eval_single_task.py \ + --model ViT-B-32 \ + --finetuning-mode standard_ortho \ + --ortho-lambda 10 \ + --lr 1e-5 \ + --data-location /path/to/datasets/ +``` + +> Run `eval_single_task` with `--finetuning-mode none --ortho-lambda 0` first to generate `zeroshot_accuracies.json`, which is required as the reference for normalized accuracy in Steps 3โ€“4. + +### Step 3 โ€” Evaluate Task Addition + +```bash +python src/eval_task_addition.py \ + --model ViT-B-32 \ + --finetuning-mode standard_ortho \ + --ortho-lambda 10 \ + --lr 1e-5 \ + --data-location /path/to/datasets/ +``` + +### Step 4 โ€” Evaluate Task Negation + +```bash +python src/eval_task_negation.py \ + --model ViT-B-32 \ + --finetuning-mode standard_ortho \ + --ortho-lambda 10 \ + --lr 1e-5 \ + --data-location /path/to/datasets/ +``` + +--- + +## ๐Ÿ”ง Key Arguments + +| Argument | Default | Description | +|---|:---:|---| +| `--model` | `ViT-B-32` | CLIP model architecture | +| `--finetuning-mode` | โ€” | One of the 6 modes above | +| `--ortho-lambda` | `0.0` | OrthoReg strength ฮป; set to `0` for baselines | +| `--lr` | `1e-5` | Learning rate | +| `--seed` | `1993` | Random seed | +| `--world-size` | `1` | Number of GPUs (DDP) | +| `--data-location` | โ€” | Dataset root directory | +| `--batch-size` | `128` | Batch size per GPU | + +--- + +## ๐Ÿ“ Checkpoints + +We release fine-tuned checkpoints for ViT-B-32, ViT-B-16, and ViT-L-14 on all 8 tasks, covering all 6 modes. + +> ๐Ÿ“ฅ **Checkpoint Download:** `https://huggingface.co/gezi2333/OrthoReg_checkpoints` + +Unzip into `OrthoReg/checkpoints_{seed}/` and pass the corresponding `--seed`, `--lr`, and `--ortho-lambda` to the eval scripts to reproduce the paper's results directly. + +--- + +## ๐Ÿ“ Citation + +If you find this work useful, please cite: + +```bibtex +@inproceedings{liu2026orthoreg, + title = {Understanding and Enforcing Weight Disentanglement in Task Arithmetic}, + author = {Liu, Shangge and Yin, Yuehan and Wang, Lei and Fan, Qi and + Shi, Yinghuan and Li, Wenbin and Gao, Yang and Tao, Dacheng}, + booktitle = {CVPR}, + year = {2026} +} +``` + +--- + +## ๐Ÿ“ž Contact + +For questions or issues, please: + +- Open an issue on GitHub +- Contact the authors at [lshangge@smail.nju.edu.cn] + +--- + +## ๐Ÿ“ฌ Acknowledgements + +This codebase is built on top of [Task Arithmetic](https://github.com/mlfoundations/task_vectors), [Tangent Task Arithmetic](https://github.com/gortizji/tangent_task_arithmetic), and [Attention-Only Fine-tuning](https://github.com/kyrie-23/linear_task_arithmetic). We thank the authors for releasing their code. + diff --git a/assets/WVO-WD-TFS.png b/assets/WVO-WD-TFS.png new file mode 100644 index 0000000000000000000000000000000000000000..8e1296166a218be6e3e652c466f2ea56da1ab3ce --- /dev/null +++ b/assets/WVO-WD-TFS.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc8a9efc76ecb495a5de03a98215606a8cbab5b38cdbb53ea5d2c2ed133e535a +size 149842 diff --git a/assets/orthoreg_loss.png b/assets/orthoreg_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..c1370b3fe742341b87d034583f312ae5ac4327f7 --- /dev/null +++ b/assets/orthoreg_loss.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7974c4003a412cc9a5134f26b71cc58739f5cd1f48301ab0d3d428cb17ecb8e +size 157982 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..c3fa11f1d080d690ddf12ad5a58ff3c507e19b31 --- /dev/null +++ b/environment.yml @@ -0,0 +1,140 @@ +name: tangent-arithmetic +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1 + - _openmp_mutex=5.1 + - blas=1.0 + - brotlipy=0.7.0 + - bzip2=1.0.8 + - ca-certificates=2023.05.30 + - certifi=2023.5.7 + - cffi=1.15.1 + - charset-normalizer=2.0.4 + - cryptography=39.0.1 + - cuda=11.6.1 + - cuda-cccl=11.6.55 + - cuda-command-line-tools=11.6.2 + - cuda-compiler=11.6.2 + - cuda-cudart=11.6.55 + - cuda-cudart-dev=11.6.55 + - cuda-cuobjdump=11.6.124 + - cuda-cupti=11.6.124 + - cuda-cuxxfilt=11.6.124 + - cuda-driver-dev=11.6.55 + - cuda-gdb=12.1.105 + - cuda-libraries=11.6.1 + - cuda-libraries-dev=11.6.1 + - cuda-memcheck=11.8.86 + - cuda-nsight=12.1.105 + - cuda-nsight-compute=12.1.1 + - cuda-nvcc=11.6.124 + - cuda-nvdisasm=12.1.105 + - cuda-nvml-dev=11.6.55 + - cuda-nvprof=12.1.105 + - cuda-nvprune=11.6.124 + - cuda-nvrtc=11.6.124 + - cuda-nvrtc-dev=11.6.124 + - cuda-nvtx=11.6.124 + - cuda-nvvp=12.1.105 + - cuda-runtime=11.6.1 + - cuda-samples=11.6.101 + - cuda-sanitizer-api=12.1.105 + - cuda-toolkit=11.6.1 + - cuda-tools=11.6.1 + - cuda-visual-tools=11.6.1 + - ffmpeg=4.3 + - freetype=2.12.1 + - gds-tools=1.6.1.9 + - giflib=5.2.1 + - gmp=6.2.1 + - gnutls=3.6.15 + - idna=3.4 + - intel-openmp=2023.1.0 + - jpeg=9e + - lame=3.100 + - lcms2=2.12 + - ld_impl_linux-64=2.38 + - lerc=3.0 + - libcublas=11.9.2.110 + - libcublas-dev=11.9.2.110 + - libcufft=10.7.1.112 + - libcufft-dev=10.7.1.112 + - libcufile=1.6.1.9 + - libcufile-dev=1.6.1.9 + - libcurand=10.3.2.106 + - libcurand-dev=10.3.2.106 + - libcusolver=11.3.4.124 + - libcusparse=11.7.2.124 + - libcusparse-dev=11.7.2.124 + - libdeflate=1.17 + - libffi=3.4.4 + - libgcc-ng=11.2.0 + - libgomp=11.2.0 + - libiconv=1.16 + - libidn2=2.3.4 + - libnpp=11.6.3.124 + - libnpp-dev=11.6.3.124 + - libnvjpeg=11.6.2.124 + - libnvjpeg-dev=11.6.2.124 + - libpng=1.6.39 + - libstdcxx-ng=11.2.0 + - libtasn1=4.19.0 + - libtiff=4.5.0 + - libunistring=0.9.10 + - libuuid=1.41.5 + - libwebp=1.2.4 + - libwebp-base=1.2.4 + - lz4-c=1.9.4 + - mkl=2023.1.0 + - mkl-service=2.4.0 + - mkl_fft=1.3.6 + - mkl_random=1.2.2 + - ncurses=6.4 + - nettle=3.7.3 + - nsight-compute=2023.1.1.4 + - numpy=1.24.3 + - numpy-base=1.24.3 + - openh264=2.1.1 + - openssl=1.1.1t + - pillow=9.4.0 + - pip=23.0.1 + - pycparser=2.21 + - pyopenssl=23.0.0 + - pysocks=1.7.1 + - python=3.10.11 + - pytorch=1.13.1 + - pytorch-cuda=11.6 + - pytorch-mutex=1.0 + - readline=8.2 + - requests=2.29.0 + - setuptools=67.8.0 + - sqlite=3.41.2 + - tbb=2021.8.0 + - tk=8.6.12 + - torchaudio=0.13.1 + - torchvision=0.14.1 + - typing_extensions=4.5.0 + - tzdata=2023c + - urllib3=1.26.16 + - wheel=0.38.4 + - xz=5.4.2 + - zlib=1.2.13 + - zstd=1.5.5 + - pip: + - filelock==3.12.0 + - fsspec==2023.5.0 + - ftfy==6.1.1 + - huggingface-hub==0.15.1 + - open-clip-torch==2.10.1 + - packaging==23.1 + - protobuf==3.20.3 + - pyyaml==6.0 + - regex==2023.6.3 + - safetensors==0.3.1 + - scipy==1.10.1 + - sentencepiece==0.1.99 + - timm==0.9.2 + - wcwidth==0.2.6 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a7cc90a8c713cef874a8e3c40b8d24b5f361d3e Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/__pycache__/args.cpython-310.pyc b/src/__pycache__/args.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5f3a4d5e74a11757f90d656cb4bbec449a7dc94 Binary files /dev/null and b/src/__pycache__/args.cpython-310.pyc differ diff --git a/src/__pycache__/attention_only_finetune.cpython-310.pyc b/src/__pycache__/attention_only_finetune.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ea5e49c92e2c2dd9ebb0791e73cb383d95a5c21 Binary files /dev/null and b/src/__pycache__/attention_only_finetune.cpython-310.pyc differ diff --git a/src/__pycache__/distributed.cpython-310.pyc b/src/__pycache__/distributed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..340b9dedddebb30967cf4e654dd64c089ca69d39 Binary files /dev/null and b/src/__pycache__/distributed.cpython-310.pyc differ diff --git a/src/__pycache__/eval.cpython-310.pyc b/src/__pycache__/eval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11c84b135a6e0563a6713b2fb974809f520a75a0 Binary files /dev/null and b/src/__pycache__/eval.cpython-310.pyc differ diff --git a/src/__pycache__/heads.cpython-310.pyc b/src/__pycache__/heads.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5560cfb464a58f4e14cee196889d273e886ac012 Binary files /dev/null and b/src/__pycache__/heads.cpython-310.pyc differ diff --git a/src/__pycache__/linearize.cpython-310.pyc b/src/__pycache__/linearize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0333faf468cb7e4bdc4601287a562bd070f187e3 Binary files /dev/null and b/src/__pycache__/linearize.cpython-310.pyc differ diff --git a/src/__pycache__/modeling.cpython-310.pyc b/src/__pycache__/modeling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b61a266a18103c25d4bf3023dc2bb7def057bd5c Binary files /dev/null and b/src/__pycache__/modeling.cpython-310.pyc differ diff --git a/src/__pycache__/task_vectors.cpython-310.pyc b/src/__pycache__/task_vectors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3bb60527b24f61861488ebe52b59d46d856b4cc Binary files /dev/null and b/src/__pycache__/task_vectors.cpython-310.pyc differ diff --git a/src/__pycache__/utils.cpython-310.pyc b/src/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..257c32d1052c6b28a85df1a438f5c1f8b1e6ffb4 Binary files /dev/null and b/src/__pycache__/utils.cpython-310.pyc differ diff --git a/src/args.py b/src/args.py new file mode 100644 index 0000000000000000000000000000000000000000..47e11d8c846276908874a5c8a1540d94a9dc2d74 --- /dev/null +++ b/src/args.py @@ -0,0 +1,153 @@ +import argparse +import os + +import torch + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-location", + type=str, + default=os.path.expanduser("/path/datasets/"), + help="The root directory for the datasets.", + ) + parser.add_argument( + "--eval-datasets", + default=None, + type=lambda x: x.split(","), + help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ", + ) + parser.add_argument( + "--train-dataset", + default=None, + type=lambda x: x.split(","), + help="Which dataset(s) to patch on.", + ) + parser.add_argument( + "--exp_name", + type=str, + default=None, + help="Name of the experiment, for organization purposes only.", + ) + parser.add_argument( + "--results-db", + type=str, + default=None, + help="Where to store the results, else does not store", + ) + parser.add_argument( + "--model", + type=str, + default="ViT-B-32", + help="The type of model (e.g. RN50, ViT-B-32).", + ) + parser.add_argument( + "--batch-size", + type=int, + default=128, + ) + parser.add_argument( + "--num-grad-accumulation", + type=int, + default=1, + help="Number of gradient accumulation steps.", + ) + parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.") + parser.add_argument("--wd", type=float, default=0.1, help="Weight decay") + parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.") + parser.add_argument( + "--warmup_length", + type=int, + default=500, + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + ) + parser.add_argument( + "--load", + type=lambda x: x.split(","), + default=None, + help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", + ) + parser.add_argument( + "--save", + type=str, + default=None, + help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Directory for caching features and encoder", + ) + parser.add_argument( + "--openclip-cachedir", + type=str, + default=os.path.expanduser("~/openclip-cachedir/open_clip"), + help="Directory for caching models from OpenCLIP", + ) + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of processes for distributed training.", + ) + parser.add_argument( + "--checkpoint-every", + type=int, + default=-1, + help="How often to checkpoint the model.", + ) + parser.add_argument( + "--port", + type=int, + default=12355, + help="Port for distributed training.", + ) + parser.add_argument( + "--seed", + type=int, + default=1993, + help="Random seed.", + ) + parser.add_argument( + "--finetuning-mode", + choices=["standard", "standard_ortho", "linear", "linear_ortho", "linear-2", "linear-2_ortho"], + help="Finetuning mode: standard/linear/linear-2 with optional ortho regularization.", + ) + parser.add_argument( + "--n-eval-points", + type=int, + default=21, + help="Number of evaluation points used to find optimal coefficient in task arithmetic.", + ) + parser.add_argument( + "--ortho-lambda", + type=float, + default=0.0, + help="Weight of the orthogonality regularization term. Default 0.0 means no regularization.", + ) + parser.add_argument( + "--control_threshold", + type=float, + default=0.95, + help="Control dataset performance degradation threshold.", + ) + parser.add_argument( + "--alpha", + type=float, + default=None, + help="Manually specify the scaling coefficient for task vectors. If None, it will be optimized on validation set.", + ) + + parsed_args = parser.parse_args() + parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" + + if parsed_args.load is not None and len(parsed_args.load) == 1: + parsed_args.load = parsed_args.load[0] + + return parsed_args diff --git a/src/attention_only_finetune.py b/src/attention_only_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..3a77b965fddf2370fdb56b4fbd4e6f182b4bf2b1 --- /dev/null +++ b/src/attention_only_finetune.py @@ -0,0 +1,116 @@ +import os +import torch +import torch.nn as nn +from src.modeling import ImageEncoder +from src.utils import DotDict + +class AttentionOnlyFinetuneEncoder(ImageEncoder): + """ + A specialized ImageEncoder that fine-tunes only the attention module weights in the ViT. + Corresponds to the method described in Jin et al. (2025). + """ + def __init__(self, args, keep_lang=False): + # 1. Call the parent constructor to build the full model as usual + super().__init__(args, keep_lang=keep_lang) + + self.args = args + + # 2. Freeze all model parameters + # print("Freezing all parameters of the model initially...") + for param in self.model.parameters(): + param.requires_grad = False + + # 3. Unfreeze only the Attention module weights (Wq, Wk, Wv, Wo) + # print("Unfreezing Attention module weights for fine-tuning...") + self._unfreeze_attention_weights(self.model.visual) + + # 4. (Optional but recommended) Print trainable parameters for verification + # self._verify_trainable_params() + + def _unfreeze_attention_weights(self, vit_model): + """ + Iterate over all Transformer blocks and unfreeze the attention projection weights. + """ + # Iterate over the model and unfreeze target parameters + for block in vit_model.transformer.resblocks: + # Unfreeze the combined input projection weight for Q, K, V + block.attn.in_proj_weight.requires_grad = True + + # Unfreeze the output projection weight + block.attn.out_proj.weight.requires_grad = True + + # Per the paper's ablation study, not fine-tuning biases yields better results; keep them frozen + # block.attn.in_proj_bias.requires_grad = True + # block.attn.out_proj.bias.requires_grad = True + + def _verify_trainable_params(self): + """Print all trainable parameters for debugging and verification.""" + print("="*80) + print("Trainable parameters in AttentionOnlyFinetuneEncoder:") + trainable_params_count = 0 + for name, param in self.model.named_parameters(): + if param.requires_grad: + print(f" - {name}") + trainable_params_count += param.numel() + print(f"Total trainable parameters: {trainable_params_count / 1e6:.2f}M") + print("="*80) + + def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None): + """ + Extended forward method to optionally compute and return the orthogonal loss. + Consistent with the logic implemented for standard_ortho. + """ + # Original forward pass + features = self.model.encode_image(images) + + # Return features directly if orthogonal loss is not needed + if not calculate_ortho_loss: + return features + + # --- Compute orthogonal loss if requested --- + if pretrained_state_dict is None: + raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True") + + ortho_loss = 0.0 + # self.model is the open_clip model (e.g. ViT); iterate over its parameters + for name, p_finetuned in self.model.named_parameters(): + # Only compute loss for trainable parameters with gradients + if p_finetuned.requires_grad and p_finetuned.dim() == 2: + if name in pretrained_state_dict: + p_pretrained = pretrained_state_dict[name].to(p_finetuned.device) + + delta_W = p_finetuned - p_pretrained + + # Compute orthogonal loss (W^T * W - I) + rows, cols = delta_W.shape + if rows < cols: + mat = delta_W @ delta_W.T + identity = torch.eye(rows, device=delta_W.device) + else: + mat = delta_W.T @ delta_W + identity = torch.eye(cols, device=delta_W.device) + + ortho_loss += torch.norm(mat - identity, p='fro') + + return features, ortho_loss + + def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None): + # Ensure __call__ forwards all arguments + return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict) + + def save(self, filename): + """Save model weights.""" + # print(f"Saving AttentionOnlyFinetuneEncoder state_dict to {filename}") + if os.path.dirname(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + # Save only the state_dict; reconstruct the model on load + torch.save(self.state_dict(), filename) + + @classmethod + def load(cls, filename, args): + """Load model from a state_dict.""" + # print(f"Loading AttentionOnlyFinetuneEncoder from {filename}") + encoder = cls(args) # Create a new instance + state_dict = torch.load(filename, map_location='cpu') + encoder.load_state_dict(state_dict) # Load weights + return encoder \ No newline at end of file diff --git a/src/datasets/__pycache__/cars.cpython-310.pyc b/src/datasets/__pycache__/cars.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d4405bfc3a5b86e1f34823494ed704d655402e Binary files /dev/null and b/src/datasets/__pycache__/cars.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/cifar10.cpython-310.pyc b/src/datasets/__pycache__/cifar10.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..392eb3ee6607d1678c45dc1de315783067186aa9 Binary files /dev/null and b/src/datasets/__pycache__/cifar10.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/cifar100.cpython-310.pyc b/src/datasets/__pycache__/cifar100.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfce4e0671aa1ca3848a5ebd644343c16d4cdeeb Binary files /dev/null and b/src/datasets/__pycache__/cifar100.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/common.cpython-310.pyc b/src/datasets/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5283c26161f41f3a07d31f3e96aa755b05e4a4d1 Binary files /dev/null and b/src/datasets/__pycache__/common.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/dtd.cpython-310.pyc b/src/datasets/__pycache__/dtd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75cde45975e35320f1fddf7ddc0f3d400d34e34c Binary files /dev/null and b/src/datasets/__pycache__/dtd.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/emnist.cpython-310.pyc b/src/datasets/__pycache__/emnist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ba145d7792e5347fe678092596ca250d00a5d73 Binary files /dev/null and b/src/datasets/__pycache__/emnist.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/eurosat.cpython-310.pyc b/src/datasets/__pycache__/eurosat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7982e0eba87a3be3f0a2ec694fd22857a67a101f Binary files /dev/null and b/src/datasets/__pycache__/eurosat.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/gtsrb.cpython-310.pyc b/src/datasets/__pycache__/gtsrb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54dd342d1682b0e5349233c103ad1e4aef31578b Binary files /dev/null and b/src/datasets/__pycache__/gtsrb.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/imagenet.cpython-310.pyc b/src/datasets/__pycache__/imagenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2c2ecbd705b1dd2672d571b1477f366f517671a Binary files /dev/null and b/src/datasets/__pycache__/imagenet.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/kmnist.cpython-310.pyc b/src/datasets/__pycache__/kmnist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f3c07ef170873d91e6d0a7729d0d766ca25e9ce Binary files /dev/null and b/src/datasets/__pycache__/kmnist.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/mnist.cpython-310.pyc b/src/datasets/__pycache__/mnist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c961255829fd3125c65faf3f761217bfdad25f Binary files /dev/null and b/src/datasets/__pycache__/mnist.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/oxfordpets.cpython-310.pyc b/src/datasets/__pycache__/oxfordpets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f95c4a61bbb3e78942c8689fddf973dc6010bf26 Binary files /dev/null and b/src/datasets/__pycache__/oxfordpets.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/registry.cpython-310.pyc b/src/datasets/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..722d832aa156ec400101753f3b608d785faee229 Binary files /dev/null and b/src/datasets/__pycache__/registry.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/resisc45.cpython-310.pyc b/src/datasets/__pycache__/resisc45.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d115426a3da66263410795803530cf51d66233b Binary files /dev/null and b/src/datasets/__pycache__/resisc45.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/stl10.cpython-310.pyc b/src/datasets/__pycache__/stl10.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4723be7f4219d30244aa269f37c4b9a9953b022 Binary files /dev/null and b/src/datasets/__pycache__/stl10.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/sun397.cpython-310.pyc b/src/datasets/__pycache__/sun397.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43bc175deb4d4221c924b91ac1c54c7c238ddd33 Binary files /dev/null and b/src/datasets/__pycache__/sun397.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/svhn.cpython-310.pyc b/src/datasets/__pycache__/svhn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f277d480777bae9330af34f83874359aa5c8d84 Binary files /dev/null and b/src/datasets/__pycache__/svhn.cpython-310.pyc differ diff --git a/src/datasets/__pycache__/templates.cpython-310.pyc b/src/datasets/__pycache__/templates.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2972bcb8cee9391d7f9dcd35ac6c53b62a881077 Binary files /dev/null and b/src/datasets/__pycache__/templates.cpython-310.pyc differ diff --git a/src/datasets/cars.py b/src/datasets/cars.py new file mode 100644 index 0000000000000000000000000000000000000000..4546031aa1bde18c22e753da607ee0db555f08f2 --- /dev/null +++ b/src/datasets/cars.py @@ -0,0 +1,155 @@ +import os +import torch +import torchvision.datasets as datasets + + +import pathlib +from typing import Callable, Optional, Any, Tuple + +from PIL import Image + +from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg +from torchvision.datasets.vision import VisionDataset + + +class PytorchStanfordCars(VisionDataset): + """`Stanford Cars `_ Dataset + + The Cars dataset contains 16,185 images of 196 classes of cars. The data is + split into 8,144 training images and 8,041 testing images, where each class + has been split roughly in a 50-50 split + + .. note:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Args: + root (string): Root directory of dataset + split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again.""" + + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + try: + import scipy.io as sio + except ImportError: + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") + + super().__init__(root, transform=transform, target_transform=target_transform) + + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = pathlib.Path(root) / "stanford_cars" + devkit = self._base_folder / "devkit" + + if self._split == "train": + self._annotations_mat_path = devkit / "cars_train_annos.mat" + self._images_base_path = self._base_folder / "cars_train" + else: + self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" + self._images_base_path = self._base_folder / "cars_test" + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._samples = [ + ( + str(self._images_base_path / annotation["fname"]), + annotation["class"] - 1, # Original target mapping starts from 1, hence -1 + ) + for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] + ] + + self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + """Returns pil_image and class_id for given index""" + image_path, target = self._samples[idx] + pil_image = Image.open(image_path).convert("RGB") + + if self.transform is not None: + pil_image = self.transform(pil_image) + if self.target_transform is not None: + target = self.target_transform(target) + return pil_image, target + + + def download(self) -> None: + if self._check_exists(): + return + + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", + download_root=str(self._base_folder), + md5="c3b158d763b6e2245038c8ad08e45376", + ) + if self._split == "train": + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", + download_root=str(self._base_folder), + md5="065e5b463ae28d29e77c1b4b166cfe61", + ) + else: + download_and_extract_archive( + url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", + download_root=str(self._base_folder), + md5="4ce7ebf6a94d07f1952d94dd34c4d501", + ) + download_url( + url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", + root=str(self._base_folder), + md5="b0a2b23655a3edd16d84508592a98d10", + ) + + def _check_exists(self) -> bool: + if not (self._base_folder / "devkit").is_dir(): + return False + + return self._annotations_mat_path.exists() and self._images_base_path.is_dir() + + +class Cars: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + # Data loading code + + self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=False) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=False) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i].replace( + '_', ' ') for i in range(len(idx_to_class))] diff --git a/src/datasets/cifar10.py b/src/datasets/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..096913b59e5158c2eb507ebf066658547df5b533 --- /dev/null +++ b/src/datasets/cifar10.py @@ -0,0 +1,56 @@ +import os +import PIL +import torch +import numpy as np +import torchvision +from torchvision import transforms +from torchvision.datasets import CIFAR10 as PyTorchCIFAR10 +from torchvision.datasets import VisionDataset + +cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] + +class CIFAR10: + def __init__(self, preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + + self.train_dataset = PyTorchCIFAR10( + root=location, download=True, train=True, transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + + self.test_dataset = PyTorchCIFAR10( + root=location, download=True, train=False, transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + self.classnames = self.test_dataset.classes + +def convert(x): + if isinstance(x, np.ndarray): + return torchvision.transforms.functional.to_pil_image(x) + return x + +class BasicVisionDataset(VisionDataset): + def __init__(self, images, targets, transform=None, target_transform=None): + if transform is not None: + transform.transforms.insert(0, convert) + super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform) + assert len(images) == len(targets) + + self.images = images + self.targets = targets + + def __getitem__(self, index): + return self.transform(self.images[index]), self.targets[index] + + def __len__(self): + return len(self.targets) diff --git a/src/datasets/cifar100.py b/src/datasets/cifar100.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b3bb4953336e7ef3f24bc429f51df8953b4c47 --- /dev/null +++ b/src/datasets/cifar100.py @@ -0,0 +1,30 @@ +import os +import torch +from torchvision.datasets import CIFAR100 as PyTorchCIFAR100 + +class CIFAR100: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + self.train_dataset = PyTorchCIFAR100( + root=location, download=True, train=True, transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, batch_size=batch_size, num_workers=num_workers + ) + + self.test_dataset = PyTorchCIFAR100( + root=location, download=True, train=False, transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + self.classnames = self.test_dataset.classes + + diff --git a/src/datasets/common.py b/src/datasets/common.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb6cce76f71f0eea7f770fa4508795703082a95 --- /dev/null +++ b/src/datasets/common.py @@ -0,0 +1,139 @@ +import os +import torch +import json +import glob +import collections +import random + +import numpy as np + +from tqdm import tqdm + +import torchvision.datasets as datasets +from torch.utils.data import Dataset, DataLoader, Sampler + + +class SubsetSampler(Sampler): + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return (i for i in self.indices) + + def __len__(self): + return len(self.indices) + +class ImageFolderWithPaths(datasets.ImageFolder): + def __init__(self, path, transform, flip_label_prob=0.0): + super().__init__(path, transform) + self.flip_label_prob = flip_label_prob + if self.flip_label_prob > 0: + print(f'Flipping labels with probability {self.flip_label_prob}') + num_classes = len(self.classes) + for i in range(len(self.samples)): + if random.random() < self.flip_label_prob: + new_label = random.randint(0, num_classes-1) + self.samples[i] = ( + self.samples[i][0], + new_label + ) + + def __getitem__(self, index): + image, label = super(ImageFolderWithPaths, self).__getitem__(index) + return { + 'images': image, + 'labels': label, + 'image_paths': self.samples[index][0] + } + + +def maybe_dictionarize(batch): + if isinstance(batch, dict): + return batch + + if len(batch) == 2: + batch = {'images': batch[0], 'labels': batch[1]} + elif len(batch) == 3: + batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} + else: + raise ValueError(f'Unexpected number of elements: {len(batch)}') + + return batch + + +def get_features_helper(image_encoder, dataloader, device): + all_data = collections.defaultdict(list) + + image_encoder = image_encoder.to(device) + image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) + image_encoder.eval() + + with torch.no_grad(): + for batch in tqdm(dataloader): + batch = maybe_dictionarize(batch) + features = image_encoder(batch['images'].cuda()) + + all_data['features'].append(features.cpu()) + + for key, val in batch.items(): + if key == 'images': + continue + if hasattr(val, 'cpu'): + val = val.cpu() + all_data[key].append(val) + else: + all_data[key].extend(val) + + for key, val in all_data.items(): + if torch.is_tensor(val[0]): + all_data[key] = torch.cat(val).numpy() + + return all_data + + +def get_features(is_train, image_encoder, dataset, device): + split = 'train' if is_train else 'val' + dname = type(dataset).__name__ + if image_encoder.cache_dir is not None: + cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' + cached_files = glob.glob(f'{cache_dir}/*') + if image_encoder.cache_dir is not None and len(cached_files) > 0: + print(f'Getting features from {cache_dir}') + data = {} + for cached_file in cached_files: + name = os.path.splitext(os.path.basename(cached_file))[0] + data[name] = torch.load(cached_file) + else: + print(f'Did not find cached features at {cache_dir}. Building from scratch.') + loader = dataset.train_loader if is_train else dataset.test_loader + data = get_features_helper(image_encoder, loader, device) + if image_encoder.cache_dir is None: + print('Not caching because no cache directory was passed.') + else: + os.makedirs(cache_dir, exist_ok=True) + print(f'Caching data at {cache_dir}') + for name, val in data.items(): + torch.save(val, f'{cache_dir}/{name}.pt') + return data + + +class FeatureDataset(Dataset): + def __init__(self, is_train, image_encoder, dataset, device): + self.data = get_features(is_train, image_encoder, dataset, device) + + def __len__(self): + return len(self.data['features']) + + def __getitem__(self, idx): + data = {k: v[idx] for k, v in self.data.items()} + data['features'] = torch.from_numpy(data['features']).float() + return data + + +def get_dataloader(dataset, is_train, args, image_encoder=None): + if image_encoder is not None: + feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) + dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) + else: + dataloader = dataset.train_loader if is_train else dataset.test_loader + return dataloader \ No newline at end of file diff --git a/src/datasets/dtd.py b/src/datasets/dtd.py new file mode 100644 index 0000000000000000000000000000000000000000..79fb3c3018634e9a9c7ecef38a86f0b3573f39a7 --- /dev/null +++ b/src/datasets/dtd.py @@ -0,0 +1,34 @@ +import os +import torch +import torchvision.datasets as datasets + + +class DTD: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + # Data loading code + traindir = os.path.join(location, 'dtd', 'train') + valdir = os.path.join(location, 'dtd', 'val') + + self.train_dataset = datasets.ImageFolder( + traindir, transform=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i].replace( + '_', ' ') for i in range(len(idx_to_class))] \ No newline at end of file diff --git a/src/datasets/emnist.py b/src/datasets/emnist.py new file mode 100644 index 0000000000000000000000000000000000000000..790eca50a2ef5c772b06fb60e2bebd5584eaa14a --- /dev/null +++ b/src/datasets/emnist.py @@ -0,0 +1,74 @@ +import os + +import torch + +import torchvision +import torchvision.datasets as datasets + + +def rotate_img(img): + return torchvision.transforms.functional.rotate(img, -90) + + +def flip_img(img): + return torchvision.transforms.functional.hflip(img) + + +def emnist_preprocess(): + return torchvision.transforms.Compose( + [ + rotate_img, + flip_img, + ] + ) + + +class EMNIST: + def __init__( + self, + preprocess, + location, + batch_size=128, + num_workers=8, + ): + preprocess1 = emnist_preprocess() + preprocess = torchvision.transforms.Compose( + [ + preprocess, + preprocess1, + ] + ) + # if not os.path.exists(location): + # os.makedirs(location, exist_ok=True) + + self.train_dataset = datasets.EMNIST( + root=location, + download=True, + split="digits", + transform=preprocess, + train=True, + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + ) + + self.test_dataset = datasets.EMNIST( + root=location, + download=True, + split="digits", + transform=preprocess, + train=False, + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=32, + shuffle=False, + num_workers=num_workers, + ) + + self.classnames = self.train_dataset.classes diff --git a/src/datasets/eurosat.py b/src/datasets/eurosat.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef20d90f16ef8cf046895e89c525ae09d37d1f0 --- /dev/null +++ b/src/datasets/eurosat.py @@ -0,0 +1,75 @@ +import os +import torch +import torchvision.datasets as datasets +import re + +def pretify_classname(classname): + l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname) + l = [i.lower() for i in l] + out = ' '.join(l) + if out.endswith('al'): + return out + ' area' + return out + +class EuroSATBase: + def __init__(self, + preprocess, + test_split, + location='~/datasets', + batch_size=32, + num_workers=16): + # Data loading code + traindir = os.path.join(location, 'EuroSAT_splits', 'train') + testdir = os.path.join(location, 'EuroSAT_splits', test_split) + + + self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] + self.classnames = [pretify_classname(c) for c in self.classnames] + ours_to_open_ai = { + 'annual crop': 'annual crop land', + 'forest': 'forest', + 'herbaceous vegetation': 'brushland or shrubland', + 'highway': 'highway or road', + 'industrial area': 'industrial buildings or commercial buildings', + 'pasture': 'pasture land', + 'permanent crop': 'permanent crop land', + 'residential area': 'residential buildings or homes or apartments', + 'river': 'river', + 'sea lake': 'lake or sea', + } + for i in range(len(self.classnames)): + self.classnames[i] = ours_to_open_ai[self.classnames[i]] + + +class EuroSAT(EuroSATBase): + def __init__(self, + preprocess, + location='~/datasets', + batch_size=32, + num_workers=16): + super().__init__(preprocess, 'test', location, batch_size, num_workers) + + +class EuroSATVal(EuroSATBase): + def __init__(self, + preprocess, + location='~/datasets', + batch_size=32, + num_workers=16): + super().__init__(preprocess, 'val', location, batch_size, num_workers) diff --git a/src/datasets/gtsrb.py b/src/datasets/gtsrb.py new file mode 100644 index 0000000000000000000000000000000000000000..b089e128489f1526c27b542f3bb8dcf5a1683d8e --- /dev/null +++ b/src/datasets/gtsrb.py @@ -0,0 +1,205 @@ +import csv +import os +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import PIL +import torch +from torchvision.datasets.folder import make_dataset +from torchvision.datasets.utils import (download_and_extract_archive, + verify_str_arg) +from torchvision.datasets.vision import VisionDataset + +def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: + """Finds the class folders in a dataset. + + See :class:`DatasetFolder` for details. + """ + classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) + if not classes: + raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") + + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + +class PyTorchGTSRB(VisionDataset): + """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. + + Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + def __init__( + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + super().__init__(root, transform=transform, target_transform=target_transform) + + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = pathlib.Path(root) / "gtsrb" + self._target_folder = ( + self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") + ) + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + if self._split == "train": + _, class_to_idx = find_classes(str(self._target_folder)) + samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) + else: + with open(self._base_folder / "GT-final_test.csv") as csv_file: + samples = [ + (str(self._target_folder / row["Filename"]), int(row["ClassId"])) + for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) + ] + + self._samples = samples + self.transform = transform + self.target_transform = target_transform + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + + path, target = self._samples[index] + sample = PIL.Image.open(path).convert("RGB") + + if self.transform is not None: + sample = self.transform(sample) + + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + + def _check_exists(self) -> bool: + return self._target_folder.is_dir() + + def download(self) -> None: + if self._check_exists(): + return + + base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" + + if self._split == "train": + download_and_extract_archive( + f"{base_url}GTSRB-Training_fixed.zip", + download_root=str(self._base_folder), + md5="513f3c79a4c5141765e10e952eaa2478", + ) + else: + download_and_extract_archive( + f"{base_url}GTSRB_Final_Test_Images.zip", + download_root=str(self._base_folder), + md5="c7e4e6327067d32654124b0fe9e82185", + ) + download_and_extract_archive( + f"{base_url}GTSRB_Final_Test_GT.zip", + download_root=str(self._base_folder), + md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", + ) + + +class GTSRB: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + # to fit with repo conventions for location + self.train_dataset = PyTorchGTSRB( + root=location, + download=True, + split='train', + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = PyTorchGTSRB( + root=location, + download=True, + split='test', + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md + self.classnames = [ + 'red and white circle 20 kph speed limit', + 'red and white circle 30 kph speed limit', + 'red and white circle 50 kph speed limit', + 'red and white circle 60 kph speed limit', + 'red and white circle 70 kph speed limit', + 'red and white circle 80 kph speed limit', + 'end / de-restriction of 80 kph speed limit', + 'red and white circle 100 kph speed limit', + 'red and white circle 120 kph speed limit', + 'red and white circle red car and black car no passing', + 'red and white circle red truck and black car no passing', + 'red and white triangle road intersection warning', + 'white and yellow diamond priority road', + 'red and white upside down triangle yield right-of-way', + 'stop', + 'empty red and white circle', + 'red and white circle no truck entry', + 'red circle with white horizonal stripe no entry', + 'red and white triangle with exclamation mark warning', + 'red and white triangle with black left curve approaching warning', + 'red and white triangle with black right curve approaching warning', + 'red and white triangle with black double curve approaching warning', + 'red and white triangle rough / bumpy road warning', + 'red and white triangle car skidding / slipping warning', + 'red and white triangle with merging / narrow lanes warning', + 'red and white triangle with person digging / construction / road work warning', + 'red and white triangle with traffic light approaching warning', + 'red and white triangle with person walking warning', + 'red and white triangle with child and person walking warning', + 'red and white triangle with bicyle warning', + 'red and white triangle with snowflake / ice warning', + 'red and white triangle with deer warning', + 'white circle with gray strike bar no speed limit', + 'blue circle with white right turn arrow mandatory', + 'blue circle with white left turn arrow mandatory', + 'blue circle with white forward arrow mandatory', + 'blue circle with white forward or right turn arrow mandatory', + 'blue circle with white forward or left turn arrow mandatory', + 'blue circle with white keep right arrow mandatory', + 'blue circle with white keep left arrow mandatory', + 'blue circle with white arrows indicating a traffic circle', + 'white circle with gray strike bar indicating no passing for cars has ended', + 'white circle with gray strike bar indicating no passing for trucks has ended', + ] diff --git a/src/datasets/imagenet.py b/src/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d5a5613047c134ff81c33393d3e88cab3c471b --- /dev/null +++ b/src/datasets/imagenet.py @@ -0,0 +1,253 @@ +import os +import torch + +from .common import ImageFolderWithPaths, SubsetSampler +import numpy as np + + +imagenet_classnames = [ + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" +] + +class ImageNet: + def __init__(self, + preprocess, + location='~/data', + batch_size=32, + num_workers=32): + self.preprocess = preprocess + self.location = '/path/ImageNet2012/' # TODO + self.batch_size = batch_size + self.num_workers = num_workers + self.classnames = imagenet_classnames + + self.populate_train() + self.populate_test() + + def populate_train(self): + traindir = os.path.join(self.location, 'train') + self.train_dataset = ImageFolderWithPaths( + traindir, + transform=self.preprocess) + sampler = self.get_train_sampler() + kwargs = {'shuffle' : True} if sampler is None else {} + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + sampler=sampler, + batch_size=self.batch_size, + num_workers=self.num_workers, + **kwargs, + ) + + def populate_test(self): + self.test_dataset = self.get_test_dataset() + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + sampler=self.get_test_sampler() + ) + + def get_test_path(self): + test_path = os.path.join(self.location, 'val_dir') + if not os.path.exists(test_path): + test_path = os.path.join(self.location,'val') + return test_path + + def get_train_sampler(self): + return None + + def get_test_sampler(self): + return None + + def get_test_dataset(self): + return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess) + + def name(self): + return 'imagenet' + +class ImageNetTrain(ImageNet): + + def get_test_dataset(self): + pass + +class ImageNetK(ImageNet): + + def get_train_sampler(self): + idxs = np.zeros(len(self.train_dataset.targets)) + target_array = np.array(self.train_dataset.targets) + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:self.k()] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype('int') + sampler = SubsetSampler(np.where(idxs)[0]) + return sampler \ No newline at end of file diff --git a/src/datasets/kmnist.py b/src/datasets/kmnist.py new file mode 100644 index 0000000000000000000000000000000000000000..243359d7a6cc7aec1bdd99126995daf3d4c4b414 --- /dev/null +++ b/src/datasets/kmnist.py @@ -0,0 +1,39 @@ +import os + +import torch +import torchvision.datasets as datasets + + +class KMNIST: + def __init__( + self, + preprocess, + location=os.path.expanduser("~/data"), + batch_size=128, + num_workers=6, + ): + + location = os.path.join(location, "KMNIST") + self.train_dataset = datasets.KMNIST( + root=location, download=True, train=True, transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + ) + + self.test_dataset = datasets.KMNIST( + root=location, download=True, train=False, transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + + self.classnames = self.train_dataset.classes diff --git a/src/datasets/mnist.py b/src/datasets/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4819351db8edf370311851026eba40b698ad3e --- /dev/null +++ b/src/datasets/mnist.py @@ -0,0 +1,41 @@ +import os +import torch +import torchvision.datasets as datasets + +class MNIST: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + + self.train_dataset = datasets.MNIST( + root=location, + download=True, + train=True, + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = datasets.MNIST( + root=location, + download=True, + train=False, + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] \ No newline at end of file diff --git a/src/datasets/oxfordpets.py b/src/datasets/oxfordpets.py new file mode 100644 index 0000000000000000000000000000000000000000..7285b424c6f970df692bd7eb3fd9656d423e7cb0 --- /dev/null +++ b/src/datasets/oxfordpets.py @@ -0,0 +1,38 @@ +import os +import torch +import torchvision.datasets as datasets + + +class OxfordIIITPet: + def __init__( + self, + preprocess, + location=os.path.expanduser("~/data"), + batch_size=128, + num_workers=6, + ): + + location = os.path.join(location, "OxfordIIITPet") + self.train_dataset = datasets.OxfordIIITPet( + root=location, download=True, split="trainval", transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + ) + + self.test_dataset = datasets.OxfordIIITPet( + root=location, download=True, split="test", transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + ) + + self.classnames = self.train_dataset.classes diff --git a/src/datasets/registry.py b/src/datasets/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3f732c998918194b21926aa177ffb19502d874a3 --- /dev/null +++ b/src/datasets/registry.py @@ -0,0 +1,103 @@ +import sys +import inspect +import random +import torch +import copy + +from torch.utils.data.dataset import random_split + +from src.datasets.cars import Cars +from src.datasets.cifar10 import CIFAR10 +from src.datasets.cifar100 import CIFAR100 +from src.datasets.dtd import DTD +from src.datasets.eurosat import EuroSAT, EuroSATVal +from src.datasets.gtsrb import GTSRB +from src.datasets.imagenet import ImageNet +from src.datasets.mnist import MNIST +from src.datasets.resisc45 import RESISC45 +from src.datasets.stl10 import STL10 +from src.datasets.svhn import SVHN +from src.datasets.sun397 import SUN397 +from src.datasets.emnist import EMNIST +from src.datasets.kmnist import KMNIST +from src.datasets.oxfordpets import OxfordIIITPet + +registry = { + name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) +} + + +class GenericDataset(object): + def __init__(self): + self.train_dataset = None + self.train_loader = None + self.test_dataset = None + self.test_loader = None + self.classnames = None + + +def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0): + assert val_fraction > 0. and val_fraction < 1. + total_size = len(dataset.train_dataset) + val_size = int(total_size * val_fraction) + if max_val_samples is not None: + val_size = min(val_size, max_val_samples) + train_size = total_size - val_size + + assert val_size > 0 + assert train_size > 0 + + lengths = [train_size, val_size] + + trainset, valset = random_split( + dataset.train_dataset, + lengths, + generator=torch.Generator().manual_seed(seed) + ) + if new_dataset_class_name == 'MNISTVal': + assert trainset.indices[0] == 36044 + + + new_dataset = None + + new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {}) + new_dataset = new_dataset_class() + + new_dataset.train_dataset = trainset + new_dataset.train_loader = torch.utils.data.DataLoader( + new_dataset.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + new_dataset.test_dataset = valset + new_dataset.test_loader = torch.utils.data.DataLoader( + new_dataset.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + + new_dataset.classnames = copy.copy(dataset.classnames) + + return new_dataset + + +def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000): + if dataset_name.endswith('Val'): + # Handle val splits + if dataset_name in registry: + dataset_class = registry[dataset_name] + else: + base_dataset_name = dataset_name.split('Val')[0] + base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers) + dataset = split_train_into_train_val( + base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples) + return dataset + else: + assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' + dataset_class = registry[dataset_name] + dataset = dataset_class( + preprocess, location=location, batch_size=batch_size, num_workers=num_workers + ) + return dataset diff --git a/src/datasets/resisc45.py b/src/datasets/resisc45.py new file mode 100644 index 0000000000000000000000000000000000000000..056122dbe5d43751dbdb4ea82048d47b56c6fe38 --- /dev/null +++ b/src/datasets/resisc45.py @@ -0,0 +1,304 @@ +import os +import torch + +import abc +import os +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.datasets import ImageFolder +from torchvision.datasets.folder import default_loader as pil_loader + + +# modified from: https://github.com/microsoft/torchgeo +class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): + """Abstract base class for datasets lacking geospatial information. + This base class is designed for datasets with pre-defined image chips. + """ + + @abc.abstractmethod + def __getitem__(self, index: int) -> Dict[str, Any]: + """Return an index within the dataset. + Args: + index: index to return + Returns: + data and labels at that index + Raises: + IndexError: if index is out of range of the dataset + """ + + @abc.abstractmethod + def __len__(self) -> int: + """Return the length of the dataset. + Returns: + length of the dataset + """ + + def __str__(self) -> str: + """Return the informal string representation of the object. + Returns: + informal string representation + """ + return f"""\ +{self.__class__.__name__} Dataset + type: VisionDataset + size: {len(self)}""" + + +class VisionClassificationDataset(VisionDataset, ImageFolder): + """Abstract base class for classification datasets lacking geospatial information. + This base class is designed for datasets with pre-defined image chips which + are separated into separate folders per class. + """ + + def __init__( + self, + root: str, + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + loader: Optional[Callable[[str], Any]] = pil_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> None: + """Initialize a new VisionClassificationDataset instance. + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + loader: a callable function which takes as input a path to an image and + returns a PIL Image or numpy array + is_valid_file: A function that takes the path of an Image file and checks if + the file is a valid file + """ + # When transform & target_transform are None, ImageFolder.__getitem__(index) + # returns a PIL.Image and int for image and label, respectively + super().__init__( + root=root, + transform=None, + target_transform=None, + loader=loader, + is_valid_file=is_valid_file, + ) + + # Must be set after calling super().__init__() + self.transforms = transforms + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + Args: + index: index to return + Returns: + data and label at that index + """ + image, label = self._load_image(index) + + if self.transforms is not None: + return self.transforms(image), label + + return image, label + + def __len__(self) -> int: + """Return the number of data points in the dataset. + Returns: + length of the dataset + """ + return len(self.imgs) + + def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: + """Load a single image and it's class label. + Args: + index: index to return + Returns: + the image + the image class label + """ + img, label = ImageFolder.__getitem__(self, index) + label = torch.tensor(label) + return img, label + + +class RESISC45Dataset(VisionClassificationDataset): + """RESISC45 dataset. + The `RESISC45 `_ + dataset is a dataset for remote sensing image scene classification. + Dataset features: + * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) + * three spectral bands - RGB + * 45 scene classes, 700 images per class + * images extracted from Google Earth from over 100 countries + * images conditions with high variability (resolution, weather, illumination) + Dataset format: + * images are three-channel jpgs + Dataset classes: + 0. airplane + 1. airport + 2. baseball_diamond + 3. basketball_court + 4. beach + 5. bridge + 6. chaparral + 7. church + 8. circular_farmland + 9. cloud + 10. commercial_area + 11. dense_residential + 12. desert + 13. forest + 14. freeway + 15. golf_course + 16. ground_track_field + 17. harbor + 18. industrial_area + 19. intersection + 20. island + 21. lake + 22. meadow + 23. medium_residential + 24. mobile_home_park + 25. mountain + 26. overpass + 27. palace + 28. parking_lot + 29. railway + 30. railway_station + 31. rectangular_farmland + 32. river + 33. roundabout + 34. runway + 35. sea_ice + 36. ship + 37. snowberg + 38. sparse_residential + 39. stadium + 40. storage_tank + 41. tennis_court + 42. terrace + 43. thermal_power_station + 44. wetland + This dataset uses the train/val/test splits defined in the "In-domain representation + learning for remote sensing" paper: + * https://arxiv.org/abs/1911.06721 + If you use this dataset in your research, please cite the following paper: + * https://doi.org/10.1109/jproc.2017.2675998 + """ + + # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" + # md5 = "d824acb73957502b00efd559fc6cfbbb" + # filename = "NWPU-RESISC45.rar" + directory = "resisc45/NWPU-RESISC45" + + splits = ["train", "val", "test"] + split_urls = { + "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 + "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 + "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 + } + split_md5s = { + "train": "b5a4c05a37de15e4ca886696a85c403e", + "val": "a0770cee4c5ca20b8c32bbd61e114805", + "test": "3dda9e4988b47eb1de9f07993653eb08", + } + classes = [ + "airplane", + "airport", + "baseball_diamond", + "basketball_court", + "beach", + "bridge", + "chaparral", + "church", + "circular_farmland", + "cloud", + "commercial_area", + "dense_residential", + "desert", + "forest", + "freeway", + "golf_course", + "ground_track_field", + "harbor", + "industrial_area", + "intersection", + "island", + "lake", + "meadow", + "medium_residential", + "mobile_home_park", + "mountain", + "overpass", + "palace", + "parking_lot", + "railway", + "railway_station", + "rectangular_farmland", + "river", + "roundabout", + "runway", + "sea_ice", + "ship", + "snowberg", + "sparse_residential", + "stadium", + "storage_tank", + "tennis_court", + "terrace", + "thermal_power_station", + "wetland", + ] + + def __init__( + self, + root: str = "data", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + ) -> None: + """Initialize a new RESISC45 dataset instance. + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + """ + assert split in self.splits + self.root = root + + valid_fns = set() + with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: + for fn in f: + valid_fns.add(fn.strip()) + is_in_split: Callable[[str], bool] = lambda x: os.path.basename( + x) in valid_fns + + super().__init__( + root=os.path.join(root, self.directory), + transforms=transforms, + is_valid_file=is_in_split, + ) + + + +class RESISC45: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + + self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + + # class names have _ so split on this for better zero-shot head + self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes] diff --git a/src/datasets/stl10.py b/src/datasets/stl10.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7237f014e6b7983d46e2ab9866fa523b6a0ef4 --- /dev/null +++ b/src/datasets/stl10.py @@ -0,0 +1,41 @@ +import os +import torch +import torchvision.datasets as datasets + +class STL10: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + location = os.path.join(location, 'stl10') + self.train_dataset = datasets.STL10( + root=location, + download=True, + split='train', + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = datasets.STL10( + root=location, + download=True, + split='test', + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + self.classnames = self.train_dataset.classes \ No newline at end of file diff --git a/src/datasets/sun397.py b/src/datasets/sun397.py new file mode 100644 index 0000000000000000000000000000000000000000..684c648146ed78c838a8b155f4ffb7278e3433b4 --- /dev/null +++ b/src/datasets/sun397.py @@ -0,0 +1,32 @@ +import os +import torch +import torchvision.datasets as datasets + +class SUN397: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=32, + num_workers=16): + # Data loading code + traindir = os.path.join(location, 'sun397', 'train') + valdir = os.path.join(location, 'sun397', 'val') + + + self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=batch_size, + num_workers=num_workers, + ) + + self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + num_workers=num_workers + ) + idx_to_class = dict((v, k) + for k, v in self.train_dataset.class_to_idx.items()) + self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))] diff --git a/src/datasets/svhn.py b/src/datasets/svhn.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9b47c7b04b7c419d486afcf0cae467cec93728 --- /dev/null +++ b/src/datasets/svhn.py @@ -0,0 +1,45 @@ +import os +import torch +from torchvision.datasets import SVHN as PyTorchSVHN +import numpy as np + + +class SVHN: + def __init__(self, + preprocess, + location=os.path.expanduser('~/data'), + batch_size=128, + num_workers=16): + + # to fit with repo conventions for location + modified_location = os.path.join(location, 'svhn') + + self.train_dataset = PyTorchSVHN( + root=modified_location, + download=True, + split='train', + transform=preprocess + ) + + self.train_loader = torch.utils.data.DataLoader( + self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + self.test_dataset = PyTorchSVHN( + root=modified_location, + download=True, + split='test', + transform=preprocess + ) + + self.test_loader = torch.utils.data.DataLoader( + self.test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] diff --git a/src/datasets/templates.py b/src/datasets/templates.py new file mode 100644 index 0000000000000000000000000000000000000000..56fb4a6d8bf4ac341bd06b6165120a1fa72c4f21 --- /dev/null +++ b/src/datasets/templates.py @@ -0,0 +1,239 @@ +cars_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'i love my {c}!', + lambda c: f'a photo of my dirty {c}.', + lambda c: f'a photo of my clean {c}.', + lambda c: f'a photo of my new {c}.', + lambda c: f'a photo of my old {c}.', +] + +cifar10_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +cifar100_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +dtd_template = [ + lambda c: f'a photo of a {c} texture.', + lambda c: f'a photo of a {c} pattern.', + lambda c: f'a photo of a {c} thing.', + lambda c: f'a photo of a {c} object.', + lambda c: f'a photo of the {c} texture.', + lambda c: f'a photo of the {c} pattern.', + lambda c: f'a photo of the {c} thing.', + lambda c: f'a photo of the {c} object.', +] + +eurosat_template = [ + lambda c: f'a centered satellite photo of {c}.', + lambda c: f'a centered satellite photo of a {c}.', + lambda c: f'a centered satellite photo of the {c}.', +] + +food101_template = [ + lambda c: f'a photo of {c}, a type of food.', +] + +gtsrb_template = [ + lambda c: f'a zoomed in photo of a "{c}" traffic sign.', + lambda c: f'a centered photo of a "{c}" traffic sign.', + lambda c: f'a close up photo of a "{c}" traffic sign.', +] + +mnist_template = [ + lambda c: f'a photo of the number: "{c}".', +] + +imagenet_template = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] + +resisc45_template = [ + lambda c: f'satellite imagery of {c}.', + lambda c: f'aerial imagery of {c}.', + lambda c: f'satellite photo of {c}.', + lambda c: f'aerial photo of {c}.', + lambda c: f'satellite view of {c}.', + lambda c: f'aerial view of {c}.', + lambda c: f'satellite imagery of a {c}.', + lambda c: f'aerial imagery of a {c}.', + lambda c: f'satellite photo of a {c}.', + lambda c: f'aerial photo of a {c}.', + lambda c: f'satellite view of a {c}.', + lambda c: f'aerial view of a {c}.', + lambda c: f'satellite imagery of the {c}.', + lambda c: f'aerial imagery of the {c}.', + lambda c: f'satellite photo of the {c}.', + lambda c: f'aerial photo of the {c}.', + lambda c: f'satellite view of the {c}.', + lambda c: f'aerial view of the {c}.', +] + +stl10_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a photo of the {c}.', +] + +sun397_template = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a photo of the {c}.', +] + +svhn_template = [ + lambda c: f'a photo of the number: "{c}".', +] + +oxfordpets_template = [ + lambda c: f"a photo of a {c}, a type of pet.", +] + +emnist_template = [ + lambda c: f'a photo of the number: "{c}".', +] + +kmnist_template = [ + lambda c: f"a photo of the character {c}.", +] + +dataset_to_template = { + 'Cars': cars_template, + 'CIFAR10': cifar10_template, + 'CIFAR100': cifar100_template, + 'DTD': dtd_template, + 'EuroSAT': eurosat_template, + 'Food101': food101_template, + 'GTSRB': gtsrb_template, + 'MNIST': mnist_template, + 'ImageNet': imagenet_template, + 'RESISC45': resisc45_template, + 'STL10': stl10_template, + 'SUN397': sun397_template, + 'SVHN': svhn_template, + "OxfordIIITPet": oxfordpets_template, + "EMNIST": emnist_template, + "KMNIST": kmnist_template, +} + + +def get_templates(dataset_name): + if dataset_name.endswith('Val'): + return get_templates(dataset_name.replace('Val', '')) + assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' + return dataset_to_template[dataset_name] \ No newline at end of file diff --git a/src/distributed.py b/src/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2f401bcf4694e1b05a90e8041b7d11af04316713 --- /dev/null +++ b/src/distributed.py @@ -0,0 +1,39 @@ +import os + +import torch + + +def setup_ddp(rank, world_size, port=12357): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + # initialize the process group + torch.distributed.init_process_group( + "nccl", + rank=rank, + world_size=world_size, + ) + torch.cuda.set_device(rank) + torch.distributed.barrier() + + +def cleanup_ddp(): + torch.distributed.destroy_process_group() + + +def is_main_process(): + return torch.distributed.get_rank() == 0 + + +def distribute_loader(loader): + return torch.utils.data.DataLoader( + loader.dataset, + batch_size=loader.batch_size // torch.distributed.get_world_size(), + sampler=torch.utils.data.distributed.DistributedSampler( + loader.dataset, + num_replicas=torch.distributed.get_world_size(), + rank=torch.distributed.get_rank(), + ), + num_workers=loader.num_workers, + pin_memory=loader.pin_memory, + ) diff --git a/src/eval.py b/src/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7ddbd85b6190dec9ae01fe43607fed1b02da76be --- /dev/null +++ b/src/eval.py @@ -0,0 +1,126 @@ +import numpy as np +import torch +import tqdm + +from src import utils +from src.datasets.common import get_dataloader, maybe_dictionarize +from src.datasets.registry import get_dataset +from src.heads import get_classification_head +from src.linearize import LinearizedImageEncoder +from src.modeling import ImageClassifier + + +def eval_single_dataset(image_encoder, dataset_name, args): + classification_head = get_classification_head(args, dataset_name) + model = ImageClassifier(image_encoder, classification_head) + + model.eval() + + dataset = get_dataset( + dataset_name, + model.val_preprocess, + location=args.data_location, + batch_size=args.batch_size, + ) + dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None) + device = args.device + + with torch.no_grad(): + top1, correct, n = 0.0, 0.0, 0.0 + for _, data in enumerate(tqdm.tqdm(dataloader)): + data = maybe_dictionarize(data) + x = data["images"].to(device) + y = data["labels"].to(device) + + logits = utils.get_logits(x, model) + + pred = logits.argmax(dim=1, keepdim=True).to(device) + + correct += pred.eq(y.view_as(pred)).sum().item() + + n += y.size(0) + + top1 = correct / n + + metrics = {"top1": top1} + print(f"Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%") + + return metrics + + +def evaluate(image_encoder, args): + if args.eval_datasets is None: + return + per_dataset_results = {} + eval_datasets = ( + args.eval_datasets + if args.control_dataset is None + else args.eval_datasets + [args.control_dataset] + ) + for dataset_name in eval_datasets: + print("Evaluating on", dataset_name) + + results = eval_single_dataset(image_encoder, dataset_name, args) + + print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}") + per_dataset_results[dataset_name + ":top1"] = results["top1"] + + return per_dataset_results + + +def evaluate_task_vector_at_coef( + task_vector, pretrained_checkpoint, args, scaling_coef, posthoc_linearization=False +): + image_encoder = task_vector.apply_to( + pretrained_checkpoint, scaling_coef=scaling_coef + ) + if posthoc_linearization: + pretrained_encoder = task_vector.apply_to( + pretrained_checkpoint, scaling_coef=0.0 + ) + image_encoder = LinearizedImageEncoder( + init_encoder=pretrained_encoder, image_encoder=image_encoder, args=args + ) + coef_info = evaluate(image_encoder, args) + + coef_info = add_normalized_accuracy(coef_info, args) + coef_info["avg_normalized_top1"] = np.mean( + [coef_info[dataset + ":normalized_top1"] for dataset in args.eval_datasets] + ) + coef_info["avg_top1"] = np.mean( + [coef_info[dataset + ":top1"] for dataset in args.eval_datasets] + ) + + return coef_info + + +def evaluate_task_vector( + task_vector, pretrained_checkpoint, args, posthoc_linearization=False +): + info = {} + for scaling_coef in np.linspace(0.0, 1.0, args.n_eval_points): + print(f"Evaluating for scaling coefficient {scaling_coef:.2f}") + info[scaling_coef] = evaluate_task_vector_at_coef( + task_vector, + pretrained_checkpoint, + args, + scaling_coef, + posthoc_linearization, + ) + + return info + + +def add_normalized_accuracy(results, args): + for dataset_name in args.eval_datasets: + results[dataset_name + ":normalized_top1"] = ( + results[dataset_name + ":top1"] / args.finetuning_accuracies[dataset_name] + ) + + return results + + +def nonlinear_advantage(acc_linear, acc_nonlinear, num_classes): + err_linear = 1 - acc_linear + err_nonlinear = 1 - acc_nonlinear + return (err_linear - err_nonlinear) * num_classes / (num_classes - 1) diff --git a/src/eval_single_task.py b/src/eval_single_task.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2b7ab154058302a50c36c640c6a78bcc8f7726 --- /dev/null +++ b/src/eval_single_task.py @@ -0,0 +1,115 @@ +import json +import os + +from src.args import parse_arguments +from src.eval import eval_single_dataset +from src.linearize import LinearizedImageEncoder +from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector +from src.attention_only_finetune import AttentionOnlyFinetuneEncoder + +args = parse_arguments() +if 'ortho' in args.finetuning_mode: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}" +else: + if args.seed is not None: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}" + else: + args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}" + +accuracies = {} + +print("*" * 100) +mode_labels = { + "standard": "Evaluating non-linear FT models.", + "standard_ortho": "Evaluating standard FT models with orthogonality regularization.", + "linear": "Evaluating linear FT models.", + "linear_ortho": "Evaluating linear FT models with orthogonality regularization.", + "linear-2": "Evaluating Attention-Only Finetune models.", + "linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.", +} +print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models.")) + +for dataset in [ + "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN", +]: + print("*" * 100) + print(f"Evaluating on {dataset}") + + mode = args.finetuning_mode + + if mode == "standard": + pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt" + try: + task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) + image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) + except FileNotFoundError: + print(f"Error: Could not find checkpoints for {dataset}.") + continue + + elif mode == "standard_ortho": + pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt" + try: + task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) + image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) + except FileNotFoundError: + print(f"Error: Could not find checkpoints for {dataset}.") + continue + + elif mode == "linear": + pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt" + try: + task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint) + image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) + except FileNotFoundError: + print(f"Error: Could not find checkpoints for {dataset}.") + continue + + elif mode == "linear_ortho": + pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt" + try: + task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint) + image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) + except FileNotFoundError: + print(f"Error: Could not find checkpoints for {dataset}.") + continue + + elif mode in ("linear-2", "linear-2_ortho"): + prefix = mode + "_" + pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt" + try: + task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) + image_encoder = task_vector.apply_to(pretrained_checkpoint, scaling_coef=1.0) + except FileNotFoundError: + print(f"Error: Could not find checkpoints for {dataset} with mode {mode}.") + continue + + else: + print(f"Unknown finetuning mode: {mode}") + continue + + for split in ["test", "val"]: + print("=" * 100) + print(f"Evaluating on {split} split.") + eval_dataset = dataset if split == "test" else f"{dataset}Val" + accuracies[eval_dataset] = eval_single_dataset(image_encoder, eval_dataset, args)["top1"] + +# Save results +save_name_map = { + "standard": "ft_accuracies.json", + "standard_ortho": "standard_ortho_ft_accuracies.json", + "linear": "linear_ft_accuracies.json", + "linear_ortho": "linear_ortho_ft_accuracies.json", + "linear-2": "linear-2_ft_accuracies.json", + "linear-2_ortho": "linear-2_ortho_ft_accuracies.json", +} + +save_path = os.path.join(args.save, save_name_map[args.finetuning_mode]) +os.makedirs(os.path.dirname(save_path), exist_ok=True) +with open(save_path, "w") as f: + json.dump(accuracies, f, indent=4) +print(f"Results saved to {save_path}") diff --git a/src/eval_task_addition.py b/src/eval_task_addition.py new file mode 100644 index 0000000000000000000000000000000000000000..987e89d65ad62d1c2356ba78ed45c2cd78275577 --- /dev/null +++ b/src/eval_task_addition.py @@ -0,0 +1,154 @@ +import json +import os + +from utils import find_optimal_coef + +from src.args import parse_arguments +from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef +from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector +from src.attention_only_finetune import AttentionOnlyFinetuneEncoder + +args = parse_arguments() + +if 'ortho' in args.finetuning_mode: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}" +else: + if args.seed is not None: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}" + else: + args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}" + +print("*" * 100) +mode_labels = { + "standard": "Evaluating non-linear FT models.", + "standard_ortho": "Evaluating standard FT models with orthogonality regularization.", + "linear": "Evaluating linear FT models.", + "linear_ortho": "Evaluating linear FT models with orthogonality regularization.", + "linear-2": "Evaluating Attention-Only Finetune models.", + "linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.", +} +ft_accuracies_name_map = { + "standard": "ft_accuracies.json", + "standard_ortho": "standard_ortho_ft_accuracies.json", + "linear": "linear_ft_accuracies.json", + "linear_ortho": "linear_ortho_ft_accuracies.json", + "linear-2": "linear-2_ft_accuracies.json", + "linear-2_ortho": "linear-2_ortho_ft_accuracies.json", +} +print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models.")) +print("*" * 100) + +ft_accuracies_path = os.path.join(args.save, ft_accuracies_name_map[args.finetuning_mode]) +with open(ft_accuracies_path) as f: + args.finetuning_accuracies = json.load(f) + +if args.seed is not None: + base_model_save_path = f"checkpoints_{args.seed}/{args.model}" +else: + base_model_save_path = f"checkpoints/{args.model}" + +with open(os.path.join(base_model_save_path, "zeroshot_accuracies.json")) as f: + pretrained_accuracies = json.load(f) + +eval_datasets = [ + "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SVHN", "SUN397", +] + +task_vectors = [] +mode = args.finetuning_mode + +for dataset in eval_datasets: + if mode == "linear": + pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt" + task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)) + + elif mode == "linear_ortho": + pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt" + task_vectors.append(LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint)) + + elif mode == "standard_ortho": + pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt" + task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)) + + elif mode in ("linear-2", "linear-2_ortho"): + prefix = mode + "_" + pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt" + if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)): + print(f"Warning: Missing checkpoints for {dataset}. Skipping.") + continue + task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)) + + else: # standard + pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt" + task_vectors.append(NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint)) + +if not task_vectors: + print("No task vectors were created. Exiting.") + exit() + +task_vector = sum(task_vectors) + +# Determine the base pretrained checkpoint +mode_prefix_map = { + "standard": "", + "standard_ortho": "standard_ortho_", + "linear": "linear_", + "linear_ortho": "linear_ortho_", + "linear-2": "linear-2_", + "linear-2_ortho": "linear-2_ortho_", +} +mode_prefix = mode_prefix_map[mode] +pretrained_checkpoint = f"{args.save}/{eval_datasets[0]}Val/{mode_prefix}zeroshot.pt" + +if not os.path.exists(pretrained_checkpoint): + print(f"Error: Base pretrained checkpoint not found at {pretrained_checkpoint}") + exit() + +args.eval_datasets = [dataset + "Val" for dataset in eval_datasets] +args.control_dataset = None + +val_metrics = evaluate_task_vector( + task_vector, + pretrained_checkpoint, + args, + posthoc_linearization=False, +) + +optimal_coef = find_optimal_coef( + val_metrics, + metric="avg_normalized_top1", + minimize=False, +) + +args.eval_datasets = [dataset for dataset in eval_datasets] +test_metrics = evaluate_task_vector_at_coef( + task_vector, + pretrained_checkpoint, + args, + float(optimal_coef), + posthoc_linearization=False, +) + +print("=" * 100) +print(f"Optimal Coefficient: {optimal_coef}") +print(f"Test normalized accuracy: {test_metrics['avg_normalized_top1']}") +print(f"Test absolute accuracy: {test_metrics['avg_top1']}") +additive_accuracies = {"test": test_metrics, "val": val_metrics, "optimal_coef": optimal_coef} + +save_name_map = { + "standard": "additions.json", + "standard_ortho": "standard_ortho_additions.json", + "linear": "linear_additions.json", + "linear_ortho": "linear_ortho_additions.json", + "linear-2": "linear-2_additions.json", + "linear-2_ortho": "linear-2_ortho_additions.json", +} +save_file = os.path.join(args.save, save_name_map[mode]) +with open(save_file, "w") as f: + json.dump(additive_accuracies, f, indent=4) +print(f"Addition results saved to {save_file}") diff --git a/src/eval_task_negation.py b/src/eval_task_negation.py new file mode 100644 index 0000000000000000000000000000000000000000..86d115405e992004bc9200e28fcaadbdbc3f5efe --- /dev/null +++ b/src/eval_task_negation.py @@ -0,0 +1,163 @@ +import json +import os + +from utils import find_optimal_coef + +from src.args import parse_arguments +from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef +from src.task_vectors import LinearizedTaskVector, NonLinearTaskVector +from src.attention_only_finetune import AttentionOnlyFinetuneEncoder + +args = parse_arguments() + +if 'ortho' in args.finetuning_mode: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}" +else: + if args.seed is not None: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}" + else: + args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}" + +if args.seed is not None: + base_model_save_path = f"checkpoints_{args.seed}/{args.model}" +else: + base_model_save_path = f"checkpoints/{args.model}" + +with open(os.path.join(base_model_save_path, "zeroshot_accuracies.json")) as f: + pretrained_accuracies = json.load(f) + +eval_datasets = [ + "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN", +] + +print("*" * 100) +mode_labels = { + "standard": "Evaluating non-linear FT models.", + "standard_ortho": "Evaluating standard FT models with orthogonality regularization.", + "linear": "Evaluating linear FT models.", + "linear_ortho": "Evaluating linear FT models with orthogonality regularization.", + "linear-2": "Evaluating Attention-Only Finetune models.", + "linear-2_ortho": "Evaluating Attention-Only Finetune models with orthogonality regularization.", +} +ft_accuracies_name_map = { + "standard": "ft_accuracies.json", + "standard_ortho": "standard_ortho_ft_accuracies.json", + "linear": "linear_ft_accuracies.json", + "linear_ortho": "linear_ortho_ft_accuracies.json", + "linear-2": "linear-2_ft_accuracies.json", + "linear-2_ortho": "linear-2_ortho_ft_accuracies.json", +} +print(mode_labels.get(args.finetuning_mode, f"Evaluating {args.finetuning_mode} models.")) +print("*" * 100) + +ft_accuracies_path = os.path.join(args.save, ft_accuracies_name_map[args.finetuning_mode]) +with open(ft_accuracies_path) as f: + args.finetuning_accuracies = json.load(f) + +control_dataset = "ImageNet" +negation_accuracies = {} +mode = args.finetuning_mode + +for dataset in eval_datasets: + task_vector = None + pretrained_checkpoint = None + + if mode == "linear": + pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_finetuned.pt" + if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)): + print(f"Warning: Missing checkpoints for {dataset}. Skipping.") + continue + task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint) + + elif mode == "linear_ortho": + pretrained_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/linear_ortho_finetuned.pt" + if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)): + print(f"Warning: Missing checkpoints for {dataset}. Skipping.") + continue + task_vector = LinearizedTaskVector(pretrained_checkpoint, finetuned_checkpoint) + + elif mode == "standard_ortho": + pretrained_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/standard_ortho_finetuned.pt" + if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)): + print(f"Warning: Missing checkpoints for {dataset}. Skipping.") + continue + task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) + + elif mode in ("linear-2", "linear-2_ortho"): + prefix = mode + "_" + pretrained_checkpoint = f"{args.save}/{dataset}Val/{prefix}zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/{prefix}finetuned.pt" + if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)): + print(f"Warning: Missing checkpoints for {dataset}. Skipping.") + continue + task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) + + else: # standard + pretrained_checkpoint = f"{args.save}/{dataset}Val/zeroshot.pt" + finetuned_checkpoint = f"{args.save}/{dataset}Val/finetuned.pt" + if not (os.path.exists(pretrained_checkpoint) and os.path.exists(finetuned_checkpoint)): + print(f"Warning: Missing checkpoints for {dataset}. Skipping.") + continue + task_vector = NonLinearTaskVector(pretrained_checkpoint, finetuned_checkpoint) + + if not os.path.exists(pretrained_checkpoint): + print(f"Error: Base pretrained checkpoint not found at {pretrained_checkpoint}. Skipping {dataset}.") + continue + + task_vector = -task_vector + + args.eval_datasets = [dataset + "Val"] + args.control_dataset = control_dataset + "Val" + val_metrics = evaluate_task_vector( + task_vector, + pretrained_checkpoint, + args, + posthoc_linearization=False, + ) + + optimal_coef = find_optimal_coef( + val_metrics, + metric=f"{dataset}Val:top1", + minimize=True, + control_metric=f"{control_dataset}Val:top1", + control_metric_threshold=args.control_threshold * pretrained_accuracies[control_dataset + "Val"], + ) + + args.eval_datasets = [dataset] + args.control_dataset = control_dataset + test_metrics = evaluate_task_vector_at_coef( + task_vector, + pretrained_checkpoint, + args, + optimal_coef, + posthoc_linearization=False, + ) + + print("=" * 100) + print(f"Results for dataset: {dataset}") + print(f"Optimal Coefficient: {optimal_coef}") + print(f"Test accuracy: {test_metrics.get(f'{dataset}:top1', 'N/A')}") + print(f"Control accuracy on {control_dataset}: {test_metrics.get(f'{control_dataset}:top1', 'N/A')}") + + negation_accuracies[dataset] = { + "test": test_metrics.get(f"{dataset}:top1"), + "test_control": test_metrics.get(f"{control_dataset}:top1"), + "val": val_metrics, + "optimal_coef": optimal_coef, + } + +save_name_map = { + "standard": "negations.json", + "standard_ortho": "standard_ortho_negations.json", + "linear": "linear_negations.json", + "linear_ortho": "linear_ortho_negations.json", + "linear-2": "linear-2_negations.json", + "linear-2_ortho": "linear-2_ortho_negations.json", +} +save_file = os.path.join(args.save, save_name_map[mode]) +with open(save_file, "w") as f: + json.dump(negation_accuracies, f, indent=4) +print(f"Negation results saved to {save_file}") diff --git a/src/finetune.py b/src/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0bb949408d3bf6c00171af91be635bab3ab06b --- /dev/null +++ b/src/finetune.py @@ -0,0 +1,243 @@ +import os +import time + +import torch + +from src.args import parse_arguments +from src.datasets.common import get_dataloader, maybe_dictionarize +from src.datasets.registry import get_dataset +from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp +from src.eval import eval_single_dataset +from src.heads import get_classification_head +from src.linearize import LinearizedImageEncoder +from src.modeling import ImageClassifier, ImageEncoder +from src.attention_only_finetune import AttentionOnlyFinetuneEncoder +from src.utils import LabelSmoothing, cosine_lr, accuracy + + +def finetune(rank, args): + setup_ddp(rank, args.world_size, port=args.port) + + train_dataset = args.train_dataset + ckpdir = os.path.join(args.save, train_dataset) + + valid_modes = [ + "standard", "standard_ortho", + "linear", "linear_ortho", + "linear-2", "linear-2_ortho", + ] + assert args.finetuning_mode in valid_modes, f"Mode {args.finetuning_mode} not supported." + + is_linearized = args.finetuning_mode in ("linear", "linear_ortho") + is_linear2 = args.finetuning_mode in ("linear-2", "linear-2_ortho") + is_standard_ortho = args.finetuning_mode == "standard_ortho" + is_linear_ortho = args.finetuning_mode == "linear_ortho" + is_linear2_ortho = args.finetuning_mode == "linear-2_ortho" + needs_ortho = is_standard_ortho or is_linear_ortho or is_linear2_ortho + + print(f"Using fine-tuning mode: {args.finetuning_mode}") + if needs_ortho and args.ortho_lambda > 0: + print(f" -> With OrthoReg (lambda={args.ortho_lambda})") + + mode_prefix_map = { + "standard": "", + "standard_ortho": "standard_ortho", + "linear": "linear", + "linear_ortho": "linear_ortho", + "linear-2": "linear-2", + "linear-2_ortho": "linear-2_ortho", + } + mode_prefix = mode_prefix_map[args.finetuning_mode] + + ft_path = os.path.join(ckpdir, f"{mode_prefix}_finetuned.pt" if mode_prefix else "finetuned.pt") + zs_path = os.path.join(ckpdir, f"{mode_prefix}_zeroshot.pt" if mode_prefix else "zeroshot.pt") + + if os.path.exists(zs_path) and os.path.exists(ft_path): + print(f"Skipping fine-tuning because {ft_path} exists.") + return zs_path, ft_path + + assert train_dataset is not None, "Please provide a training dataset." + + if args.load is not None and args.load.endswith("pt"): + if is_linearized: + image_encoder = LinearizedImageEncoder.load(args.load) + elif is_linear2: + image_encoder = AttentionOnlyFinetuneEncoder.load(args.load, args) + else: + image_encoder = ImageEncoder.load(args.load) + else: + print("Building image encoder.") + if is_linearized: + image_encoder = LinearizedImageEncoder(args, keep_lang=False) + elif is_linear2: + image_encoder = AttentionOnlyFinetuneEncoder(args, keep_lang=False) + else: + image_encoder = ImageEncoder(args) + + # Save a frozen copy of pretrained weights for ortho loss (standard_ortho / linear-2_ortho) + pretrained_state_dict_ref = None + if is_standard_ortho or is_linear2_ortho: + print("Saving pretrained state dict reference for ortho loss.") + pretrained_state_dict_ref = { + k: v.clone().detach() for k, v in image_encoder.model.state_dict().items() + } + + classification_head = get_classification_head(args, train_dataset) + model = ImageClassifier(image_encoder, classification_head) + model.freeze_head() + model = model.cuda() + + preprocess_fn = model.train_preprocess + print_every = 100 + + dataset = get_dataset( + train_dataset, + preprocess_fn, + location=args.data_location, + batch_size=args.batch_size, + ) + data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None) + num_batches = len(dataset.train_loader) + + ddp_loader = distribute_loader(data_loader) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[rank], + find_unused_parameters=True, + output_device=rank, + ) + + loss_fn = LabelSmoothing(args.ls) if args.ls > 0 else torch.nn.CrossEntropyLoss() + + params = [p for p in ddp_model.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) + scheduler = cosine_lr( + optimizer, + args.lr, + args.warmup_length, + args.epochs * num_batches // args.num_grad_accumulation, + ) + + if args.save is not None and is_main_process(): + os.makedirs(ckpdir, exist_ok=True) + ddp_model.module.image_encoder.save(zs_path) + + for epoch in range(args.epochs): + ddp_model.train() + + for i, batch in enumerate(ddp_loader): + start_time = time.time() + step = ( + i // args.num_grad_accumulation + + epoch * num_batches // args.num_grad_accumulation + ) + + batch = maybe_dictionarize(batch) + inputs = batch["images"].cuda() + labels = batch["labels"].cuda() + data_time = time.time() - start_time + + ortho_loss = 0.0 + if needs_ortho and args.ortho_lambda > 0: + logits, ortho_loss = ddp_model( + inputs, + calculate_ortho_loss=True, + pretrained_state_dict=pretrained_state_dict_ref, + ) + else: + logits = ddp_model(inputs) + + classification_loss = loss_fn(logits, labels) + loss = classification_loss + args.ortho_lambda * ortho_loss + + (acc1,) = accuracy(logits, labels, topk=(1,)) + acc1 /= labels.size(0) + + loss.backward() + + if (i + 1) % args.num_grad_accumulation == 0: + scheduler(step) + torch.nn.utils.clip_grad_norm_(params, 1.0) + optimizer.step() + optimizer.zero_grad() + + batch_time = time.time() - start_time + + if ( + args.checkpoint_every > 0 + and step % args.checkpoint_every == 0 + and is_main_process() + ): + ckpt_name = f"{mode_prefix}_checkpoint_{step}.pt" if mode_prefix else f"checkpoint_{step}.pt" + ddp_model.module.image_encoder.save(os.path.join(ckpdir, ckpt_name)) + + if ( + step % print_every == 0 + and ((i + 1) % args.num_grad_accumulation == 0) + and is_main_process() + ): + percent_complete = 100 * i / len(ddp_loader) + log_msg = ( + f"Train Epoch: {epoch} [{percent_complete:.0f}%]\t" + f"Total Loss: {loss.item():.6f}\t" + f"CE Loss: {classification_loss.item():.6f}\t" + ) + if needs_ortho and args.ortho_lambda > 0: + log_msg += f"Ortho Loss: {ortho_loss.item():.6f}\t" + log_msg += f"Acc@1: {100*acc1:.2f}%\tData (t) {data_time:.3f}" + print(log_msg, flush=True) + + if is_main_process(): + image_encoder = ddp_model.module.image_encoder + eval_single_dataset(image_encoder, train_dataset, args) + + if args.save is not None and is_main_process(): + image_encoder.save(ft_path) + return zs_path, ft_path + + cleanup_ddp() + + +if __name__ == "__main__": + train_datasets = [ + "Cars", + "DTD", + "EuroSAT", + "GTSRB", + "MNIST", + "RESISC45", + "SUN397", + "SVHN", + ] + epochs = { + "Cars": 35, + "DTD": 76, + "EuroSAT": 12, + "GTSRB": 11, + "MNIST": 5, + "RESISC45": 15, + "SUN397": 14, + "SVHN": 4, + } + + for dataset in train_datasets: + args = parse_arguments() + + args.epochs = epochs[dataset] + args.train_dataset = dataset + "Val" + + args.batch_size = 64 if args.model == "ViT-L-14" else 128 + args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1 + + if 'ortho' in args.finetuning_mode: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}" + else: + if args.seed is not None: + args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}" + else: + args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}" + + print("=" * 100) + print(f"Finetuning {args.model} on {dataset}") + print("=" * 100) + torch.multiprocessing.spawn(finetune, args=(args,), nprocs=args.world_size) diff --git a/src/heads.py b/src/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..c60d022284760b621e9b26147541f7b5988ea1ac --- /dev/null +++ b/src/heads.py @@ -0,0 +1,68 @@ +import os + +import open_clip +import torch +from tqdm import tqdm + +from src.datasets.registry import get_dataset +from src.datasets.templates import get_templates +from src.modeling import ClassificationHead, ImageEncoder + + +def build_classification_head(model, dataset_name, template, data_location, device): + template = get_templates(dataset_name) + + logit_scale = model.logit_scale + dataset = get_dataset(dataset_name, None, location=data_location) + model.eval() + model.to(device) + + print("Building classification head.") + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(dataset.classnames): + texts = [] + for t in template: + texts.append(t(classname)) + texts = open_clip.tokenize(texts).to(device) # tokenize + embeddings = model.encode_text(texts) # embed with text encoder + embeddings /= embeddings.norm(dim=-1, keepdim=True) + + embeddings = embeddings.mean(dim=0, keepdim=True) + embeddings /= embeddings.norm() + + zeroshot_weights.append(embeddings) + + zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) + zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) + + zeroshot_weights *= logit_scale.exp() + + zeroshot_weights = zeroshot_weights.squeeze().float() + zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) + + classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) + + return classification_head + + +def get_classification_head(args, dataset): + if not dataset.endswith("Val"): + # We want to load the head for the validation set always to be consistent with the one generated at training time. + dataset += "Val" + + filename = os.path.join(args.save, f"head_{dataset}.pt") + if os.path.exists(filename): + print(f"Classification head for {args.model} on {dataset} exists at {filename}") + return ClassificationHead.load(filename) + print( + f"Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch." # noqa: E501 + ) + model = ImageEncoder(args, keep_lang=True).model + template = get_templates(dataset) + classification_head = build_classification_head( + model, dataset, template, args.data_location, args.device + ) + os.makedirs(args.save, exist_ok=True) + classification_head.save(filename) + return classification_head diff --git a/src/linearize.py b/src/linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..2523a39d3f3f14654105aa16c3499281a1150a49 --- /dev/null +++ b/src/linearize.py @@ -0,0 +1,179 @@ +import abc +import os + +import torch +import torch.nn as nn +from functorch import jvp, make_functional_with_buffers + +from src.modeling import ImageEncoder +from src.utils import DotDict + + +class LinearizedModel(nn.Module): + """Creates a linearized version of a nn.Module. + + The linearized version of a model is a proper PyTorch model and can be + trained as any other nn.Module. + + Args: + model (nn.Module): The model to linearize. The trainable parameters of + the linearized model will be initialized to the parameters of this + model. + init_model (nn.Module): A model of the same type as `model` containing + the parameters around which the model is initialized. If not + provided, `model` is used as the initialization model. + """ + + def __init__(self, model: nn.Module, init_model: nn.Module = None) -> None: + """Initializes the linearized model.""" + super().__init__() + if init_model is None: + init_model = model + + func0, params0, self.buffers0 = make_functional_with_buffers( + init_model.eval(), disable_autograd_tracking=True + ) + self.func0 = lambda params, x: func0(params, self.buffers0, x) + + _, params, _ = make_functional_with_buffers( + model, disable_autograd_tracking=True + ) + + self.params = nn.ParameterList(params) + self.params0 = nn.ParameterList(params0) + self._model_name = model.__class__.__name__ + + # The intial parameters are not trainable. + for p in self.params0: + p.requires_grad = False + + # The params are. + for p in self.params: + p.requires_grad = True + + def forward(self, x, calculate_ortho_loss=False): + """ + Computes the linearized model output and optionally the orthogonality loss. + The method name is changed from __call__ to forward for clarity with DDP. + """ + dparams = [p - p0 for p, p0 in zip(self.params, self.params0)] + out, dp = jvp( + lambda param: self.func0(param, x), + (tuple(self.params0),), + (tuple(dparams),), + ) + output = out + dp + + if not calculate_ortho_loss: + return output + + # --- Integrate Orthogonality Loss Calculation --- + ortho_loss = 0.0 + for p_finetuned, p_pretrained in zip(self.params, self.params0): + if p_finetuned.requires_grad and p_finetuned.dim() == 2: + delta_W = p_finetuned - p_pretrained + + rows, cols = delta_W.shape + if rows < cols: + mat = delta_W @ delta_W.T + identity = torch.eye(rows, device=delta_W.device) + else: + mat = delta_W.T @ delta_W + identity = torch.eye(cols, device=delta_W.device) + + ortho_loss += torch.norm(mat - identity, p='fro') + + return output, ortho_loss + + # for tau_jp + def dp(self, x) -> torch.Tensor: + """ + Computes only the change in output (JVP) due to parameter shift. + Used for the 'Reg' penalty calculation. + """ + dparams = [p - p0 for p, p0 in zip(self.params, self.params0)] + _, dp = jvp( + lambda param: self.func0(param, x), + (tuple(self.params0),), + (tuple(dparams),), + ) + return dp + def __call__(self, x, calculate_ortho_loss=False): + return self.forward(x, calculate_ortho_loss) + + +class LinearizedImageEncoder(abc.ABC, nn.Module): + """Creates a linearized version of an image encoder.""" + + def __init__( + self, args=None, keep_lang=False, image_encoder=None, init_encoder=None + ): + super().__init__() + if image_encoder is None: + image_encoder = ImageEncoder(args, keep_lang) + if init_encoder is None: + init_encoder = image_encoder + + # Copy the attributes from the image encoder. + self.train_preprocess = image_encoder.train_preprocess + self.val_preprocess = image_encoder.val_preprocess + self.cache_dir = image_encoder.cache_dir + + self._model_name = self._get_name(args.model) + self.model = LinearizedModel(init_model=init_encoder, model=image_encoder) + + def _get_name(self, model_name): + if "__pretrained__" in model_name: + model_name, _ = model_name.split("__pretrained__", "") + return model_name + + def forward(self, x, calculate_ortho_loss=False, pretrained_state_dict=None): + # Pass the flag down to the wrapped LinearizedModel + return self.model(x, calculate_ortho_loss=calculate_ortho_loss) + + def __call__(self, x, calculate_ortho_loss=False, pretrained_state_dict=None): + return self.forward(x, calculate_ortho_loss, pretrained_state_dict) + + def save(self, filename): + """Saves the linearized image encoder. + + We save the model name in the state dict so that we can load the + correct model when loading the linearized image encoder. Directly using + torch.save would not work becuse func0 is not serializable. + + Args: + filename (str): The path to save the taylorized image encoder. + """ + if os.path.dirname(filename) != "": + os.makedirs(os.path.dirname(filename), exist_ok=True) + + state_dict = self.state_dict() + state_dict["model_name"] = self._model_name + + torch.save(state_dict, filename) + + @classmethod + def load(cls, filename): + """Loads a linearized image encoder. + + It first loads the state dict with the model name and then creates the + correct model and loads the state dict. + + Args: + filename (str): The path to the taylorized image encoder. + + Returns: + LinearizedImageEncoder: The loaded taylorized image encoder. + """ + print(f"Loading image encoder from {filename}") + state_dict = torch.load(filename, map_location="cpu") + + # ImageEncoder expects a DotDict + args = DotDict({"model": state_dict["model_name"]}) + taylorized_encoder = cls(args) + + # Remove the model name from the state dict so that we can load the + # model. + state_dict.pop("model_name") + taylorized_encoder.load_state_dict(state_dict) + return taylorized_encoder diff --git a/src/modeling.py b/src/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..846c2366e4ee8f03b3d94852be35c9c16dc37d48 --- /dev/null +++ b/src/modeling.py @@ -0,0 +1,209 @@ +import open_clip +import torch + +from src import utils + + +class ImageEncoder(torch.nn.Module): + def __init__(self, args, keep_lang=False): + super().__init__() + + print(f"Loading {args.model} pre-trained weights.") + if "__pretrained__" in args.model: + name, pretrained = args.model.split("__pretrained__") + elif "__init__" in args.model: + print("Using random initialization.") + name, pretrained = args.model.split("__init__")[0], None + else: + name = args.model + pretrained = "openai" + ( + self.model, + self.train_preprocess, + self.val_preprocess, + ) = open_clip.create_model_and_transforms( + name, pretrained=pretrained, cache_dir=args.openclip_cachedir + ) + + self.cache_dir = args.cache_dir + + if not keep_lang and hasattr(self.model, "transformer"): + delattr(self.model, "transformer") + + # def forward(self, images): + # assert self.model is not None + # return self.model.encode_image(images) + + # def __call__(self, inputs): + # return self.forward(inputs) + + def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None): + """ + Extended forward method to optionally compute and return the orthogonal loss. + """ + # Original forward pass + features = self.model.encode_image(images) + + # Return features directly if orthogonal loss is not needed + if not calculate_ortho_loss: + return features + + # --- Compute orthogonal loss if requested --- + # This logic is moved here from utils.py + if pretrained_state_dict is None: + raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True") + + ortho_loss = 0.0 + # self.model is the open_clip model (e.g. ViT); iterate over its parameters + for name, p_finetuned in self.model.named_parameters(): + if p_finetuned.requires_grad and p_finetuned.dim() == 2: + if name in pretrained_state_dict: + p_pretrained = pretrained_state_dict[name].to(p_finetuned.device) + + delta_W = p_finetuned - p_pretrained + + rows, cols = delta_W.shape + if rows < cols: + mat = delta_W @ delta_W.T + identity = torch.eye(rows, device=delta_W.device) + else: + mat = delta_W.T @ delta_W + identity = torch.eye(cols, device=delta_W.device) + + ortho_loss += torch.norm(mat - identity, p='fro') + + return features, ortho_loss + + def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None): + # Ensure __call__ forwards all arguments + return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict) + + def save(self, filename): + print(f"Saving image encoder to {filename}") + utils.torch_save(self, filename) + + @classmethod + def load(cls, model_name, filename): + print(f"Loading image encoder from {filename}") + state_dict = torch.load(filename, map_location="cpu") + return cls.load(model_name, state_dict) + + @classmethod + def load_from_state_dict(cls, model_name, state_dict): + ( + self.model, + self.train_preprocess, + self.val_preprocess, + ) = open_clip.create_model_and_transforms( + name, pretrained=pretrained, cache_dir=args.openclip_cachedir + ) + self.model.load_from_state_dict(state_dict) + + +class ClassificationHead(torch.nn.Linear): + def __init__(self, normalize, weights, biases=None): + output_size, input_size = weights.shape + super().__init__(input_size, output_size) + self.normalize = normalize + if weights is not None: + self.weight = torch.nn.Parameter(weights.clone()) + if biases is not None: + self.bias = torch.nn.Parameter(biases.clone()) + else: + self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) + + def forward(self, inputs): + if self.normalize: + inputs = inputs / inputs.norm(dim=-1, keepdim=True) + return super().forward(inputs) + + def __call__(self, inputs): + return self.forward(inputs) + + def save(self, filename): + print(f"Saving classification head to {filename}") + utils.torch_save(self, filename) + + @classmethod + def load(cls, filename): + print(f"Loading classification head from {filename}") + return utils.torch_load(filename) + + +class ImageClassifier(torch.nn.Module): + def __init__(self, image_encoder, classification_head): + super().__init__() + self.image_encoder = image_encoder + self.classification_head = classification_head + if self.image_encoder is not None: + self.train_preprocess = self.image_encoder.train_preprocess + self.val_preprocess = self.image_encoder.val_preprocess + + def freeze_head(self): + self.classification_head.weight.requires_grad_(False) + self.classification_head.bias.requires_grad_(False) + + # def forward(self, inputs): + # features = self.image_encoder(inputs) + # outputs = self.classification_head(features) + # return outputs + + # def __call__(self, inputs): + # return self.forward(inputs) + + def forward(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None): + # Forward arguments to image_encoder + encoder_output = self.image_encoder(inputs, calculate_ortho_loss, pretrained_state_dict) + + if calculate_ortho_loss: + features, ortho_loss = encoder_output + outputs = self.classification_head(features) + return outputs, ortho_loss + else: + features = encoder_output + outputs = self.classification_head(features) + return outputs + + def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None): + return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict) + + def save(self, filename): + print(f"Saving image classifier to {filename}") + utils.torch_save(self, filename) + + @classmethod + def load(cls, filename): + print(f"Loading image classifier from {filename}") + return utils.torch_load(filename) + + +class MultiHeadImageClassifier(torch.nn.Module): + def __init__(self, image_encoder, classification_heads): + super().__init__() + self.image_encoder = image_encoder + self.classification_heads = torch.nn.ModuleList(classification_heads) + if self.image_encoder is not None: + self.train_preprocess = self.image_encoder.train_preprocess + self.val_preprocess = self.image_encoder.val_preprocess + + def freeze_head(self): + for idx in range(len(self.classification_heads)): + self.classification_heads[idx].weight.requires_grad_(False) + self.classification_heads[idx].bias.requires_grad_(False) + + def forward(self, inputs, head_idx): + features = self.image_encoder(inputs) + outputs = self.classification_heads[head_idx](features) + return outputs + + def __call__(self, inputs, head_idx): + return self.forward(inputs, head_idx) + + def save(self, filename): + print(f"Saving image classifier to {filename}") + utils.torch_save(self, filename) + + @classmethod + def load(cls, filename): + print(f"Loading image classifier from {filename}") + return utils.torch_load(filename) diff --git a/src/task_vectors.py b/src/task_vectors.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2489a230492f9fc28a41930d76b47ae1c85999 --- /dev/null +++ b/src/task_vectors.py @@ -0,0 +1,230 @@ +import abc +import torch + +from src.linearize import LinearizedImageEncoder +from src.modeling import ImageEncoder +from src.attention_only_finetune import AttentionOnlyFinetuneEncoder + + +class _TaskVector(abc.ABC): + def __init__( + self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None + ): + """ + Initializes the task vector from a pretrained and a finetuned checkpoints. + This can either be done by passing two state dicts (one corresponding to the + pretrained model, and another to the finetuned model), or by directly passing in + the task vector state dict. + """ + if vector is not None: + self.vector = vector + else: + assert ( + pretrained_checkpoint is not None and finetuned_checkpoint is not None + ) + with torch.no_grad(): + pretrained_obj = self._load_checkpoint(pretrained_checkpoint) + finetuned_obj = self._load_checkpoint(finetuned_checkpoint) + + if hasattr(pretrained_obj, 'state_dict'): + pretrained_state_dict = pretrained_obj.state_dict() + else: + pretrained_state_dict = pretrained_obj + + if hasattr(finetuned_obj, 'state_dict'): + finetuned_state_dict = finetuned_obj.state_dict() + else: + finetuned_state_dict = finetuned_obj + + self.vector = {} + for key in pretrained_state_dict: + if pretrained_state_dict[key].dtype not in [torch.float32, torch.float16, torch.bfloat16]: + continue + if key in finetuned_state_dict: + self.vector[key] = ( + finetuned_state_dict[key] - pretrained_state_dict[key] + ) + + @abc.abstractmethod + def _load_checkpoint(self, checkpoint): + raise NotImplementedError + + @abc.abstractmethod + def _cast_to_same_type(self, other): + raise NotImplementedError + + def __add__(self, other): + other = self._cast_to_same_type(other) + with torch.no_grad(): + new_vector = {} + for key in self.vector: + if key not in other.vector: + print(f"Warning, key {key} is not present in both task vectors.") + continue + new_vector[key] = self.vector[key] + other.vector[key] + return self.__class__(vector=new_vector) + + def __sub__(self, other): + return self.__add__(-other) + + def __radd__(self, other): + if other is None or isinstance(other, int): + return self + return self.__add__(other) + + def __neg__(self): + with torch.no_grad(): + new_vector = {} + for key in self.vector: + new_vector[key] = -self.vector[key] + return self.__class__(vector=new_vector) + + def __pow__(self, power): + with torch.no_grad(): + new_vector = {} + for key in self.vector: + new_vector[key] = self.vector[key] ** power + return self.__class__(vector=new_vector) + + def __mul__(self, other): + with torch.no_grad(): + new_vector = {} + for key in self.vector: + new_vector[key] = other * self.vector[key] + return self.__class__(vector=new_vector) + + def dot(self, other): + other = self._cast_to_same_type(other) + with torch.no_grad(): + dot_product = 0.0 + for key in self.vector: + if key not in other.vector: + print(f"Warning, key {key} is not present in both task vectors.") + continue + dot_product += torch.sum(self.vector[key] * other.vector[key]) + return dot_product + + def norm(self): + return torch.sqrt(self.dot(self)) + + def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): + """Apply a task vector to a pretrained model.""" + with torch.no_grad(): + pretrained_model = self._load_checkpoint(pretrained_checkpoint) + + if hasattr(pretrained_model, 'state_dict'): + new_state_dict = pretrained_model.state_dict() + else: + new_state_dict = pretrained_model.copy() + + pretrained_state_dict = new_state_dict.copy() + + for key in pretrained_state_dict: + if key in self.vector: + new_state_dict[key] = ( + pretrained_state_dict[key] + scaling_coef * self.vector[key] + ) + + if hasattr(pretrained_model, 'state_dict'): + pretrained_model.load_state_dict(new_state_dict) + return pretrained_model + else: + from src.args import parse_arguments + args = parse_arguments() + if isinstance(self, NonLinearTaskVector): + encoder = self._build_model_from_checkpoint(pretrained_checkpoint, args) + encoder.load_state_dict(new_state_dict) + return encoder + else: + pretrained_model.load_state_dict(new_state_dict) + return pretrained_model + + +class NonLinearTaskVector(_TaskVector): + """A task vector for nonlinear models.""" + + def _load_checkpoint(self, checkpoint): + return torch.load(checkpoint, map_location="cpu") + + def _build_model_from_checkpoint(self, checkpoint_path, args): + mode = args.finetuning_mode + if mode in ["linear-2", "linear-2_ortho"]: + return AttentionOnlyFinetuneEncoder(args) + return ImageEncoder(args) + + def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): + with torch.no_grad(): + from src.args import parse_arguments + args = parse_arguments() + pretrained_model = self._build_model_from_checkpoint(pretrained_checkpoint, args) + pretrained_state_dict = torch.load(pretrained_checkpoint, map_location='cpu') + + if hasattr(pretrained_state_dict, 'state_dict'): + pretrained_state_dict = pretrained_state_dict.state_dict() + + new_state_dict = pretrained_state_dict.copy() + + for key in pretrained_state_dict: + if key in self.vector: + new_state_dict[key] += scaling_coef * self.vector[key] + + pretrained_model.load_state_dict(new_state_dict) + return pretrained_model + + def _cast_to_same_type(self, other): + if isinstance(other, LinearizedTaskVector): + return linear_to_nonlinear(other, self.vector.keys()) + return other + + +class LinearizedTaskVector(_TaskVector): + """A task vector for linearized models.""" + + def _load_checkpoint(self, checkpoint): + return LinearizedImageEncoder.load(checkpoint) + + def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): + with torch.no_grad(): + pretrained_model = self._load_checkpoint(pretrained_checkpoint) + new_state_dict = pretrained_model.state_dict() + pretrained_state_dict = new_state_dict.copy() + + for key in pretrained_state_dict: + if key in self.vector: + new_state_dict[key] += scaling_coef * self.vector[key] + + pretrained_model.load_state_dict(new_state_dict) + return pretrained_model + + def get_named_parameters(self, param_names): + params = {k: v for k, v in self.vector.items() if "model.params0" not in k} + return {k: v for k, v in zip(param_names, params.values())} + + def _cast_to_same_type(self, other): + if isinstance(other, NonLinearTaskVector): + return nonlinear_to_linear(other) + return other + + +def nonlinear_to_linear(nonlinear_task_vector): + if isinstance(nonlinear_task_vector, LinearizedTaskVector): + return nonlinear_task_vector + else: + linear_params = { + f"model.params.{i}": v + for i, v in enumerate(nonlinear_task_vector.vector.values()) + } + linear_params.update({ + f"model.params0.{i}": torch.zeros_like(v) + for i, v in enumerate(nonlinear_task_vector.vector.values()) + }) + return LinearizedTaskVector(vector=linear_params) + + +def linear_to_nonlinear(linear_task_vector, param_names): + if isinstance(linear_task_vector, NonLinearTaskVector): + return linear_task_vector + else: + return NonLinearTaskVector( + vector=linear_task_vector.get_named_parameters(param_names) + ) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6260bc271f972a25d4f58e1a4c413cb40f38c2 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,181 @@ +import os +import pickle + +import numpy as np +import torch + + +def assign_learning_rate(param_group, new_lr): + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def cosine_lr(optimizer, base_lrs, warmup_length, steps): + if not isinstance(base_lrs, list): + base_lrs = [base_lrs for _ in optimizer.param_groups] + assert len(base_lrs) == len(optimizer.param_groups) + + def _lr_adjuster(step): + for param_group, base_lr in zip(optimizer.param_groups, base_lrs): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(param_group, lr) + + return _lr_adjuster + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [ + float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) + for k in topk + ] + + +def torch_load_old(save_path, device=None): + with open(save_path, "rb") as f: + classifier = pickle.load(f) + if device is not None: + classifier = classifier.to(device) + return classifier + + +def torch_save(model, save_path): + if os.path.dirname(save_path) != "": + os.makedirs(os.path.dirname(save_path), exist_ok=True) + torch.save(model, save_path) + + +def torch_load(save_path, device=None): + model = torch.load(save_path, map_location="cpu") + if device is not None: + model = model.to(device) + return model + + +def get_logits(inputs, classifier): + assert callable(classifier) + if hasattr(classifier, "to"): + classifier = classifier.to(inputs.device) + return classifier(inputs) + + +def get_probs(inputs, classifier): + if hasattr(classifier, "predict_proba"): + probs = classifier.predict_proba(inputs.detach().cpu().numpy()) + return torch.from_numpy(probs) + logits = get_logits(inputs, classifier) + return logits.softmax(dim=1) + + +class LabelSmoothing(torch.nn.Module): + def __init__(self, smoothing=0.0): + super(LabelSmoothing, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + + +class DotDict(dict): + """dot.notation access to dictionary attributes""" + + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def find_optimal_coef( + results, + metric="avg_normalized_top1", + minimize=False, + control_metric=None, + control_metric_threshold=0.0, +): + best_coef = None + if minimize: + best_metric = 1 + else: + best_metric = 0 + for scaling_coef in results.keys(): + if control_metric is not None: + if results[scaling_coef][control_metric] < control_metric_threshold: + print(f"Control metric fell below {control_metric_threshold} threshold") + continue + if minimize: + if results[scaling_coef][metric] < best_metric: + best_metric = results[scaling_coef][metric] + best_coef = scaling_coef + else: + if results[scaling_coef][metric] > best_metric: + best_metric = results[scaling_coef][metric] + best_coef = scaling_coef + return best_coef + + +def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes): + return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes) + + +def calculate_linearized_orthogonality_loss(linearized_model): + """Compute orthogonality loss ||delta_W^T delta_W - I||_F for a LinearizedModel.""" + ortho_loss = 0.0 + for p_finetuned, p_pretrained in zip(linearized_model.params, linearized_model.params0): + if p_finetuned.requires_grad and p_finetuned.dim() == 2: + delta_W = p_finetuned - p_pretrained + + rows, cols = delta_W.shape + if rows < cols: + mat = delta_W @ delta_W.T + identity = torch.eye(rows, device=delta_W.device) + else: + mat = delta_W.T @ delta_W + identity = torch.eye(cols, device=delta_W.device) + + ortho_loss += torch.norm(mat - identity, p='fro') + + return ortho_loss + + +def calculate_standard_orthogonality_loss(model, pretrained_state_dict): + """Compute orthogonality loss ||delta_W^T delta_W - I||_F for standard/linear-2 finetuning. + + Args: + model: DDP-wrapped ImageClassifier (ddp_model). + pretrained_state_dict: snapshot of the pretrained model's inner ViT state_dict. + """ + ortho_loss = 0.0 + + for name, p_finetuned in model.module.image_encoder.model.named_parameters(): + if p_finetuned.requires_grad and p_finetuned.dim() == 2: + if name in pretrained_state_dict: + p_pretrained = pretrained_state_dict[name].to(p_finetuned.device) + + delta_W = p_finetuned - p_pretrained + + rows, cols = delta_W.shape + if rows < cols: + mat = delta_W @ delta_W.T + identity = torch.eye(rows, device=delta_W.device) + else: + mat = delta_W.T @ delta_W + identity = torch.eye(cols, device=delta_W.device) + + ortho_loss += torch.norm(mat - identity, p='fro') + + return ortho_loss