Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- README.md +207 -0
- assets/WVO-WD-TFS.png +3 -0
- assets/orthoreg_loss.png +3 -0
- environment.yml +140 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/args.cpython-310.pyc +0 -0
- src/__pycache__/attention_only_finetune.cpython-310.pyc +0 -0
- src/__pycache__/distributed.cpython-310.pyc +0 -0
- src/__pycache__/eval.cpython-310.pyc +0 -0
- src/__pycache__/heads.cpython-310.pyc +0 -0
- src/__pycache__/linearize.cpython-310.pyc +0 -0
- src/__pycache__/modeling.cpython-310.pyc +0 -0
- src/__pycache__/task_vectors.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/args.py +153 -0
- src/attention_only_finetune.py +116 -0
- src/datasets/__pycache__/cars.cpython-310.pyc +0 -0
- src/datasets/__pycache__/cifar10.cpython-310.pyc +0 -0
- src/datasets/__pycache__/cifar100.cpython-310.pyc +0 -0
- src/datasets/__pycache__/common.cpython-310.pyc +0 -0
- src/datasets/__pycache__/dtd.cpython-310.pyc +0 -0
- src/datasets/__pycache__/emnist.cpython-310.pyc +0 -0
- src/datasets/__pycache__/eurosat.cpython-310.pyc +0 -0
- src/datasets/__pycache__/gtsrb.cpython-310.pyc +0 -0
- src/datasets/__pycache__/imagenet.cpython-310.pyc +0 -0
- src/datasets/__pycache__/kmnist.cpython-310.pyc +0 -0
- src/datasets/__pycache__/mnist.cpython-310.pyc +0 -0
- src/datasets/__pycache__/oxfordpets.cpython-310.pyc +0 -0
- src/datasets/__pycache__/registry.cpython-310.pyc +0 -0
- src/datasets/__pycache__/resisc45.cpython-310.pyc +0 -0
- src/datasets/__pycache__/stl10.cpython-310.pyc +0 -0
- src/datasets/__pycache__/sun397.cpython-310.pyc +0 -0
- src/datasets/__pycache__/svhn.cpython-310.pyc +0 -0
- src/datasets/__pycache__/templates.cpython-310.pyc +0 -0
- src/datasets/cars.py +155 -0
- src/datasets/cifar10.py +56 -0
- src/datasets/cifar100.py +30 -0
- src/datasets/common.py +139 -0
- src/datasets/dtd.py +34 -0
- src/datasets/emnist.py +74 -0
- src/datasets/eurosat.py +75 -0
- src/datasets/gtsrb.py +205 -0
- src/datasets/imagenet.py +253 -0
- src/datasets/kmnist.py +39 -0
- src/datasets/mnist.py +41 -0
- src/datasets/oxfordpets.py +38 -0
- src/datasets/registry.py +103 -0
- src/datasets/resisc45.py +304 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/WVO-WD-TFS.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/orthoreg_loss.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Understanding and Enforcing Weight Disentanglement in Task Arithmetic
|
| 2 |
+
|
| 3 |
+
[CVPR 2026] Official code of the paper **"Understanding and Enforcing Weight Disentanglement in Task Arithmetic"**.
|
| 4 |
+
|
| 5 |
+
[[Paper](https://arxiv.org/abs/2604.17078)] [[Checkpoints](#-checkpoints)] [[Datasets](#-datasets)]
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 🎯 Abstract
|
| 10 |
+
|
| 11 |
+
Task arithmetic provides an efficient, training-free way to edit pre-trained models, yet lacks a fundamental theoretical explanation for its success. The existing concept of "weight disentanglement" describes the ideal outcome of non-interfering task composition but does not reveal its underlying cause. Crucially, what intrinsic properties of the pre-trained model ($\theta_0$) or the task vectors ($\tau_t$) enable this disentanglement remains underexplored. In this paper, we introduce Task-Feature Specialization (TFS), a model's ability to allocate distinct internal features to different tasks, as the fundamental principle. We first prove that TFS is a sufficient condition for weight disentanglement. More importantly, we find that TFS also gives rise to an observable geometric consequence: weight vector orthogonality. This positions TFS as the common cause for both the desired functional outcome (disentanglement) and a measurable geometric property (orthogonality). This relationship provides the key insight for our method: since the abstract TFS property is intractable to enforce directly, we can instead promote weight disentanglement by shaping its concrete geometric consequence, orthogonality. Therefore, we propose OrthoReg, a simple and effective regularization method that actively enforces an internal orthogonal structure on weight updates ($\Delta W$) that constitute $\tau_t$ during fine-tuning. And we theoretically prove that OrthoReg promotes disentanglement. Extensive experiments demonstrate that OrthoReg consistently and significantly enhances the performance of various task arithmetic methods.
|
| 12 |
+
|
| 13 |
+
<p align="center">
|
| 14 |
+
<img src="assets/WVO-WD-TFS.png" width="500"/>
|
| 15 |
+
<br>
|
| 16 |
+
<em>TFS is the common cause connecting Weight Vector Orthogonality (WVO) with Weight Disentanglement (WD).</em>
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
### ✨ Key Contributions
|
| 20 |
+
|
| 21 |
+
- 📐 **Theory**: We identify TFS as a sufficient condition for weight disentanglement, and WVO as its geometric consequence, providing the first principled explanation for task arithmetic.
|
| 22 |
+
- 🔧 **Method (OrthoReg)**: A simple regularization term added to the fine-tuning loss that enforces column-wise orthogonality on ΔW, for which we prove theoretical efficacy.
|
| 23 |
+
- 🔗 **Connection to TTA**: We show that OrthoReg and Tangent Task Arithmetic (TTA) share the same underlying mechanism (i.e. inter-task vector orthogonality), but OrthoReg achieves this more efficiently.
|
| 24 |
+
- 📊 **Experiments**: Consistent and significant improvements over Non-linear FT, TTA, ATT-FT, LoRA-ATT across ViT-B-32, ViT-B-16, and ViT-L-14.
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
### The OrthoReg Loss
|
| 29 |
+
|
| 30 |
+
<p align="center">
|
| 31 |
+
<img src="assets/orthoreg_loss.png" width="560"/>
|
| 32 |
+
</p>
|
| 33 |
+
|
| 34 |
+
The total loss adds a regularization term to the standard task objective:
|
| 35 |
+
|
| 36 |
+
$$\mathcal{L} = \mathcal{L}_{\text{task}}(\theta_0 + \Delta\theta) + \lambda \cdot \mathcal{L}_{\text{ortho}}(\Delta\theta)$$
|
| 37 |
+
|
| 38 |
+
$$\mathcal{L}_{\text{ortho}}(\Delta\theta) = \sum_l \left\|(\Delta W^{(l)})^\top \Delta W^{(l)} - I\right\|_F^2$$
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## 🛠️ Installation
|
| 43 |
+
|
| 44 |
+
This codebase is built on top of [Tangent Task Arithmetic (TTA)](https://github.com/gortizji/tangent_task_arithmetic). Environment setup follows theirs exactly.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
To run the code, please install all its dependencies:
|
| 48 |
+
```sh
|
| 49 |
+
conda env create
|
| 50 |
+
conda activate tangent-arithmetic
|
| 51 |
+
```
|
| 52 |
+
and add the `src` directory to the `PYTHONPATH`:
|
| 53 |
+
```sh
|
| 54 |
+
cd OrthoReg
|
| 55 |
+
export PYTHONPATH="$PYTHONPATH:$PWD"
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## 📦 Datasets
|
| 61 |
+
|
| 62 |
+
We evaluate on 8 image classification benchmarks following [Task Arithmetic](https://github.com/mlfoundations/task_vectors) and [TTA](https://github.com/gortizji/tangent_task_arithmetic):
|
| 63 |
+
|
| 64 |
+
**Cars · DTD · EuroSAT · GTSRB · MNIST · RESISC45 · SUN397 · SVHN**
|
| 65 |
+
|
| 66 |
+
For dataset download and preparation, please follow the instructions in the [TTA repository](https://github.com/gortizji/tangent_task_arithmetic#datasets).
|
| 67 |
+
|
| 68 |
+
We also provide a pre-packaged dataset archive for convenience:
|
| 69 |
+
|
| 70 |
+
> 📥 **Dataset Download:** `https://pan.baidu.com/s/1PgLyjUrAhsmgSAz4ms5mcQ?pwd=fwf5`
|
| 71 |
+
|
| 72 |
+
Set the root path via `--data-location /path/to/datasets/`.
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## 🚀 Quick Start
|
| 77 |
+
|
| 78 |
+
All scripts are run from the `OrthoReg/` directory. This repository implements **6 finetuning modes**:
|
| 79 |
+
|
| 80 |
+
| `--finetuning-mode` | Description |
|
| 81 |
+
|---|---|
|
| 82 |
+
| `standard` | Non-linear full fine-tuning (baseline) |
|
| 83 |
+
| `standard_ortho` | Non-linear FT + OrthoReg |
|
| 84 |
+
| `linear` | TTA — tangent space fine-tuning (baseline) |
|
| 85 |
+
| `linear_ortho` | TTA + OrthoReg |
|
| 86 |
+
| `linear-2` | ATT-FT — attention-only fine-tuning (baseline) |
|
| 87 |
+
| `linear-2_ortho` | ATT-FT + OrthoReg |
|
| 88 |
+
|
| 89 |
+
> **Note on LoRA-ATT:** The LoRA-ATT and LoRA-ATT+OrthoReg results from the paper are implemented in a separate repository due to the complexity of patching OpenCLIP's fused QKV projection. Code will be released at: `https://github.com/lshangge/OrthoReg_lora`
|
| 90 |
+
|
| 91 |
+
### Step 1 — Fine-tune
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
python src/finetune.py \
|
| 95 |
+
--model ViT-B-32 \
|
| 96 |
+
--finetuning-mode standard_ortho \
|
| 97 |
+
--ortho-lambda 10 \
|
| 98 |
+
--lr 1e-5 \
|
| 99 |
+
--data-location /path/to/datasets/ \
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Switch between all six modes by changing `--finetuning-mode` and `--ortho-lambda`:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
--finetuning-mode standard --ortho-lambda 0 # Non-linear FT
|
| 106 |
+
--finetuning-mode standard_ortho --ortho-lambda xx # Non-linear FT + OrthoReg
|
| 107 |
+
--finetuning-mode linear --ortho-lambda 0 # TTA
|
| 108 |
+
--finetuning-mode linear_ortho --ortho-lambda xx # TTA + OrthoReg
|
| 109 |
+
--finetuning-mode linear-2 --ortho-lambda 0 # ATT-FT
|
| 110 |
+
--finetuning-mode linear-2_ortho --ortho-lambda xx # ATT-FT + OrthoReg
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Checkpoints are saved to:
|
| 114 |
+
- `checkpoints_{seed}/{mode}_{lr}_{model}/` — for baselines
|
| 115 |
+
- `checkpoints_{seed}/{mode}_{lr}_lambda{lambda}_{model}/` — for OrthoReg variants
|
| 116 |
+
|
| 117 |
+
### Step 2 — Evaluate Single-Task Accuracy
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
python src/eval_single_task.py \
|
| 121 |
+
--model ViT-B-32 \
|
| 122 |
+
--finetuning-mode standard_ortho \
|
| 123 |
+
--ortho-lambda 10 \
|
| 124 |
+
--lr 1e-5 \
|
| 125 |
+
--data-location /path/to/datasets/
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
> Run `eval_single_task` with `--finetuning-mode none --ortho-lambda 0` first to generate `zeroshot_accuracies.json`, which is required as the reference for normalized accuracy in Steps 3–4.
|
| 129 |
+
|
| 130 |
+
### Step 3 — Evaluate Task Addition
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
python src/eval_task_addition.py \
|
| 134 |
+
--model ViT-B-32 \
|
| 135 |
+
--finetuning-mode standard_ortho \
|
| 136 |
+
--ortho-lambda 10 \
|
| 137 |
+
--lr 1e-5 \
|
| 138 |
+
--data-location /path/to/datasets/
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### Step 4 — Evaluate Task Negation
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
python src/eval_task_negation.py \
|
| 145 |
+
--model ViT-B-32 \
|
| 146 |
+
--finetuning-mode standard_ortho \
|
| 147 |
+
--ortho-lambda 10 \
|
| 148 |
+
--lr 1e-5 \
|
| 149 |
+
--data-location /path/to/datasets/
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## 🔧 Key Arguments
|
| 155 |
+
|
| 156 |
+
| Argument | Default | Description |
|
| 157 |
+
|---|:---:|---|
|
| 158 |
+
| `--model` | `ViT-B-32` | CLIP model architecture |
|
| 159 |
+
| `--finetuning-mode` | — | One of the 6 modes above |
|
| 160 |
+
| `--ortho-lambda` | `0.0` | OrthoReg strength λ; set to `0` for baselines |
|
| 161 |
+
| `--lr` | `1e-5` | Learning rate |
|
| 162 |
+
| `--seed` | `1993` | Random seed |
|
| 163 |
+
| `--world-size` | `1` | Number of GPUs (DDP) |
|
| 164 |
+
| `--data-location` | — | Dataset root directory |
|
| 165 |
+
| `--batch-size` | `128` | Batch size per GPU |
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
## 📁 Checkpoints
|
| 170 |
+
|
| 171 |
+
We release fine-tuned checkpoints for ViT-B-32, ViT-B-16, and ViT-L-14 on all 8 tasks, covering all 6 modes.
|
| 172 |
+
|
| 173 |
+
> 📥 **Checkpoint Download:** `https://huggingface.co/gezi2333/OrthoReg_checkpoints`
|
| 174 |
+
|
| 175 |
+
Unzip into `OrthoReg/checkpoints_{seed}/` and pass the corresponding `--seed`, `--lr`, and `--ortho-lambda` to the eval scripts to reproduce the paper's results directly.
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## 📝 Citation
|
| 180 |
+
|
| 181 |
+
If you find this work useful, please cite:
|
| 182 |
+
|
| 183 |
+
```bibtex
|
| 184 |
+
@inproceedings{liu2026orthoreg,
|
| 185 |
+
title = {Understanding and Enforcing Weight Disentanglement in Task Arithmetic},
|
| 186 |
+
author = {Liu, Shangge and Yin, Yuehan and Wang, Lei and Fan, Qi and
|
| 187 |
+
Shi, Yinghuan and Li, Wenbin and Gao, Yang and Tao, Dacheng},
|
| 188 |
+
booktitle = {CVPR},
|
| 189 |
+
year = {2026}
|
| 190 |
+
}
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
---
|
| 194 |
+
|
| 195 |
+
## 📞 Contact
|
| 196 |
+
|
| 197 |
+
For questions or issues, please:
|
| 198 |
+
|
| 199 |
+
- Open an issue on GitHub
|
| 200 |
+
- Contact the authors at [lshangge@smail.nju.edu.cn]
|
| 201 |
+
|
| 202 |
+
---
|
| 203 |
+
|
| 204 |
+
## 📬 Acknowledgements
|
| 205 |
+
|
| 206 |
+
This codebase is built on top of [Task Arithmetic](https://github.com/mlfoundations/task_vectors), [Tangent Task Arithmetic](https://github.com/gortizji/tangent_task_arithmetic), and [Attention-Only Fine-tuning](https://github.com/kyrie-23/linear_task_arithmetic). We thank the authors for releasing their code.
|
| 207 |
+
|
assets/WVO-WD-TFS.png
ADDED
|
Git LFS Details
|
assets/orthoreg_loss.png
ADDED
|
Git LFS Details
|
environment.yml
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: tangent-arithmetic
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- _libgcc_mutex=0.1
|
| 8 |
+
- _openmp_mutex=5.1
|
| 9 |
+
- blas=1.0
|
| 10 |
+
- brotlipy=0.7.0
|
| 11 |
+
- bzip2=1.0.8
|
| 12 |
+
- ca-certificates=2023.05.30
|
| 13 |
+
- certifi=2023.5.7
|
| 14 |
+
- cffi=1.15.1
|
| 15 |
+
- charset-normalizer=2.0.4
|
| 16 |
+
- cryptography=39.0.1
|
| 17 |
+
- cuda=11.6.1
|
| 18 |
+
- cuda-cccl=11.6.55
|
| 19 |
+
- cuda-command-line-tools=11.6.2
|
| 20 |
+
- cuda-compiler=11.6.2
|
| 21 |
+
- cuda-cudart=11.6.55
|
| 22 |
+
- cuda-cudart-dev=11.6.55
|
| 23 |
+
- cuda-cuobjdump=11.6.124
|
| 24 |
+
- cuda-cupti=11.6.124
|
| 25 |
+
- cuda-cuxxfilt=11.6.124
|
| 26 |
+
- cuda-driver-dev=11.6.55
|
| 27 |
+
- cuda-gdb=12.1.105
|
| 28 |
+
- cuda-libraries=11.6.1
|
| 29 |
+
- cuda-libraries-dev=11.6.1
|
| 30 |
+
- cuda-memcheck=11.8.86
|
| 31 |
+
- cuda-nsight=12.1.105
|
| 32 |
+
- cuda-nsight-compute=12.1.1
|
| 33 |
+
- cuda-nvcc=11.6.124
|
| 34 |
+
- cuda-nvdisasm=12.1.105
|
| 35 |
+
- cuda-nvml-dev=11.6.55
|
| 36 |
+
- cuda-nvprof=12.1.105
|
| 37 |
+
- cuda-nvprune=11.6.124
|
| 38 |
+
- cuda-nvrtc=11.6.124
|
| 39 |
+
- cuda-nvrtc-dev=11.6.124
|
| 40 |
+
- cuda-nvtx=11.6.124
|
| 41 |
+
- cuda-nvvp=12.1.105
|
| 42 |
+
- cuda-runtime=11.6.1
|
| 43 |
+
- cuda-samples=11.6.101
|
| 44 |
+
- cuda-sanitizer-api=12.1.105
|
| 45 |
+
- cuda-toolkit=11.6.1
|
| 46 |
+
- cuda-tools=11.6.1
|
| 47 |
+
- cuda-visual-tools=11.6.1
|
| 48 |
+
- ffmpeg=4.3
|
| 49 |
+
- freetype=2.12.1
|
| 50 |
+
- gds-tools=1.6.1.9
|
| 51 |
+
- giflib=5.2.1
|
| 52 |
+
- gmp=6.2.1
|
| 53 |
+
- gnutls=3.6.15
|
| 54 |
+
- idna=3.4
|
| 55 |
+
- intel-openmp=2023.1.0
|
| 56 |
+
- jpeg=9e
|
| 57 |
+
- lame=3.100
|
| 58 |
+
- lcms2=2.12
|
| 59 |
+
- ld_impl_linux-64=2.38
|
| 60 |
+
- lerc=3.0
|
| 61 |
+
- libcublas=11.9.2.110
|
| 62 |
+
- libcublas-dev=11.9.2.110
|
| 63 |
+
- libcufft=10.7.1.112
|
| 64 |
+
- libcufft-dev=10.7.1.112
|
| 65 |
+
- libcufile=1.6.1.9
|
| 66 |
+
- libcufile-dev=1.6.1.9
|
| 67 |
+
- libcurand=10.3.2.106
|
| 68 |
+
- libcurand-dev=10.3.2.106
|
| 69 |
+
- libcusolver=11.3.4.124
|
| 70 |
+
- libcusparse=11.7.2.124
|
| 71 |
+
- libcusparse-dev=11.7.2.124
|
| 72 |
+
- libdeflate=1.17
|
| 73 |
+
- libffi=3.4.4
|
| 74 |
+
- libgcc-ng=11.2.0
|
| 75 |
+
- libgomp=11.2.0
|
| 76 |
+
- libiconv=1.16
|
| 77 |
+
- libidn2=2.3.4
|
| 78 |
+
- libnpp=11.6.3.124
|
| 79 |
+
- libnpp-dev=11.6.3.124
|
| 80 |
+
- libnvjpeg=11.6.2.124
|
| 81 |
+
- libnvjpeg-dev=11.6.2.124
|
| 82 |
+
- libpng=1.6.39
|
| 83 |
+
- libstdcxx-ng=11.2.0
|
| 84 |
+
- libtasn1=4.19.0
|
| 85 |
+
- libtiff=4.5.0
|
| 86 |
+
- libunistring=0.9.10
|
| 87 |
+
- libuuid=1.41.5
|
| 88 |
+
- libwebp=1.2.4
|
| 89 |
+
- libwebp-base=1.2.4
|
| 90 |
+
- lz4-c=1.9.4
|
| 91 |
+
- mkl=2023.1.0
|
| 92 |
+
- mkl-service=2.4.0
|
| 93 |
+
- mkl_fft=1.3.6
|
| 94 |
+
- mkl_random=1.2.2
|
| 95 |
+
- ncurses=6.4
|
| 96 |
+
- nettle=3.7.3
|
| 97 |
+
- nsight-compute=2023.1.1.4
|
| 98 |
+
- numpy=1.24.3
|
| 99 |
+
- numpy-base=1.24.3
|
| 100 |
+
- openh264=2.1.1
|
| 101 |
+
- openssl=1.1.1t
|
| 102 |
+
- pillow=9.4.0
|
| 103 |
+
- pip=23.0.1
|
| 104 |
+
- pycparser=2.21
|
| 105 |
+
- pyopenssl=23.0.0
|
| 106 |
+
- pysocks=1.7.1
|
| 107 |
+
- python=3.10.11
|
| 108 |
+
- pytorch=1.13.1
|
| 109 |
+
- pytorch-cuda=11.6
|
| 110 |
+
- pytorch-mutex=1.0
|
| 111 |
+
- readline=8.2
|
| 112 |
+
- requests=2.29.0
|
| 113 |
+
- setuptools=67.8.0
|
| 114 |
+
- sqlite=3.41.2
|
| 115 |
+
- tbb=2021.8.0
|
| 116 |
+
- tk=8.6.12
|
| 117 |
+
- torchaudio=0.13.1
|
| 118 |
+
- torchvision=0.14.1
|
| 119 |
+
- typing_extensions=4.5.0
|
| 120 |
+
- tzdata=2023c
|
| 121 |
+
- urllib3=1.26.16
|
| 122 |
+
- wheel=0.38.4
|
| 123 |
+
- xz=5.4.2
|
| 124 |
+
- zlib=1.2.13
|
| 125 |
+
- zstd=1.5.5
|
| 126 |
+
- pip:
|
| 127 |
+
- filelock==3.12.0
|
| 128 |
+
- fsspec==2023.5.0
|
| 129 |
+
- ftfy==6.1.1
|
| 130 |
+
- huggingface-hub==0.15.1
|
| 131 |
+
- open-clip-torch==2.10.1
|
| 132 |
+
- packaging==23.1
|
| 133 |
+
- protobuf==3.20.3
|
| 134 |
+
- pyyaml==6.0
|
| 135 |
+
- regex==2023.6.3
|
| 136 |
+
- safetensors==0.3.1
|
| 137 |
+
- scipy==1.10.1
|
| 138 |
+
- sentencepiece==0.1.99
|
| 139 |
+
- timm==0.9.2
|
| 140 |
+
- wcwidth==0.2.6
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (139 Bytes). View file
|
|
|
src/__pycache__/args.cpython-310.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
src/__pycache__/attention_only_finetune.cpython-310.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
src/__pycache__/distributed.cpython-310.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
src/__pycache__/eval.cpython-310.pyc
ADDED
|
Binary file (3.38 kB). View file
|
|
|
src/__pycache__/heads.cpython-310.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
src/__pycache__/linearize.cpython-310.pyc
ADDED
|
Binary file (6.29 kB). View file
|
|
|
src/__pycache__/modeling.cpython-310.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|
src/__pycache__/task_vectors.cpython-310.pyc
ADDED
|
Binary file (7.75 kB). View file
|
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (5.95 kB). View file
|
|
|
src/args.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def parse_arguments():
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
parser.add_argument(
|
| 10 |
+
"--data-location",
|
| 11 |
+
type=str,
|
| 12 |
+
default=os.path.expanduser("/path/datasets/"),
|
| 13 |
+
help="The root directory for the datasets.",
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--eval-datasets",
|
| 17 |
+
default=None,
|
| 18 |
+
type=lambda x: x.split(","),
|
| 19 |
+
help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--train-dataset",
|
| 23 |
+
default=None,
|
| 24 |
+
type=lambda x: x.split(","),
|
| 25 |
+
help="Which dataset(s) to patch on.",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--exp_name",
|
| 29 |
+
type=str,
|
| 30 |
+
default=None,
|
| 31 |
+
help="Name of the experiment, for organization purposes only.",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--results-db",
|
| 35 |
+
type=str,
|
| 36 |
+
default=None,
|
| 37 |
+
help="Where to store the results, else does not store",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--model",
|
| 41 |
+
type=str,
|
| 42 |
+
default="ViT-B-32",
|
| 43 |
+
help="The type of model (e.g. RN50, ViT-B-32).",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--batch-size",
|
| 47 |
+
type=int,
|
| 48 |
+
default=128,
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--num-grad-accumulation",
|
| 52 |
+
type=int,
|
| 53 |
+
default=1,
|
| 54 |
+
help="Number of gradient accumulation steps.",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
|
| 57 |
+
parser.add_argument("--wd", type=float, default=0.1, help="Weight decay")
|
| 58 |
+
parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.")
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--warmup_length",
|
| 61 |
+
type=int,
|
| 62 |
+
default=500,
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--epochs",
|
| 66 |
+
type=int,
|
| 67 |
+
default=10,
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--load",
|
| 71 |
+
type=lambda x: x.split(","),
|
| 72 |
+
default=None,
|
| 73 |
+
help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--save",
|
| 77 |
+
type=str,
|
| 78 |
+
default=None,
|
| 79 |
+
help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--cache-dir",
|
| 83 |
+
type=str,
|
| 84 |
+
default=None,
|
| 85 |
+
help="Directory for caching features and encoder",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--openclip-cachedir",
|
| 89 |
+
type=str,
|
| 90 |
+
default=os.path.expanduser("~/openclip-cachedir/open_clip"),
|
| 91 |
+
help="Directory for caching models from OpenCLIP",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--world-size",
|
| 95 |
+
type=int,
|
| 96 |
+
default=1,
|
| 97 |
+
help="Number of processes for distributed training.",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--checkpoint-every",
|
| 101 |
+
type=int,
|
| 102 |
+
default=-1,
|
| 103 |
+
help="How often to checkpoint the model.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--port",
|
| 107 |
+
type=int,
|
| 108 |
+
default=12355,
|
| 109 |
+
help="Port for distributed training.",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--seed",
|
| 113 |
+
type=int,
|
| 114 |
+
default=1993,
|
| 115 |
+
help="Random seed.",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--finetuning-mode",
|
| 119 |
+
choices=["standard", "standard_ortho", "linear", "linear_ortho", "linear-2", "linear-2_ortho"],
|
| 120 |
+
help="Finetuning mode: standard/linear/linear-2 with optional ortho regularization.",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--n-eval-points",
|
| 124 |
+
type=int,
|
| 125 |
+
default=21,
|
| 126 |
+
help="Number of evaluation points used to find optimal coefficient in task arithmetic.",
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--ortho-lambda",
|
| 130 |
+
type=float,
|
| 131 |
+
default=0.0,
|
| 132 |
+
help="Weight of the orthogonality regularization term. Default 0.0 means no regularization.",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--control_threshold",
|
| 136 |
+
type=float,
|
| 137 |
+
default=0.95,
|
| 138 |
+
help="Control dataset performance degradation threshold.",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--alpha",
|
| 142 |
+
type=float,
|
| 143 |
+
default=None,
|
| 144 |
+
help="Manually specify the scaling coefficient for task vectors. If None, it will be optimized on validation set.",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
parsed_args = parser.parse_args()
|
| 148 |
+
parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 149 |
+
|
| 150 |
+
if parsed_args.load is not None and len(parsed_args.load) == 1:
|
| 151 |
+
parsed_args.load = parsed_args.load[0]
|
| 152 |
+
|
| 153 |
+
return parsed_args
|
src/attention_only_finetune.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from src.modeling import ImageEncoder
|
| 5 |
+
from src.utils import DotDict
|
| 6 |
+
|
| 7 |
+
class AttentionOnlyFinetuneEncoder(ImageEncoder):
|
| 8 |
+
"""
|
| 9 |
+
A specialized ImageEncoder that fine-tunes only the attention module weights in the ViT.
|
| 10 |
+
Corresponds to the method described in Jin et al. (2025).
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, args, keep_lang=False):
|
| 13 |
+
# 1. Call the parent constructor to build the full model as usual
|
| 14 |
+
super().__init__(args, keep_lang=keep_lang)
|
| 15 |
+
|
| 16 |
+
self.args = args
|
| 17 |
+
|
| 18 |
+
# 2. Freeze all model parameters
|
| 19 |
+
# print("Freezing all parameters of the model initially...")
|
| 20 |
+
for param in self.model.parameters():
|
| 21 |
+
param.requires_grad = False
|
| 22 |
+
|
| 23 |
+
# 3. Unfreeze only the Attention module weights (Wq, Wk, Wv, Wo)
|
| 24 |
+
# print("Unfreezing Attention module weights for fine-tuning...")
|
| 25 |
+
self._unfreeze_attention_weights(self.model.visual)
|
| 26 |
+
|
| 27 |
+
# 4. (Optional but recommended) Print trainable parameters for verification
|
| 28 |
+
# self._verify_trainable_params()
|
| 29 |
+
|
| 30 |
+
def _unfreeze_attention_weights(self, vit_model):
|
| 31 |
+
"""
|
| 32 |
+
Iterate over all Transformer blocks and unfreeze the attention projection weights.
|
| 33 |
+
"""
|
| 34 |
+
# Iterate over the model and unfreeze target parameters
|
| 35 |
+
for block in vit_model.transformer.resblocks:
|
| 36 |
+
# Unfreeze the combined input projection weight for Q, K, V
|
| 37 |
+
block.attn.in_proj_weight.requires_grad = True
|
| 38 |
+
|
| 39 |
+
# Unfreeze the output projection weight
|
| 40 |
+
block.attn.out_proj.weight.requires_grad = True
|
| 41 |
+
|
| 42 |
+
# Per the paper's ablation study, not fine-tuning biases yields better results; keep them frozen
|
| 43 |
+
# block.attn.in_proj_bias.requires_grad = True
|
| 44 |
+
# block.attn.out_proj.bias.requires_grad = True
|
| 45 |
+
|
| 46 |
+
def _verify_trainable_params(self):
|
| 47 |
+
"""Print all trainable parameters for debugging and verification."""
|
| 48 |
+
print("="*80)
|
| 49 |
+
print("Trainable parameters in AttentionOnlyFinetuneEncoder:")
|
| 50 |
+
trainable_params_count = 0
|
| 51 |
+
for name, param in self.model.named_parameters():
|
| 52 |
+
if param.requires_grad:
|
| 53 |
+
print(f" - {name}")
|
| 54 |
+
trainable_params_count += param.numel()
|
| 55 |
+
print(f"Total trainable parameters: {trainable_params_count / 1e6:.2f}M")
|
| 56 |
+
print("="*80)
|
| 57 |
+
|
| 58 |
+
def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None):
|
| 59 |
+
"""
|
| 60 |
+
Extended forward method to optionally compute and return the orthogonal loss.
|
| 61 |
+
Consistent with the logic implemented for standard_ortho.
|
| 62 |
+
"""
|
| 63 |
+
# Original forward pass
|
| 64 |
+
features = self.model.encode_image(images)
|
| 65 |
+
|
| 66 |
+
# Return features directly if orthogonal loss is not needed
|
| 67 |
+
if not calculate_ortho_loss:
|
| 68 |
+
return features
|
| 69 |
+
|
| 70 |
+
# --- Compute orthogonal loss if requested ---
|
| 71 |
+
if pretrained_state_dict is None:
|
| 72 |
+
raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True")
|
| 73 |
+
|
| 74 |
+
ortho_loss = 0.0
|
| 75 |
+
# self.model is the open_clip model (e.g. ViT); iterate over its parameters
|
| 76 |
+
for name, p_finetuned in self.model.named_parameters():
|
| 77 |
+
# Only compute loss for trainable parameters with gradients
|
| 78 |
+
if p_finetuned.requires_grad and p_finetuned.dim() == 2:
|
| 79 |
+
if name in pretrained_state_dict:
|
| 80 |
+
p_pretrained = pretrained_state_dict[name].to(p_finetuned.device)
|
| 81 |
+
|
| 82 |
+
delta_W = p_finetuned - p_pretrained
|
| 83 |
+
|
| 84 |
+
# Compute orthogonal loss (W^T * W - I)
|
| 85 |
+
rows, cols = delta_W.shape
|
| 86 |
+
if rows < cols:
|
| 87 |
+
mat = delta_W @ delta_W.T
|
| 88 |
+
identity = torch.eye(rows, device=delta_W.device)
|
| 89 |
+
else:
|
| 90 |
+
mat = delta_W.T @ delta_W
|
| 91 |
+
identity = torch.eye(cols, device=delta_W.device)
|
| 92 |
+
|
| 93 |
+
ortho_loss += torch.norm(mat - identity, p='fro')
|
| 94 |
+
|
| 95 |
+
return features, ortho_loss
|
| 96 |
+
|
| 97 |
+
def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
|
| 98 |
+
# Ensure __call__ forwards all arguments
|
| 99 |
+
return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict)
|
| 100 |
+
|
| 101 |
+
def save(self, filename):
|
| 102 |
+
"""Save model weights."""
|
| 103 |
+
# print(f"Saving AttentionOnlyFinetuneEncoder state_dict to {filename}")
|
| 104 |
+
if os.path.dirname(filename):
|
| 105 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 106 |
+
# Save only the state_dict; reconstruct the model on load
|
| 107 |
+
torch.save(self.state_dict(), filename)
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def load(cls, filename, args):
|
| 111 |
+
"""Load model from a state_dict."""
|
| 112 |
+
# print(f"Loading AttentionOnlyFinetuneEncoder from {filename}")
|
| 113 |
+
encoder = cls(args) # Create a new instance
|
| 114 |
+
state_dict = torch.load(filename, map_location='cpu')
|
| 115 |
+
encoder.load_state_dict(state_dict) # Load weights
|
| 116 |
+
return encoder
|
src/datasets/__pycache__/cars.cpython-310.pyc
ADDED
|
Binary file (5.94 kB). View file
|
|
|
src/datasets/__pycache__/cifar10.cpython-310.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
src/datasets/__pycache__/cifar100.cpython-310.pyc
ADDED
|
Binary file (924 Bytes). View file
|
|
|
src/datasets/__pycache__/common.cpython-310.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
src/datasets/__pycache__/dtd.cpython-310.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
src/datasets/__pycache__/emnist.cpython-310.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
src/datasets/__pycache__/eurosat.cpython-310.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
src/datasets/__pycache__/gtsrb.cpython-310.pyc
ADDED
|
Binary file (7.68 kB). View file
|
|
|
src/datasets/__pycache__/imagenet.cpython-310.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
src/datasets/__pycache__/kmnist.cpython-310.pyc
ADDED
|
Binary file (952 Bytes). View file
|
|
|
src/datasets/__pycache__/mnist.cpython-310.pyc
ADDED
|
Binary file (947 Bytes). View file
|
|
|
src/datasets/__pycache__/oxfordpets.cpython-310.pyc
ADDED
|
Binary file (986 Bytes). View file
|
|
|
src/datasets/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
src/datasets/__pycache__/resisc45.cpython-310.pyc
ADDED
|
Binary file (9.24 kB). View file
|
|
|
src/datasets/__pycache__/stl10.cpython-310.pyc
ADDED
|
Binary file (976 Bytes). View file
|
|
|
src/datasets/__pycache__/sun397.cpython-310.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
src/datasets/__pycache__/svhn.cpython-310.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
src/datasets/__pycache__/templates.cpython-310.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
src/datasets/cars.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.datasets as datasets
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import pathlib
|
| 7 |
+
from typing import Callable, Optional, Any, Tuple
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg
|
| 12 |
+
from torchvision.datasets.vision import VisionDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PytorchStanfordCars(VisionDataset):
|
| 16 |
+
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
|
| 17 |
+
|
| 18 |
+
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
|
| 19 |
+
split into 8,144 training images and 8,041 testing images, where each class
|
| 20 |
+
has been split roughly in a 50-50 split
|
| 21 |
+
|
| 22 |
+
.. note::
|
| 23 |
+
|
| 24 |
+
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
root (string): Root directory of dataset
|
| 28 |
+
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
|
| 29 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
| 30 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 31 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 32 |
+
target and transforms it.
|
| 33 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 34 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 35 |
+
downloaded again."""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
root: str,
|
| 40 |
+
split: str = "train",
|
| 41 |
+
transform: Optional[Callable] = None,
|
| 42 |
+
target_transform: Optional[Callable] = None,
|
| 43 |
+
download: bool = False,
|
| 44 |
+
) -> None:
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
import scipy.io as sio
|
| 48 |
+
except ImportError:
|
| 49 |
+
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
|
| 50 |
+
|
| 51 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 52 |
+
|
| 53 |
+
self._split = verify_str_arg(split, "split", ("train", "test"))
|
| 54 |
+
self._base_folder = pathlib.Path(root) / "stanford_cars"
|
| 55 |
+
devkit = self._base_folder / "devkit"
|
| 56 |
+
|
| 57 |
+
if self._split == "train":
|
| 58 |
+
self._annotations_mat_path = devkit / "cars_train_annos.mat"
|
| 59 |
+
self._images_base_path = self._base_folder / "cars_train"
|
| 60 |
+
else:
|
| 61 |
+
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
|
| 62 |
+
self._images_base_path = self._base_folder / "cars_test"
|
| 63 |
+
|
| 64 |
+
if download:
|
| 65 |
+
self.download()
|
| 66 |
+
|
| 67 |
+
if not self._check_exists():
|
| 68 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 69 |
+
|
| 70 |
+
self._samples = [
|
| 71 |
+
(
|
| 72 |
+
str(self._images_base_path / annotation["fname"]),
|
| 73 |
+
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
|
| 74 |
+
)
|
| 75 |
+
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
|
| 79 |
+
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
|
| 80 |
+
|
| 81 |
+
def __len__(self) -> int:
|
| 82 |
+
return len(self._samples)
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
|
| 85 |
+
"""Returns pil_image and class_id for given index"""
|
| 86 |
+
image_path, target = self._samples[idx]
|
| 87 |
+
pil_image = Image.open(image_path).convert("RGB")
|
| 88 |
+
|
| 89 |
+
if self.transform is not None:
|
| 90 |
+
pil_image = self.transform(pil_image)
|
| 91 |
+
if self.target_transform is not None:
|
| 92 |
+
target = self.target_transform(target)
|
| 93 |
+
return pil_image, target
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def download(self) -> None:
|
| 97 |
+
if self._check_exists():
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
download_and_extract_archive(
|
| 101 |
+
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
|
| 102 |
+
download_root=str(self._base_folder),
|
| 103 |
+
md5="c3b158d763b6e2245038c8ad08e45376",
|
| 104 |
+
)
|
| 105 |
+
if self._split == "train":
|
| 106 |
+
download_and_extract_archive(
|
| 107 |
+
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
|
| 108 |
+
download_root=str(self._base_folder),
|
| 109 |
+
md5="065e5b463ae28d29e77c1b4b166cfe61",
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
download_and_extract_archive(
|
| 113 |
+
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
|
| 114 |
+
download_root=str(self._base_folder),
|
| 115 |
+
md5="4ce7ebf6a94d07f1952d94dd34c4d501",
|
| 116 |
+
)
|
| 117 |
+
download_url(
|
| 118 |
+
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
|
| 119 |
+
root=str(self._base_folder),
|
| 120 |
+
md5="b0a2b23655a3edd16d84508592a98d10",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def _check_exists(self) -> bool:
|
| 124 |
+
if not (self._base_folder / "devkit").is_dir():
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Cars:
|
| 131 |
+
def __init__(self,
|
| 132 |
+
preprocess,
|
| 133 |
+
location=os.path.expanduser('~/data'),
|
| 134 |
+
batch_size=32,
|
| 135 |
+
num_workers=16):
|
| 136 |
+
# Data loading code
|
| 137 |
+
|
| 138 |
+
self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=False)
|
| 139 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 140 |
+
self.train_dataset,
|
| 141 |
+
shuffle=True,
|
| 142 |
+
batch_size=batch_size,
|
| 143 |
+
num_workers=num_workers,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=False)
|
| 147 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 148 |
+
self.test_dataset,
|
| 149 |
+
batch_size=batch_size,
|
| 150 |
+
num_workers=num_workers
|
| 151 |
+
)
|
| 152 |
+
idx_to_class = dict((v, k)
|
| 153 |
+
for k, v in self.train_dataset.class_to_idx.items())
|
| 154 |
+
self.classnames = [idx_to_class[i].replace(
|
| 155 |
+
'_', ' ') for i in range(len(idx_to_class))]
|
src/datasets/cifar10.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import PIL
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torchvision
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from torchvision.datasets import CIFAR10 as PyTorchCIFAR10
|
| 8 |
+
from torchvision.datasets import VisionDataset
|
| 9 |
+
|
| 10 |
+
cifar_classnames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
| 11 |
+
|
| 12 |
+
class CIFAR10:
|
| 13 |
+
def __init__(self, preprocess,
|
| 14 |
+
location=os.path.expanduser('~/data'),
|
| 15 |
+
batch_size=128,
|
| 16 |
+
num_workers=16):
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
self.train_dataset = PyTorchCIFAR10(
|
| 20 |
+
root=location, download=True, train=True, transform=preprocess
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 24 |
+
self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.test_dataset = PyTorchCIFAR10(
|
| 28 |
+
root=location, download=True, train=False, transform=preprocess
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 32 |
+
self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.classnames = self.test_dataset.classes
|
| 36 |
+
|
| 37 |
+
def convert(x):
|
| 38 |
+
if isinstance(x, np.ndarray):
|
| 39 |
+
return torchvision.transforms.functional.to_pil_image(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
class BasicVisionDataset(VisionDataset):
|
| 43 |
+
def __init__(self, images, targets, transform=None, target_transform=None):
|
| 44 |
+
if transform is not None:
|
| 45 |
+
transform.transforms.insert(0, convert)
|
| 46 |
+
super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform)
|
| 47 |
+
assert len(images) == len(targets)
|
| 48 |
+
|
| 49 |
+
self.images = images
|
| 50 |
+
self.targets = targets
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, index):
|
| 53 |
+
return self.transform(self.images[index]), self.targets[index]
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.targets)
|
src/datasets/cifar100.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision.datasets import CIFAR100 as PyTorchCIFAR100
|
| 4 |
+
|
| 5 |
+
class CIFAR100:
|
| 6 |
+
def __init__(self,
|
| 7 |
+
preprocess,
|
| 8 |
+
location=os.path.expanduser('~/data'),
|
| 9 |
+
batch_size=128,
|
| 10 |
+
num_workers=16):
|
| 11 |
+
|
| 12 |
+
self.train_dataset = PyTorchCIFAR100(
|
| 13 |
+
root=location, download=True, train=True, transform=preprocess
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 17 |
+
self.train_dataset, batch_size=batch_size, num_workers=num_workers
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
self.test_dataset = PyTorchCIFAR100(
|
| 21 |
+
root=location, download=True, train=False, transform=preprocess
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 25 |
+
self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.classnames = self.test_dataset.classes
|
| 29 |
+
|
| 30 |
+
|
src/datasets/common.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import json
|
| 4 |
+
import glob
|
| 5 |
+
import collections
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
import torchvision.datasets as datasets
|
| 13 |
+
from torch.utils.data import Dataset, DataLoader, Sampler
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SubsetSampler(Sampler):
|
| 17 |
+
def __init__(self, indices):
|
| 18 |
+
self.indices = indices
|
| 19 |
+
|
| 20 |
+
def __iter__(self):
|
| 21 |
+
return (i for i in self.indices)
|
| 22 |
+
|
| 23 |
+
def __len__(self):
|
| 24 |
+
return len(self.indices)
|
| 25 |
+
|
| 26 |
+
class ImageFolderWithPaths(datasets.ImageFolder):
|
| 27 |
+
def __init__(self, path, transform, flip_label_prob=0.0):
|
| 28 |
+
super().__init__(path, transform)
|
| 29 |
+
self.flip_label_prob = flip_label_prob
|
| 30 |
+
if self.flip_label_prob > 0:
|
| 31 |
+
print(f'Flipping labels with probability {self.flip_label_prob}')
|
| 32 |
+
num_classes = len(self.classes)
|
| 33 |
+
for i in range(len(self.samples)):
|
| 34 |
+
if random.random() < self.flip_label_prob:
|
| 35 |
+
new_label = random.randint(0, num_classes-1)
|
| 36 |
+
self.samples[i] = (
|
| 37 |
+
self.samples[i][0],
|
| 38 |
+
new_label
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, index):
|
| 42 |
+
image, label = super(ImageFolderWithPaths, self).__getitem__(index)
|
| 43 |
+
return {
|
| 44 |
+
'images': image,
|
| 45 |
+
'labels': label,
|
| 46 |
+
'image_paths': self.samples[index][0]
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def maybe_dictionarize(batch):
|
| 51 |
+
if isinstance(batch, dict):
|
| 52 |
+
return batch
|
| 53 |
+
|
| 54 |
+
if len(batch) == 2:
|
| 55 |
+
batch = {'images': batch[0], 'labels': batch[1]}
|
| 56 |
+
elif len(batch) == 3:
|
| 57 |
+
batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f'Unexpected number of elements: {len(batch)}')
|
| 60 |
+
|
| 61 |
+
return batch
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_features_helper(image_encoder, dataloader, device):
|
| 65 |
+
all_data = collections.defaultdict(list)
|
| 66 |
+
|
| 67 |
+
image_encoder = image_encoder.to(device)
|
| 68 |
+
image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
|
| 69 |
+
image_encoder.eval()
|
| 70 |
+
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
for batch in tqdm(dataloader):
|
| 73 |
+
batch = maybe_dictionarize(batch)
|
| 74 |
+
features = image_encoder(batch['images'].cuda())
|
| 75 |
+
|
| 76 |
+
all_data['features'].append(features.cpu())
|
| 77 |
+
|
| 78 |
+
for key, val in batch.items():
|
| 79 |
+
if key == 'images':
|
| 80 |
+
continue
|
| 81 |
+
if hasattr(val, 'cpu'):
|
| 82 |
+
val = val.cpu()
|
| 83 |
+
all_data[key].append(val)
|
| 84 |
+
else:
|
| 85 |
+
all_data[key].extend(val)
|
| 86 |
+
|
| 87 |
+
for key, val in all_data.items():
|
| 88 |
+
if torch.is_tensor(val[0]):
|
| 89 |
+
all_data[key] = torch.cat(val).numpy()
|
| 90 |
+
|
| 91 |
+
return all_data
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_features(is_train, image_encoder, dataset, device):
|
| 95 |
+
split = 'train' if is_train else 'val'
|
| 96 |
+
dname = type(dataset).__name__
|
| 97 |
+
if image_encoder.cache_dir is not None:
|
| 98 |
+
cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
|
| 99 |
+
cached_files = glob.glob(f'{cache_dir}/*')
|
| 100 |
+
if image_encoder.cache_dir is not None and len(cached_files) > 0:
|
| 101 |
+
print(f'Getting features from {cache_dir}')
|
| 102 |
+
data = {}
|
| 103 |
+
for cached_file in cached_files:
|
| 104 |
+
name = os.path.splitext(os.path.basename(cached_file))[0]
|
| 105 |
+
data[name] = torch.load(cached_file)
|
| 106 |
+
else:
|
| 107 |
+
print(f'Did not find cached features at {cache_dir}. Building from scratch.')
|
| 108 |
+
loader = dataset.train_loader if is_train else dataset.test_loader
|
| 109 |
+
data = get_features_helper(image_encoder, loader, device)
|
| 110 |
+
if image_encoder.cache_dir is None:
|
| 111 |
+
print('Not caching because no cache directory was passed.')
|
| 112 |
+
else:
|
| 113 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 114 |
+
print(f'Caching data at {cache_dir}')
|
| 115 |
+
for name, val in data.items():
|
| 116 |
+
torch.save(val, f'{cache_dir}/{name}.pt')
|
| 117 |
+
return data
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class FeatureDataset(Dataset):
|
| 121 |
+
def __init__(self, is_train, image_encoder, dataset, device):
|
| 122 |
+
self.data = get_features(is_train, image_encoder, dataset, device)
|
| 123 |
+
|
| 124 |
+
def __len__(self):
|
| 125 |
+
return len(self.data['features'])
|
| 126 |
+
|
| 127 |
+
def __getitem__(self, idx):
|
| 128 |
+
data = {k: v[idx] for k, v in self.data.items()}
|
| 129 |
+
data['features'] = torch.from_numpy(data['features']).float()
|
| 130 |
+
return data
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_dataloader(dataset, is_train, args, image_encoder=None):
|
| 134 |
+
if image_encoder is not None:
|
| 135 |
+
feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device)
|
| 136 |
+
dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train)
|
| 137 |
+
else:
|
| 138 |
+
dataloader = dataset.train_loader if is_train else dataset.test_loader
|
| 139 |
+
return dataloader
|
src/datasets/dtd.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.datasets as datasets
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DTD:
|
| 7 |
+
def __init__(self,
|
| 8 |
+
preprocess,
|
| 9 |
+
location=os.path.expanduser('~/data'),
|
| 10 |
+
batch_size=32,
|
| 11 |
+
num_workers=16):
|
| 12 |
+
# Data loading code
|
| 13 |
+
traindir = os.path.join(location, 'dtd', 'train')
|
| 14 |
+
valdir = os.path.join(location, 'dtd', 'val')
|
| 15 |
+
|
| 16 |
+
self.train_dataset = datasets.ImageFolder(
|
| 17 |
+
traindir, transform=preprocess)
|
| 18 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 19 |
+
self.train_dataset,
|
| 20 |
+
shuffle=True,
|
| 21 |
+
batch_size=batch_size,
|
| 22 |
+
num_workers=num_workers,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess)
|
| 26 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 27 |
+
self.test_dataset,
|
| 28 |
+
batch_size=batch_size,
|
| 29 |
+
num_workers=num_workers
|
| 30 |
+
)
|
| 31 |
+
idx_to_class = dict((v, k)
|
| 32 |
+
for k, v in self.train_dataset.class_to_idx.items())
|
| 33 |
+
self.classnames = [idx_to_class[i].replace(
|
| 34 |
+
'_', ' ') for i in range(len(idx_to_class))]
|
src/datasets/emnist.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import torchvision
|
| 6 |
+
import torchvision.datasets as datasets
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rotate_img(img):
|
| 10 |
+
return torchvision.transforms.functional.rotate(img, -90)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def flip_img(img):
|
| 14 |
+
return torchvision.transforms.functional.hflip(img)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def emnist_preprocess():
|
| 18 |
+
return torchvision.transforms.Compose(
|
| 19 |
+
[
|
| 20 |
+
rotate_img,
|
| 21 |
+
flip_img,
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class EMNIST:
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
preprocess,
|
| 30 |
+
location,
|
| 31 |
+
batch_size=128,
|
| 32 |
+
num_workers=8,
|
| 33 |
+
):
|
| 34 |
+
preprocess1 = emnist_preprocess()
|
| 35 |
+
preprocess = torchvision.transforms.Compose(
|
| 36 |
+
[
|
| 37 |
+
preprocess,
|
| 38 |
+
preprocess1,
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
# if not os.path.exists(location):
|
| 42 |
+
# os.makedirs(location, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
self.train_dataset = datasets.EMNIST(
|
| 45 |
+
root=location,
|
| 46 |
+
download=True,
|
| 47 |
+
split="digits",
|
| 48 |
+
transform=preprocess,
|
| 49 |
+
train=True,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 53 |
+
self.train_dataset,
|
| 54 |
+
batch_size=batch_size,
|
| 55 |
+
shuffle=True,
|
| 56 |
+
num_workers=num_workers,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.test_dataset = datasets.EMNIST(
|
| 60 |
+
root=location,
|
| 61 |
+
download=True,
|
| 62 |
+
split="digits",
|
| 63 |
+
transform=preprocess,
|
| 64 |
+
train=False,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 68 |
+
self.test_dataset,
|
| 69 |
+
batch_size=32,
|
| 70 |
+
shuffle=False,
|
| 71 |
+
num_workers=num_workers,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.classnames = self.train_dataset.classes
|
src/datasets/eurosat.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.datasets as datasets
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
def pretify_classname(classname):
|
| 7 |
+
l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname)
|
| 8 |
+
l = [i.lower() for i in l]
|
| 9 |
+
out = ' '.join(l)
|
| 10 |
+
if out.endswith('al'):
|
| 11 |
+
return out + ' area'
|
| 12 |
+
return out
|
| 13 |
+
|
| 14 |
+
class EuroSATBase:
|
| 15 |
+
def __init__(self,
|
| 16 |
+
preprocess,
|
| 17 |
+
test_split,
|
| 18 |
+
location='~/datasets',
|
| 19 |
+
batch_size=32,
|
| 20 |
+
num_workers=16):
|
| 21 |
+
# Data loading code
|
| 22 |
+
traindir = os.path.join(location, 'EuroSAT_splits', 'train')
|
| 23 |
+
testdir = os.path.join(location, 'EuroSAT_splits', test_split)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess)
|
| 27 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 28 |
+
self.train_dataset,
|
| 29 |
+
shuffle=True,
|
| 30 |
+
batch_size=batch_size,
|
| 31 |
+
num_workers=num_workers,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess)
|
| 35 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 36 |
+
self.test_dataset,
|
| 37 |
+
batch_size=batch_size,
|
| 38 |
+
num_workers=num_workers
|
| 39 |
+
)
|
| 40 |
+
idx_to_class = dict((v, k)
|
| 41 |
+
for k, v in self.train_dataset.class_to_idx.items())
|
| 42 |
+
self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))]
|
| 43 |
+
self.classnames = [pretify_classname(c) for c in self.classnames]
|
| 44 |
+
ours_to_open_ai = {
|
| 45 |
+
'annual crop': 'annual crop land',
|
| 46 |
+
'forest': 'forest',
|
| 47 |
+
'herbaceous vegetation': 'brushland or shrubland',
|
| 48 |
+
'highway': 'highway or road',
|
| 49 |
+
'industrial area': 'industrial buildings or commercial buildings',
|
| 50 |
+
'pasture': 'pasture land',
|
| 51 |
+
'permanent crop': 'permanent crop land',
|
| 52 |
+
'residential area': 'residential buildings or homes or apartments',
|
| 53 |
+
'river': 'river',
|
| 54 |
+
'sea lake': 'lake or sea',
|
| 55 |
+
}
|
| 56 |
+
for i in range(len(self.classnames)):
|
| 57 |
+
self.classnames[i] = ours_to_open_ai[self.classnames[i]]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class EuroSAT(EuroSATBase):
|
| 61 |
+
def __init__(self,
|
| 62 |
+
preprocess,
|
| 63 |
+
location='~/datasets',
|
| 64 |
+
batch_size=32,
|
| 65 |
+
num_workers=16):
|
| 66 |
+
super().__init__(preprocess, 'test', location, batch_size, num_workers)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class EuroSATVal(EuroSATBase):
|
| 70 |
+
def __init__(self,
|
| 71 |
+
preprocess,
|
| 72 |
+
location='~/datasets',
|
| 73 |
+
batch_size=32,
|
| 74 |
+
num_workers=16):
|
| 75 |
+
super().__init__(preprocess, 'val', location, batch_size, num_workers)
|
src/datasets/gtsrb.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
import pathlib
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import PIL
|
| 8 |
+
import torch
|
| 9 |
+
from torchvision.datasets.folder import make_dataset
|
| 10 |
+
from torchvision.datasets.utils import (download_and_extract_archive,
|
| 11 |
+
verify_str_arg)
|
| 12 |
+
from torchvision.datasets.vision import VisionDataset
|
| 13 |
+
|
| 14 |
+
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
|
| 15 |
+
"""Finds the class folders in a dataset.
|
| 16 |
+
|
| 17 |
+
See :class:`DatasetFolder` for details.
|
| 18 |
+
"""
|
| 19 |
+
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
|
| 20 |
+
if not classes:
|
| 21 |
+
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
|
| 22 |
+
|
| 23 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
| 24 |
+
return classes, class_to_idx
|
| 25 |
+
|
| 26 |
+
class PyTorchGTSRB(VisionDataset):
|
| 27 |
+
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
|
| 28 |
+
|
| 29 |
+
Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
root (string): Root directory of the dataset.
|
| 33 |
+
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
|
| 34 |
+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
|
| 35 |
+
version. E.g, ``transforms.RandomCrop``.
|
| 36 |
+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
|
| 37 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 38 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 39 |
+
downloaded again.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
root: str,
|
| 45 |
+
split: str = "train",
|
| 46 |
+
transform: Optional[Callable] = None,
|
| 47 |
+
target_transform: Optional[Callable] = None,
|
| 48 |
+
download: bool = False,
|
| 49 |
+
) -> None:
|
| 50 |
+
|
| 51 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 52 |
+
|
| 53 |
+
self._split = verify_str_arg(split, "split", ("train", "test"))
|
| 54 |
+
self._base_folder = pathlib.Path(root) / "gtsrb"
|
| 55 |
+
self._target_folder = (
|
| 56 |
+
self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if download:
|
| 60 |
+
self.download()
|
| 61 |
+
|
| 62 |
+
if not self._check_exists():
|
| 63 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 64 |
+
|
| 65 |
+
if self._split == "train":
|
| 66 |
+
_, class_to_idx = find_classes(str(self._target_folder))
|
| 67 |
+
samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx)
|
| 68 |
+
else:
|
| 69 |
+
with open(self._base_folder / "GT-final_test.csv") as csv_file:
|
| 70 |
+
samples = [
|
| 71 |
+
(str(self._target_folder / row["Filename"]), int(row["ClassId"]))
|
| 72 |
+
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
self._samples = samples
|
| 76 |
+
self.transform = transform
|
| 77 |
+
self.target_transform = target_transform
|
| 78 |
+
|
| 79 |
+
def __len__(self) -> int:
|
| 80 |
+
return len(self._samples)
|
| 81 |
+
|
| 82 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 83 |
+
|
| 84 |
+
path, target = self._samples[index]
|
| 85 |
+
sample = PIL.Image.open(path).convert("RGB")
|
| 86 |
+
|
| 87 |
+
if self.transform is not None:
|
| 88 |
+
sample = self.transform(sample)
|
| 89 |
+
|
| 90 |
+
if self.target_transform is not None:
|
| 91 |
+
target = self.target_transform(target)
|
| 92 |
+
|
| 93 |
+
return sample, target
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _check_exists(self) -> bool:
|
| 97 |
+
return self._target_folder.is_dir()
|
| 98 |
+
|
| 99 |
+
def download(self) -> None:
|
| 100 |
+
if self._check_exists():
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
|
| 104 |
+
|
| 105 |
+
if self._split == "train":
|
| 106 |
+
download_and_extract_archive(
|
| 107 |
+
f"{base_url}GTSRB-Training_fixed.zip",
|
| 108 |
+
download_root=str(self._base_folder),
|
| 109 |
+
md5="513f3c79a4c5141765e10e952eaa2478",
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
download_and_extract_archive(
|
| 113 |
+
f"{base_url}GTSRB_Final_Test_Images.zip",
|
| 114 |
+
download_root=str(self._base_folder),
|
| 115 |
+
md5="c7e4e6327067d32654124b0fe9e82185",
|
| 116 |
+
)
|
| 117 |
+
download_and_extract_archive(
|
| 118 |
+
f"{base_url}GTSRB_Final_Test_GT.zip",
|
| 119 |
+
download_root=str(self._base_folder),
|
| 120 |
+
md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class GTSRB:
|
| 125 |
+
def __init__(self,
|
| 126 |
+
preprocess,
|
| 127 |
+
location=os.path.expanduser('~/data'),
|
| 128 |
+
batch_size=128,
|
| 129 |
+
num_workers=16):
|
| 130 |
+
|
| 131 |
+
# to fit with repo conventions for location
|
| 132 |
+
self.train_dataset = PyTorchGTSRB(
|
| 133 |
+
root=location,
|
| 134 |
+
download=True,
|
| 135 |
+
split='train',
|
| 136 |
+
transform=preprocess
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 140 |
+
self.train_dataset,
|
| 141 |
+
batch_size=batch_size,
|
| 142 |
+
shuffle=True,
|
| 143 |
+
num_workers=num_workers
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self.test_dataset = PyTorchGTSRB(
|
| 147 |
+
root=location,
|
| 148 |
+
download=True,
|
| 149 |
+
split='test',
|
| 150 |
+
transform=preprocess
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 154 |
+
self.test_dataset,
|
| 155 |
+
batch_size=batch_size,
|
| 156 |
+
shuffle=False,
|
| 157 |
+
num_workers=num_workers
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md
|
| 161 |
+
self.classnames = [
|
| 162 |
+
'red and white circle 20 kph speed limit',
|
| 163 |
+
'red and white circle 30 kph speed limit',
|
| 164 |
+
'red and white circle 50 kph speed limit',
|
| 165 |
+
'red and white circle 60 kph speed limit',
|
| 166 |
+
'red and white circle 70 kph speed limit',
|
| 167 |
+
'red and white circle 80 kph speed limit',
|
| 168 |
+
'end / de-restriction of 80 kph speed limit',
|
| 169 |
+
'red and white circle 100 kph speed limit',
|
| 170 |
+
'red and white circle 120 kph speed limit',
|
| 171 |
+
'red and white circle red car and black car no passing',
|
| 172 |
+
'red and white circle red truck and black car no passing',
|
| 173 |
+
'red and white triangle road intersection warning',
|
| 174 |
+
'white and yellow diamond priority road',
|
| 175 |
+
'red and white upside down triangle yield right-of-way',
|
| 176 |
+
'stop',
|
| 177 |
+
'empty red and white circle',
|
| 178 |
+
'red and white circle no truck entry',
|
| 179 |
+
'red circle with white horizonal stripe no entry',
|
| 180 |
+
'red and white triangle with exclamation mark warning',
|
| 181 |
+
'red and white triangle with black left curve approaching warning',
|
| 182 |
+
'red and white triangle with black right curve approaching warning',
|
| 183 |
+
'red and white triangle with black double curve approaching warning',
|
| 184 |
+
'red and white triangle rough / bumpy road warning',
|
| 185 |
+
'red and white triangle car skidding / slipping warning',
|
| 186 |
+
'red and white triangle with merging / narrow lanes warning',
|
| 187 |
+
'red and white triangle with person digging / construction / road work warning',
|
| 188 |
+
'red and white triangle with traffic light approaching warning',
|
| 189 |
+
'red and white triangle with person walking warning',
|
| 190 |
+
'red and white triangle with child and person walking warning',
|
| 191 |
+
'red and white triangle with bicyle warning',
|
| 192 |
+
'red and white triangle with snowflake / ice warning',
|
| 193 |
+
'red and white triangle with deer warning',
|
| 194 |
+
'white circle with gray strike bar no speed limit',
|
| 195 |
+
'blue circle with white right turn arrow mandatory',
|
| 196 |
+
'blue circle with white left turn arrow mandatory',
|
| 197 |
+
'blue circle with white forward arrow mandatory',
|
| 198 |
+
'blue circle with white forward or right turn arrow mandatory',
|
| 199 |
+
'blue circle with white forward or left turn arrow mandatory',
|
| 200 |
+
'blue circle with white keep right arrow mandatory',
|
| 201 |
+
'blue circle with white keep left arrow mandatory',
|
| 202 |
+
'blue circle with white arrows indicating a traffic circle',
|
| 203 |
+
'white circle with gray strike bar indicating no passing for cars has ended',
|
| 204 |
+
'white circle with gray strike bar indicating no passing for trucks has ended',
|
| 205 |
+
]
|
src/datasets/imagenet.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from .common import ImageFolderWithPaths, SubsetSampler
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
imagenet_classnames = [
|
| 9 |
+
"tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
|
| 10 |
+
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
|
| 11 |
+
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
|
| 12 |
+
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
|
| 13 |
+
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
|
| 14 |
+
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
|
| 15 |
+
"box turtle", "banded gecko", "green iguana", "Carolina anole",
|
| 16 |
+
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
|
| 17 |
+
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
|
| 18 |
+
"American alligator", "triceratops", "worm snake", "ring-necked snake",
|
| 19 |
+
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
|
| 20 |
+
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
|
| 21 |
+
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
|
| 22 |
+
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
|
| 23 |
+
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
|
| 24 |
+
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
|
| 25 |
+
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
|
| 26 |
+
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
|
| 27 |
+
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
|
| 28 |
+
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
|
| 29 |
+
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
|
| 30 |
+
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
|
| 31 |
+
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
|
| 32 |
+
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
|
| 33 |
+
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
|
| 34 |
+
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
|
| 35 |
+
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
|
| 36 |
+
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
|
| 37 |
+
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
|
| 38 |
+
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
|
| 39 |
+
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
|
| 40 |
+
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
|
| 41 |
+
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
|
| 42 |
+
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
|
| 43 |
+
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
|
| 44 |
+
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
|
| 45 |
+
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
|
| 46 |
+
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
|
| 47 |
+
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
|
| 48 |
+
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
|
| 49 |
+
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
|
| 50 |
+
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
|
| 51 |
+
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
|
| 52 |
+
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
|
| 53 |
+
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
|
| 54 |
+
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
|
| 55 |
+
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
|
| 56 |
+
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
|
| 57 |
+
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
|
| 58 |
+
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
|
| 59 |
+
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
|
| 60 |
+
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
|
| 61 |
+
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
|
| 62 |
+
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
|
| 63 |
+
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
|
| 64 |
+
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
|
| 65 |
+
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
|
| 66 |
+
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
|
| 67 |
+
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
|
| 68 |
+
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
|
| 69 |
+
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
|
| 70 |
+
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
|
| 71 |
+
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
|
| 72 |
+
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
|
| 73 |
+
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
|
| 74 |
+
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
|
| 75 |
+
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
|
| 76 |
+
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
|
| 77 |
+
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
|
| 78 |
+
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
|
| 79 |
+
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
|
| 80 |
+
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
|
| 81 |
+
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
|
| 82 |
+
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
|
| 83 |
+
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
|
| 84 |
+
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
|
| 85 |
+
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
|
| 86 |
+
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
|
| 87 |
+
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
|
| 88 |
+
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
|
| 89 |
+
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
|
| 90 |
+
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
|
| 91 |
+
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
|
| 92 |
+
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
|
| 93 |
+
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
|
| 94 |
+
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
|
| 95 |
+
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
|
| 96 |
+
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
|
| 97 |
+
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
|
| 98 |
+
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
|
| 99 |
+
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
|
| 100 |
+
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
|
| 101 |
+
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
|
| 102 |
+
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
|
| 103 |
+
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
|
| 104 |
+
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
|
| 105 |
+
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
|
| 106 |
+
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
|
| 107 |
+
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
|
| 108 |
+
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
|
| 109 |
+
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
|
| 110 |
+
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
|
| 111 |
+
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
|
| 112 |
+
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
|
| 113 |
+
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
|
| 114 |
+
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
|
| 115 |
+
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
|
| 116 |
+
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
|
| 117 |
+
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
|
| 118 |
+
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
|
| 119 |
+
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
|
| 120 |
+
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
|
| 121 |
+
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
|
| 122 |
+
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
|
| 123 |
+
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
|
| 124 |
+
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
|
| 125 |
+
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
|
| 126 |
+
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
|
| 127 |
+
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
|
| 128 |
+
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
|
| 129 |
+
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
|
| 130 |
+
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
|
| 131 |
+
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
|
| 132 |
+
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
|
| 133 |
+
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
|
| 134 |
+
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
|
| 135 |
+
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
|
| 136 |
+
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
|
| 137 |
+
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
|
| 138 |
+
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
|
| 139 |
+
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
|
| 140 |
+
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
|
| 141 |
+
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
|
| 142 |
+
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
|
| 143 |
+
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
|
| 144 |
+
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
|
| 145 |
+
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
|
| 146 |
+
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
|
| 147 |
+
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
|
| 148 |
+
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
|
| 149 |
+
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
|
| 150 |
+
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
|
| 151 |
+
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
|
| 152 |
+
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
|
| 153 |
+
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
|
| 154 |
+
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
|
| 155 |
+
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
|
| 156 |
+
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
|
| 157 |
+
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
|
| 158 |
+
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
|
| 159 |
+
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
|
| 160 |
+
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
|
| 161 |
+
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
|
| 162 |
+
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
|
| 163 |
+
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
|
| 164 |
+
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
|
| 165 |
+
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
|
| 166 |
+
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
|
| 167 |
+
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
|
| 168 |
+
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
|
| 169 |
+
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
|
| 170 |
+
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
|
| 171 |
+
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
|
| 172 |
+
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
|
| 173 |
+
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
class ImageNet:
|
| 177 |
+
def __init__(self,
|
| 178 |
+
preprocess,
|
| 179 |
+
location='~/data',
|
| 180 |
+
batch_size=32,
|
| 181 |
+
num_workers=32):
|
| 182 |
+
self.preprocess = preprocess
|
| 183 |
+
self.location = '/path/ImageNet2012/' # TODO
|
| 184 |
+
self.batch_size = batch_size
|
| 185 |
+
self.num_workers = num_workers
|
| 186 |
+
self.classnames = imagenet_classnames
|
| 187 |
+
|
| 188 |
+
self.populate_train()
|
| 189 |
+
self.populate_test()
|
| 190 |
+
|
| 191 |
+
def populate_train(self):
|
| 192 |
+
traindir = os.path.join(self.location, 'train')
|
| 193 |
+
self.train_dataset = ImageFolderWithPaths(
|
| 194 |
+
traindir,
|
| 195 |
+
transform=self.preprocess)
|
| 196 |
+
sampler = self.get_train_sampler()
|
| 197 |
+
kwargs = {'shuffle' : True} if sampler is None else {}
|
| 198 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 199 |
+
self.train_dataset,
|
| 200 |
+
sampler=sampler,
|
| 201 |
+
batch_size=self.batch_size,
|
| 202 |
+
num_workers=self.num_workers,
|
| 203 |
+
**kwargs,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def populate_test(self):
|
| 207 |
+
self.test_dataset = self.get_test_dataset()
|
| 208 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 209 |
+
self.test_dataset,
|
| 210 |
+
batch_size=self.batch_size,
|
| 211 |
+
num_workers=self.num_workers,
|
| 212 |
+
sampler=self.get_test_sampler()
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def get_test_path(self):
|
| 216 |
+
test_path = os.path.join(self.location, 'val_dir')
|
| 217 |
+
if not os.path.exists(test_path):
|
| 218 |
+
test_path = os.path.join(self.location,'val')
|
| 219 |
+
return test_path
|
| 220 |
+
|
| 221 |
+
def get_train_sampler(self):
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
def get_test_sampler(self):
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
def get_test_dataset(self):
|
| 228 |
+
return ImageFolderWithPaths(self.get_test_path(), transform=self.preprocess)
|
| 229 |
+
|
| 230 |
+
def name(self):
|
| 231 |
+
return 'imagenet'
|
| 232 |
+
|
| 233 |
+
class ImageNetTrain(ImageNet):
|
| 234 |
+
|
| 235 |
+
def get_test_dataset(self):
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
class ImageNetK(ImageNet):
|
| 239 |
+
|
| 240 |
+
def get_train_sampler(self):
|
| 241 |
+
idxs = np.zeros(len(self.train_dataset.targets))
|
| 242 |
+
target_array = np.array(self.train_dataset.targets)
|
| 243 |
+
for c in range(1000):
|
| 244 |
+
m = target_array == c
|
| 245 |
+
n = len(idxs[m])
|
| 246 |
+
arr = np.zeros(n)
|
| 247 |
+
arr[:self.k()] = 1
|
| 248 |
+
np.random.shuffle(arr)
|
| 249 |
+
idxs[m] = arr
|
| 250 |
+
|
| 251 |
+
idxs = idxs.astype('int')
|
| 252 |
+
sampler = SubsetSampler(np.where(idxs)[0])
|
| 253 |
+
return sampler
|
src/datasets/kmnist.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.datasets as datasets
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class KMNIST:
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
preprocess,
|
| 11 |
+
location=os.path.expanduser("~/data"),
|
| 12 |
+
batch_size=128,
|
| 13 |
+
num_workers=6,
|
| 14 |
+
):
|
| 15 |
+
|
| 16 |
+
location = os.path.join(location, "KMNIST")
|
| 17 |
+
self.train_dataset = datasets.KMNIST(
|
| 18 |
+
root=location, download=True, train=True, transform=preprocess
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 22 |
+
self.train_dataset,
|
| 23 |
+
batch_size=batch_size,
|
| 24 |
+
shuffle=True,
|
| 25 |
+
num_workers=num_workers,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.test_dataset = datasets.KMNIST(
|
| 29 |
+
root=location, download=True, train=False, transform=preprocess
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 33 |
+
self.test_dataset,
|
| 34 |
+
batch_size=batch_size,
|
| 35 |
+
shuffle=False,
|
| 36 |
+
num_workers=num_workers,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.classnames = self.train_dataset.classes
|
src/datasets/mnist.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.datasets as datasets
|
| 4 |
+
|
| 5 |
+
class MNIST:
|
| 6 |
+
def __init__(self,
|
| 7 |
+
preprocess,
|
| 8 |
+
location=os.path.expanduser('~/data'),
|
| 9 |
+
batch_size=128,
|
| 10 |
+
num_workers=16):
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
self.train_dataset = datasets.MNIST(
|
| 14 |
+
root=location,
|
| 15 |
+
download=True,
|
| 16 |
+
train=True,
|
| 17 |
+
transform=preprocess
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 21 |
+
self.train_dataset,
|
| 22 |
+
batch_size=batch_size,
|
| 23 |
+
shuffle=True,
|
| 24 |
+
num_workers=num_workers
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.test_dataset = datasets.MNIST(
|
| 28 |
+
root=location,
|
| 29 |
+
download=True,
|
| 30 |
+
train=False,
|
| 31 |
+
transform=preprocess
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 35 |
+
self.test_dataset,
|
| 36 |
+
batch_size=batch_size,
|
| 37 |
+
shuffle=False,
|
| 38 |
+
num_workers=num_workers
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
|
src/datasets/oxfordpets.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.datasets as datasets
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class OxfordIIITPet:
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
preprocess,
|
| 10 |
+
location=os.path.expanduser("~/data"),
|
| 11 |
+
batch_size=128,
|
| 12 |
+
num_workers=6,
|
| 13 |
+
):
|
| 14 |
+
|
| 15 |
+
location = os.path.join(location, "OxfordIIITPet")
|
| 16 |
+
self.train_dataset = datasets.OxfordIIITPet(
|
| 17 |
+
root=location, download=True, split="trainval", transform=preprocess
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 21 |
+
self.train_dataset,
|
| 22 |
+
batch_size=batch_size,
|
| 23 |
+
shuffle=True,
|
| 24 |
+
num_workers=num_workers,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.test_dataset = datasets.OxfordIIITPet(
|
| 28 |
+
root=location, download=True, split="test", transform=preprocess
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 32 |
+
self.test_dataset,
|
| 33 |
+
batch_size=batch_size,
|
| 34 |
+
shuffle=False,
|
| 35 |
+
num_workers=num_workers,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.classnames = self.train_dataset.classes
|
src/datasets/registry.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import inspect
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
from torch.utils.data.dataset import random_split
|
| 8 |
+
|
| 9 |
+
from src.datasets.cars import Cars
|
| 10 |
+
from src.datasets.cifar10 import CIFAR10
|
| 11 |
+
from src.datasets.cifar100 import CIFAR100
|
| 12 |
+
from src.datasets.dtd import DTD
|
| 13 |
+
from src.datasets.eurosat import EuroSAT, EuroSATVal
|
| 14 |
+
from src.datasets.gtsrb import GTSRB
|
| 15 |
+
from src.datasets.imagenet import ImageNet
|
| 16 |
+
from src.datasets.mnist import MNIST
|
| 17 |
+
from src.datasets.resisc45 import RESISC45
|
| 18 |
+
from src.datasets.stl10 import STL10
|
| 19 |
+
from src.datasets.svhn import SVHN
|
| 20 |
+
from src.datasets.sun397 import SUN397
|
| 21 |
+
from src.datasets.emnist import EMNIST
|
| 22 |
+
from src.datasets.kmnist import KMNIST
|
| 23 |
+
from src.datasets.oxfordpets import OxfordIIITPet
|
| 24 |
+
|
| 25 |
+
registry = {
|
| 26 |
+
name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class GenericDataset(object):
|
| 31 |
+
def __init__(self):
|
| 32 |
+
self.train_dataset = None
|
| 33 |
+
self.train_loader = None
|
| 34 |
+
self.test_dataset = None
|
| 35 |
+
self.test_loader = None
|
| 36 |
+
self.classnames = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0):
|
| 40 |
+
assert val_fraction > 0. and val_fraction < 1.
|
| 41 |
+
total_size = len(dataset.train_dataset)
|
| 42 |
+
val_size = int(total_size * val_fraction)
|
| 43 |
+
if max_val_samples is not None:
|
| 44 |
+
val_size = min(val_size, max_val_samples)
|
| 45 |
+
train_size = total_size - val_size
|
| 46 |
+
|
| 47 |
+
assert val_size > 0
|
| 48 |
+
assert train_size > 0
|
| 49 |
+
|
| 50 |
+
lengths = [train_size, val_size]
|
| 51 |
+
|
| 52 |
+
trainset, valset = random_split(
|
| 53 |
+
dataset.train_dataset,
|
| 54 |
+
lengths,
|
| 55 |
+
generator=torch.Generator().manual_seed(seed)
|
| 56 |
+
)
|
| 57 |
+
if new_dataset_class_name == 'MNISTVal':
|
| 58 |
+
assert trainset.indices[0] == 36044
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
new_dataset = None
|
| 62 |
+
|
| 63 |
+
new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {})
|
| 64 |
+
new_dataset = new_dataset_class()
|
| 65 |
+
|
| 66 |
+
new_dataset.train_dataset = trainset
|
| 67 |
+
new_dataset.train_loader = torch.utils.data.DataLoader(
|
| 68 |
+
new_dataset.train_dataset,
|
| 69 |
+
shuffle=True,
|
| 70 |
+
batch_size=batch_size,
|
| 71 |
+
num_workers=num_workers,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
new_dataset.test_dataset = valset
|
| 75 |
+
new_dataset.test_loader = torch.utils.data.DataLoader(
|
| 76 |
+
new_dataset.test_dataset,
|
| 77 |
+
batch_size=batch_size,
|
| 78 |
+
num_workers=num_workers
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
new_dataset.classnames = copy.copy(dataset.classnames)
|
| 82 |
+
|
| 83 |
+
return new_dataset
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000):
|
| 87 |
+
if dataset_name.endswith('Val'):
|
| 88 |
+
# Handle val splits
|
| 89 |
+
if dataset_name in registry:
|
| 90 |
+
dataset_class = registry[dataset_name]
|
| 91 |
+
else:
|
| 92 |
+
base_dataset_name = dataset_name.split('Val')[0]
|
| 93 |
+
base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers)
|
| 94 |
+
dataset = split_train_into_train_val(
|
| 95 |
+
base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples)
|
| 96 |
+
return dataset
|
| 97 |
+
else:
|
| 98 |
+
assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}'
|
| 99 |
+
dataset_class = registry[dataset_name]
|
| 100 |
+
dataset = dataset_class(
|
| 101 |
+
preprocess, location=location, batch_size=batch_size, num_workers=num_workers
|
| 102 |
+
)
|
| 103 |
+
return dataset
|
src/datasets/resisc45.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import abc
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
from torchvision.datasets import ImageFolder
|
| 13 |
+
from torchvision.datasets.folder import default_loader as pil_loader
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# modified from: https://github.com/microsoft/torchgeo
|
| 17 |
+
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
|
| 18 |
+
"""Abstract base class for datasets lacking geospatial information.
|
| 19 |
+
This base class is designed for datasets with pre-defined image chips.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
@abc.abstractmethod
|
| 23 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
| 24 |
+
"""Return an index within the dataset.
|
| 25 |
+
Args:
|
| 26 |
+
index: index to return
|
| 27 |
+
Returns:
|
| 28 |
+
data and labels at that index
|
| 29 |
+
Raises:
|
| 30 |
+
IndexError: if index is out of range of the dataset
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
@abc.abstractmethod
|
| 34 |
+
def __len__(self) -> int:
|
| 35 |
+
"""Return the length of the dataset.
|
| 36 |
+
Returns:
|
| 37 |
+
length of the dataset
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __str__(self) -> str:
|
| 41 |
+
"""Return the informal string representation of the object.
|
| 42 |
+
Returns:
|
| 43 |
+
informal string representation
|
| 44 |
+
"""
|
| 45 |
+
return f"""\
|
| 46 |
+
{self.__class__.__name__} Dataset
|
| 47 |
+
type: VisionDataset
|
| 48 |
+
size: {len(self)}"""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class VisionClassificationDataset(VisionDataset, ImageFolder):
|
| 52 |
+
"""Abstract base class for classification datasets lacking geospatial information.
|
| 53 |
+
This base class is designed for datasets with pre-defined image chips which
|
| 54 |
+
are separated into separate folders per class.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
root: str,
|
| 60 |
+
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
| 61 |
+
loader: Optional[Callable[[str], Any]] = pil_loader,
|
| 62 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
| 63 |
+
) -> None:
|
| 64 |
+
"""Initialize a new VisionClassificationDataset instance.
|
| 65 |
+
Args:
|
| 66 |
+
root: root directory where dataset can be found
|
| 67 |
+
transforms: a function/transform that takes input sample and its target as
|
| 68 |
+
entry and returns a transformed version
|
| 69 |
+
loader: a callable function which takes as input a path to an image and
|
| 70 |
+
returns a PIL Image or numpy array
|
| 71 |
+
is_valid_file: A function that takes the path of an Image file and checks if
|
| 72 |
+
the file is a valid file
|
| 73 |
+
"""
|
| 74 |
+
# When transform & target_transform are None, ImageFolder.__getitem__(index)
|
| 75 |
+
# returns a PIL.Image and int for image and label, respectively
|
| 76 |
+
super().__init__(
|
| 77 |
+
root=root,
|
| 78 |
+
transform=None,
|
| 79 |
+
target_transform=None,
|
| 80 |
+
loader=loader,
|
| 81 |
+
is_valid_file=is_valid_file,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Must be set after calling super().__init__()
|
| 85 |
+
self.transforms = transforms
|
| 86 |
+
|
| 87 |
+
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
| 88 |
+
"""Return an index within the dataset.
|
| 89 |
+
Args:
|
| 90 |
+
index: index to return
|
| 91 |
+
Returns:
|
| 92 |
+
data and label at that index
|
| 93 |
+
"""
|
| 94 |
+
image, label = self._load_image(index)
|
| 95 |
+
|
| 96 |
+
if self.transforms is not None:
|
| 97 |
+
return self.transforms(image), label
|
| 98 |
+
|
| 99 |
+
return image, label
|
| 100 |
+
|
| 101 |
+
def __len__(self) -> int:
|
| 102 |
+
"""Return the number of data points in the dataset.
|
| 103 |
+
Returns:
|
| 104 |
+
length of the dataset
|
| 105 |
+
"""
|
| 106 |
+
return len(self.imgs)
|
| 107 |
+
|
| 108 |
+
def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
|
| 109 |
+
"""Load a single image and it's class label.
|
| 110 |
+
Args:
|
| 111 |
+
index: index to return
|
| 112 |
+
Returns:
|
| 113 |
+
the image
|
| 114 |
+
the image class label
|
| 115 |
+
"""
|
| 116 |
+
img, label = ImageFolder.__getitem__(self, index)
|
| 117 |
+
label = torch.tensor(label)
|
| 118 |
+
return img, label
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class RESISC45Dataset(VisionClassificationDataset):
|
| 122 |
+
"""RESISC45 dataset.
|
| 123 |
+
The `RESISC45 <http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html>`_
|
| 124 |
+
dataset is a dataset for remote sensing image scene classification.
|
| 125 |
+
Dataset features:
|
| 126 |
+
* 31,500 images with 0.2-30 m per pixel resolution (256x256 px)
|
| 127 |
+
* three spectral bands - RGB
|
| 128 |
+
* 45 scene classes, 700 images per class
|
| 129 |
+
* images extracted from Google Earth from over 100 countries
|
| 130 |
+
* images conditions with high variability (resolution, weather, illumination)
|
| 131 |
+
Dataset format:
|
| 132 |
+
* images are three-channel jpgs
|
| 133 |
+
Dataset classes:
|
| 134 |
+
0. airplane
|
| 135 |
+
1. airport
|
| 136 |
+
2. baseball_diamond
|
| 137 |
+
3. basketball_court
|
| 138 |
+
4. beach
|
| 139 |
+
5. bridge
|
| 140 |
+
6. chaparral
|
| 141 |
+
7. church
|
| 142 |
+
8. circular_farmland
|
| 143 |
+
9. cloud
|
| 144 |
+
10. commercial_area
|
| 145 |
+
11. dense_residential
|
| 146 |
+
12. desert
|
| 147 |
+
13. forest
|
| 148 |
+
14. freeway
|
| 149 |
+
15. golf_course
|
| 150 |
+
16. ground_track_field
|
| 151 |
+
17. harbor
|
| 152 |
+
18. industrial_area
|
| 153 |
+
19. intersection
|
| 154 |
+
20. island
|
| 155 |
+
21. lake
|
| 156 |
+
22. meadow
|
| 157 |
+
23. medium_residential
|
| 158 |
+
24. mobile_home_park
|
| 159 |
+
25. mountain
|
| 160 |
+
26. overpass
|
| 161 |
+
27. palace
|
| 162 |
+
28. parking_lot
|
| 163 |
+
29. railway
|
| 164 |
+
30. railway_station
|
| 165 |
+
31. rectangular_farmland
|
| 166 |
+
32. river
|
| 167 |
+
33. roundabout
|
| 168 |
+
34. runway
|
| 169 |
+
35. sea_ice
|
| 170 |
+
36. ship
|
| 171 |
+
37. snowberg
|
| 172 |
+
38. sparse_residential
|
| 173 |
+
39. stadium
|
| 174 |
+
40. storage_tank
|
| 175 |
+
41. tennis_court
|
| 176 |
+
42. terrace
|
| 177 |
+
43. thermal_power_station
|
| 178 |
+
44. wetland
|
| 179 |
+
This dataset uses the train/val/test splits defined in the "In-domain representation
|
| 180 |
+
learning for remote sensing" paper:
|
| 181 |
+
* https://arxiv.org/abs/1911.06721
|
| 182 |
+
If you use this dataset in your research, please cite the following paper:
|
| 183 |
+
* https://doi.org/10.1109/jproc.2017.2675998
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
# url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
|
| 187 |
+
# md5 = "d824acb73957502b00efd559fc6cfbbb"
|
| 188 |
+
# filename = "NWPU-RESISC45.rar"
|
| 189 |
+
directory = "resisc45/NWPU-RESISC45"
|
| 190 |
+
|
| 191 |
+
splits = ["train", "val", "test"]
|
| 192 |
+
split_urls = {
|
| 193 |
+
"train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501
|
| 194 |
+
"val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501
|
| 195 |
+
"test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501
|
| 196 |
+
}
|
| 197 |
+
split_md5s = {
|
| 198 |
+
"train": "b5a4c05a37de15e4ca886696a85c403e",
|
| 199 |
+
"val": "a0770cee4c5ca20b8c32bbd61e114805",
|
| 200 |
+
"test": "3dda9e4988b47eb1de9f07993653eb08",
|
| 201 |
+
}
|
| 202 |
+
classes = [
|
| 203 |
+
"airplane",
|
| 204 |
+
"airport",
|
| 205 |
+
"baseball_diamond",
|
| 206 |
+
"basketball_court",
|
| 207 |
+
"beach",
|
| 208 |
+
"bridge",
|
| 209 |
+
"chaparral",
|
| 210 |
+
"church",
|
| 211 |
+
"circular_farmland",
|
| 212 |
+
"cloud",
|
| 213 |
+
"commercial_area",
|
| 214 |
+
"dense_residential",
|
| 215 |
+
"desert",
|
| 216 |
+
"forest",
|
| 217 |
+
"freeway",
|
| 218 |
+
"golf_course",
|
| 219 |
+
"ground_track_field",
|
| 220 |
+
"harbor",
|
| 221 |
+
"industrial_area",
|
| 222 |
+
"intersection",
|
| 223 |
+
"island",
|
| 224 |
+
"lake",
|
| 225 |
+
"meadow",
|
| 226 |
+
"medium_residential",
|
| 227 |
+
"mobile_home_park",
|
| 228 |
+
"mountain",
|
| 229 |
+
"overpass",
|
| 230 |
+
"palace",
|
| 231 |
+
"parking_lot",
|
| 232 |
+
"railway",
|
| 233 |
+
"railway_station",
|
| 234 |
+
"rectangular_farmland",
|
| 235 |
+
"river",
|
| 236 |
+
"roundabout",
|
| 237 |
+
"runway",
|
| 238 |
+
"sea_ice",
|
| 239 |
+
"ship",
|
| 240 |
+
"snowberg",
|
| 241 |
+
"sparse_residential",
|
| 242 |
+
"stadium",
|
| 243 |
+
"storage_tank",
|
| 244 |
+
"tennis_court",
|
| 245 |
+
"terrace",
|
| 246 |
+
"thermal_power_station",
|
| 247 |
+
"wetland",
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
def __init__(
|
| 251 |
+
self,
|
| 252 |
+
root: str = "data",
|
| 253 |
+
split: str = "train",
|
| 254 |
+
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
| 255 |
+
) -> None:
|
| 256 |
+
"""Initialize a new RESISC45 dataset instance.
|
| 257 |
+
Args:
|
| 258 |
+
root: root directory where dataset can be found
|
| 259 |
+
split: one of "train", "val", or "test"
|
| 260 |
+
transforms: a function/transform that takes input sample and its target as
|
| 261 |
+
entry and returns a transformed version
|
| 262 |
+
"""
|
| 263 |
+
assert split in self.splits
|
| 264 |
+
self.root = root
|
| 265 |
+
|
| 266 |
+
valid_fns = set()
|
| 267 |
+
with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f:
|
| 268 |
+
for fn in f:
|
| 269 |
+
valid_fns.add(fn.strip())
|
| 270 |
+
is_in_split: Callable[[str], bool] = lambda x: os.path.basename(
|
| 271 |
+
x) in valid_fns
|
| 272 |
+
|
| 273 |
+
super().__init__(
|
| 274 |
+
root=os.path.join(root, self.directory),
|
| 275 |
+
transforms=transforms,
|
| 276 |
+
is_valid_file=is_in_split,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class RESISC45:
|
| 282 |
+
def __init__(self,
|
| 283 |
+
preprocess,
|
| 284 |
+
location=os.path.expanduser('~/data'),
|
| 285 |
+
batch_size=32,
|
| 286 |
+
num_workers=16):
|
| 287 |
+
|
| 288 |
+
self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess)
|
| 289 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 290 |
+
self.train_dataset,
|
| 291 |
+
shuffle=True,
|
| 292 |
+
batch_size=batch_size,
|
| 293 |
+
num_workers=num_workers,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess)
|
| 297 |
+
self.test_loader = torch.utils.data.DataLoader(
|
| 298 |
+
self.test_dataset,
|
| 299 |
+
batch_size=batch_size,
|
| 300 |
+
num_workers=num_workers
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# class names have _ so split on this for better zero-shot head
|
| 304 |
+
self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes]
|