diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4005b4a90729c9fe1b811f7388bd8453998d2322
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/__init__.py
@@ -0,0 +1,147 @@
+from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
+from ._stereo_matching import (
+ CarlaStereo,
+ CREStereo,
+ ETH3DStereo,
+ FallingThingsStereo,
+ InStereo2k,
+ Kitti2012Stereo,
+ Kitti2015Stereo,
+ Middlebury2014Stereo,
+ SceneFlowStereo,
+ SintelStereo,
+)
+from .caltech import Caltech101, Caltech256
+from .celeba import CelebA
+from .cifar import CIFAR10, CIFAR100
+from .cityscapes import Cityscapes
+from .clevr import CLEVRClassification
+from .coco import CocoCaptions, CocoDetection
+from .country211 import Country211
+from .dtd import DTD
+from .eurosat import EuroSAT
+from .fakedata import FakeData
+from .fer2013 import FER2013
+from .fgvc_aircraft import FGVCAircraft
+from .flickr import Flickr30k, Flickr8k
+from .flowers102 import Flowers102
+from .folder import DatasetFolder, ImageFolder
+from .food101 import Food101
+from .gtsrb import GTSRB
+from .hmdb51 import HMDB51
+from .imagenet import ImageNet
+from .imagenette import Imagenette
+from .inaturalist import INaturalist
+from .kinetics import Kinetics
+from .kitti import Kitti
+from .lfw import LFWPairs, LFWPeople
+from .lsun import LSUN, LSUNClass
+from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
+from .moving_mnist import MovingMNIST
+from .omniglot import Omniglot
+from .oxford_iiit_pet import OxfordIIITPet
+from .pcam import PCAM
+from .phototour import PhotoTour
+from .places365 import Places365
+from .rendered_sst2 import RenderedSST2
+from .sbd import SBDataset
+from .sbu import SBU
+from .semeion import SEMEION
+from .stanford_cars import StanfordCars
+from .stl10 import STL10
+from .sun397 import SUN397
+from .svhn import SVHN
+from .ucf101 import UCF101
+from .usps import USPS
+from .vision import VisionDataset
+from .voc import VOCDetection, VOCSegmentation
+from .widerface import WIDERFace
+
+__all__ = (
+ "LSUN",
+ "LSUNClass",
+ "ImageFolder",
+ "DatasetFolder",
+ "FakeData",
+ "CocoCaptions",
+ "CocoDetection",
+ "CIFAR10",
+ "CIFAR100",
+ "EMNIST",
+ "FashionMNIST",
+ "QMNIST",
+ "MNIST",
+ "KMNIST",
+ "MovingMNIST",
+ "StanfordCars",
+ "STL10",
+ "SUN397",
+ "SVHN",
+ "PhotoTour",
+ "SEMEION",
+ "Omniglot",
+ "SBU",
+ "Flickr8k",
+ "Flickr30k",
+ "Flowers102",
+ "VOCSegmentation",
+ "VOCDetection",
+ "Cityscapes",
+ "ImageNet",
+ "Caltech101",
+ "Caltech256",
+ "CelebA",
+ "WIDERFace",
+ "SBDataset",
+ "VisionDataset",
+ "USPS",
+ "Kinetics",
+ "HMDB51",
+ "UCF101",
+ "Places365",
+ "Kitti",
+ "INaturalist",
+ "LFWPeople",
+ "LFWPairs",
+ "KittiFlow",
+ "Sintel",
+ "FlyingChairs",
+ "FlyingThings3D",
+ "HD1K",
+ "Food101",
+ "DTD",
+ "FER2013",
+ "GTSRB",
+ "CLEVRClassification",
+ "OxfordIIITPet",
+ "PCAM",
+ "Country211",
+ "FGVCAircraft",
+ "EuroSAT",
+ "RenderedSST2",
+ "Kitti2012Stereo",
+ "Kitti2015Stereo",
+ "CarlaStereo",
+ "Middlebury2014Stereo",
+ "CREStereo",
+ "FallingThingsStereo",
+ "SceneFlowStereo",
+ "SintelStereo",
+ "InStereo2k",
+ "ETH3DStereo",
+ "wrap_dataset_for_transforms_v2",
+ "Imagenette",
+)
+
+
+# We override current module's attributes to handle the import:
+# from torchvision.datasets import wrap_dataset_for_transforms_v2
+# without a cyclic error.
+# Ref: https://peps.python.org/pep-0562/
+def __getattr__(name):
+ if name in ("wrap_dataset_for_transforms_v2",):
+ from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
+
+ return wrap_dataset_for_transforms_v2
+
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/_optical_flow.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/_optical_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..af8e17ad95c937dd679cf5f5c14f5e277ab1b1ab
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/_optical_flow.py
@@ -0,0 +1,520 @@
+import itertools
+import os
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from ..io.image import decode_png, read_file
+from .folder import default_loader
+from .utils import _read_pfm, verify_str_arg
+from .vision import VisionDataset
+
+T1 = tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
+T2 = tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+
+__all__ = (
+ "KittiFlow",
+ "Sintel",
+ "FlyingThings3D",
+ "FlyingChairs",
+ "HD1K",
+)
+
+
+class FlowDataset(ABC, VisionDataset):
+ # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
+ # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
+ # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
+ _has_builtin_flow_mask = False
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transforms: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+
+ super().__init__(root=root)
+ self.transforms = transforms
+
+ self._flow_list: list[str] = []
+ self._image_list: list[list[str]] = []
+ self._loader = loader
+
+ def _read_img(self, file_name: str) -> Union[Image.Image, torch.Tensor]:
+ return self._loader(file_name)
+
+ @abstractmethod
+ def _read_flow(self, file_name: str):
+ # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
+ pass
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+
+ img1 = self._read_img(self._image_list[index][0])
+ img2 = self._read_img(self._image_list[index][1])
+
+ if self._flow_list: # it will be empty for some dataset when split="test"
+ flow = self._read_flow(self._flow_list[index])
+ if self._has_builtin_flow_mask:
+ flow, valid_flow_mask = flow
+ else:
+ valid_flow_mask = None
+ else:
+ flow = valid_flow_mask = None
+
+ if self.transforms is not None:
+ img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
+
+ if self._has_builtin_flow_mask or valid_flow_mask is not None:
+ # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
+ return img1, img2, flow, valid_flow_mask # type: ignore[return-value]
+ else:
+ return img1, img2, flow # type: ignore[return-value]
+
+ def __len__(self) -> int:
+ return len(self._image_list)
+
+ def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
+ return torch.utils.data.ConcatDataset([self] * v)
+
+
+class Sintel(FlowDataset):
+ """`Sintel `_ Dataset for optical flow.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ Sintel
+ testing
+ clean
+ scene_1
+ scene_2
+ ...
+ final
+ scene_1
+ scene_2
+ ...
+ training
+ clean
+ scene_1
+ scene_2
+ ...
+ final
+ scene_1
+ scene_2
+ ...
+ flow
+ scene_1
+ scene_2
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
+ details on the different passes.
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ ``valid_flow_mask`` is expected for consistency with other datasets which
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ pass_name: str = "clean",
+ transforms: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root=root, transforms=transforms, loader=loader)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+ passes = ["clean", "final"] if pass_name == "both" else [pass_name]
+
+ root = Path(root) / "Sintel"
+ flow_root = root / "training" / "flow"
+
+ for pass_name in passes:
+ split_dir = "training" if split == "train" else split
+ image_root = root / split_dir / pass_name
+ for scene in os.listdir(image_root):
+ image_list = sorted(glob(str(image_root / scene / "*.png")))
+ for i in range(len(image_list) - 1):
+ self._image_list += [[image_list[i], image_list[i + 1]]]
+
+ if split == "train":
+ self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+ ``flow`` is None if ``split="test"``.
+ If a valid flow mask is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> np.ndarray:
+ return _read_flo(file_name)
+
+
+class KittiFlow(FlowDataset):
+ """`KITTI `__ dataset for optical flow (2015).
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ KittiFlow
+ testing
+ image_2
+ training
+ image_2
+ flow_occ
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _has_builtin_flow_mask = True
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transforms: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root=root, transforms=transforms, loader=loader)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ root = Path(root) / "KittiFlow" / (split + "ing")
+ images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
+ images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
+
+ if not images1 or not images2:
+ raise FileNotFoundError(
+ "Could not find the Kitti flow images. Please make sure the directory structure is correct."
+ )
+
+ for img1, img2 in zip(images1, images2):
+ self._image_list += [[img1, img2]]
+
+ if split == "train":
+ self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
+ where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
+ indicating which flow values are valid. The flow is a numpy array of
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+ ``split="test"``.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> tuple[np.ndarray, np.ndarray]:
+ return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+
+class FlyingChairs(FlowDataset):
+ """`FlyingChairs `_ Dataset for optical flow.
+
+ You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ FlyingChairs
+ data
+ 00001_flow.flo
+ 00001_img1.ppm
+ 00001_img2.ppm
+ ...
+ FlyingChairs_train_val.txt
+
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "val"
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ ``valid_flow_mask`` is expected for consistency with other datasets which
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+ """
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root=root, transforms=transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "val"))
+
+ root = Path(root) / "FlyingChairs"
+ images = sorted(glob(str(root / "data" / "*.ppm")))
+ flows = sorted(glob(str(root / "data" / "*.flo")))
+
+ split_file_name = "FlyingChairs_train_val.txt"
+
+ if not os.path.exists(root / split_file_name):
+ raise FileNotFoundError(
+ "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
+ )
+
+ split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
+ for i in range(len(flows)):
+ split_id = split_list[i]
+ if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
+ self._flow_list += [flows[i]]
+ self._image_list += [[images[2 * i], images[2 * i + 1]]]
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+ ``flow`` is None if ``split="val"``.
+ If a valid flow mask is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> np.ndarray:
+ return _read_flo(file_name)
+
+
+class FlyingThings3D(FlowDataset):
+ """`FlyingThings3D `_ dataset for optical flow.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ FlyingThings3D
+ frames_cleanpass
+ TEST
+ TRAIN
+ frames_finalpass
+ TEST
+ TRAIN
+ optical_flow
+ TEST
+ TRAIN
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
+ details on the different passes.
+ camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ ``valid_flow_mask`` is expected for consistency with other datasets which
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ pass_name: str = "clean",
+ camera: str = "left",
+ transforms: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root=root, transforms=transforms, loader=loader)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+ split = split.upper()
+
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+ passes = {
+ "clean": ["frames_cleanpass"],
+ "final": ["frames_finalpass"],
+ "both": ["frames_cleanpass", "frames_finalpass"],
+ }[pass_name]
+
+ verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
+ cameras = ["left", "right"] if camera == "both" else [camera]
+
+ root = Path(root) / "FlyingThings3D"
+
+ directions = ("into_future", "into_past")
+ for pass_name, camera, direction in itertools.product(passes, cameras, directions):
+ image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
+ image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
+
+ flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
+ flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
+
+ if not image_dirs or not flow_dirs:
+ raise FileNotFoundError(
+ "Could not find the FlyingThings3D flow images. "
+ "Please make sure the directory structure is correct."
+ )
+
+ for image_dir, flow_dir in zip(image_dirs, flow_dirs):
+ images = sorted(glob(str(image_dir / "*.png")))
+ flows = sorted(glob(str(flow_dir / "*.pfm")))
+ for i in range(len(flows) - 1):
+ if direction == "into_future":
+ self._image_list += [[images[i], images[i + 1]]]
+ self._flow_list += [flows[i]]
+ elif direction == "into_past":
+ self._image_list += [[images[i + 1], images[i]]]
+ self._flow_list += [flows[i + 1]]
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+ ``flow`` is None if ``split="test"``.
+ If a valid flow mask is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> np.ndarray:
+ return _read_pfm(file_name)
+
+
+class HD1K(FlowDataset):
+ """`HD1K `__ dataset for optical flow.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ hd1k
+ hd1k_challenge
+ image_2
+ hd1k_flow_gt
+ flow_occ
+ hd1k_input
+ image_2
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _has_builtin_flow_mask = True
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transforms: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root=root, transforms=transforms, loader=loader)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ root = Path(root) / "hd1k"
+ if split == "train":
+ # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
+ for seq_idx in range(36):
+ flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
+ images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
+ for i in range(len(flows) - 1):
+ self._flow_list += [flows[i]]
+ self._image_list += [[images[i], images[i + 1]]]
+ else:
+ images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
+ images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
+ for image1, image2 in zip(images1, images2):
+ self._image_list += [[image1, image2]]
+
+ if not self._image_list:
+ raise FileNotFoundError(
+ "Could not find the HD1K images. Please make sure the directory structure is correct."
+ )
+
+ def _read_flow(self, file_name: str) -> tuple[np.ndarray, np.ndarray]:
+ return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
+ is a numpy boolean mask of shape (H, W)
+ indicating which flow values are valid. The flow is a numpy array of
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+ ``split="test"``.
+ """
+ return super().__getitem__(index)
+
+
+def _read_flo(file_name: str) -> np.ndarray:
+ """Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+ # Everything needs to be in little Endian according to
+ # https://vision.middlebury.edu/flow/code/flow-code/README.txt
+ with open(file_name, "rb") as f:
+ magic = np.fromfile(f, "c", count=4).tobytes()
+ if magic != b"PIEH":
+ raise ValueError("Magic number incorrect. Invalid .flo file")
+
+ w = np.fromfile(f, " tuple[np.ndarray, np.ndarray]:
+
+ flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
+ flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
+ flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive
+ valid_flow_mask = valid_flow_mask.bool()
+
+ # For consistency with other datasets, we convert to numpy
+ return flow.numpy(), valid_flow_mask.numpy()
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/_stereo_matching.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/_stereo_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc2236e97b85c7647bf10b507ad83f0f34e83987
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/_stereo_matching.py
@@ -0,0 +1,1223 @@
+import functools
+import json
+import os
+import random
+import shutil
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Callable, cast, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+T1 = tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
+T2 = tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+__all__ = ()
+
+_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
+
+
+class StereoMatchingDataset(ABC, VisionDataset):
+ """Base interface for Stereo matching datasets"""
+
+ _has_built_in_disparity_mask = False
+
+ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
+ """
+ Args:
+ root(str): Root directory of the dataset.
+ transforms(callable, optional): A function/transform that takes in Tuples of
+ (images, disparities, valid_masks) and returns a transformed version of each of them.
+ images is a Tuple of (``PIL.Image``, ``PIL.Image``)
+ disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
+ valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
+ In some cases, when a dataset does not provide disparities, the ``disparities`` and
+ ``valid_masks`` can be Tuples containing None values.
+ For training splits generally the datasets provide a minimal guarantee of
+ images: (``PIL.Image``, ``PIL.Image``)
+ disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
+ Optionally, based on the dataset, it can return a ``mask`` as well:
+ valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
+ For some test splits, the datasets provides outputs that look like:
+ imgaes: (``PIL.Image``, ``PIL.Image``)
+ disparities: (``None``, ``None``)
+ Optionally, based on the dataset, it can return a ``mask`` as well:
+ valid_masks: (``None``, ``None``)
+ """
+ super().__init__(root=root)
+ self.transforms = transforms
+
+ self._images = [] # type: ignore
+ self._disparities = [] # type: ignore
+
+ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+ img = Image.open(file_path)
+ if img.mode != "RGB":
+ img = img.convert("RGB") # type: ignore [assignment]
+ return img
+
+ def _scan_pairs(
+ self,
+ paths_left_pattern: str,
+ paths_right_pattern: Optional[str] = None,
+ ) -> list[tuple[str, Optional[str]]]:
+
+ left_paths = list(sorted(glob(paths_left_pattern)))
+
+ right_paths: list[Union[None, str]]
+ if paths_right_pattern:
+ right_paths = list(sorted(glob(paths_right_pattern)))
+ else:
+ right_paths = list(None for _ in left_paths)
+
+ if not left_paths:
+ raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
+
+ if not right_paths:
+ raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
+
+ if len(left_paths) != len(right_paths):
+ raise ValueError(
+ f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
+ f"left pattern: {paths_left_pattern}\n"
+ f"right pattern: {paths_right_pattern}\n"
+ )
+
+ paths = list((left, right) for left, right in zip(left_paths, right_paths))
+ return paths
+
+ @abstractmethod
+ def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+ # function that returns a disparity map and an occlusion map
+ pass
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
+ can be a numpy boolean mask of shape (H, W) if the dataset provides a file
+ indicating which disparity pixels are valid. The disparity is a numpy array of
+ shape (1, H, W) and the images are PIL images. ``disparity`` is None for
+ datasets on which for ``split="test"`` the authors did not provide annotations.
+ """
+ img_left = self._read_img(self._images[index][0])
+ img_right = self._read_img(self._images[index][1])
+
+ dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0])
+ dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1])
+
+ imgs = (img_left, img_right)
+ dsp_maps = (dsp_map_left, dsp_map_right)
+ valid_masks = (valid_mask_left, valid_mask_right)
+
+ if self.transforms is not None:
+ (
+ imgs,
+ dsp_maps,
+ valid_masks,
+ ) = self.transforms(imgs, dsp_maps, valid_masks)
+
+ if self._has_built_in_disparity_mask or valid_masks[0] is not None:
+ return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
+ else:
+ return imgs[0], imgs[1], dsp_maps[0]
+
+ def __len__(self) -> int:
+ return len(self._images)
+
+
+class CarlaStereo(StereoMatchingDataset):
+ """
+ Carla simulator data linked in the `CREStereo github repo `_.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ carla-highres
+ trainingF
+ scene1
+ img0.png
+ img1.png
+ disp0GT.pfm
+ disp1GT.pfm
+ calib.txt
+ scene2
+ img0.png
+ img1.png
+ disp0GT.pfm
+ disp1GT.pfm
+ calib.txt
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
+ super().__init__(root, transforms)
+
+ root = Path(root) / "carla-highres"
+
+ left_image_pattern = str(root / "trainingF" / "*" / "im0.png")
+ right_image_pattern = str(root / "trainingF" / "*" / "im1.png")
+ imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+ self._images = imgs
+
+ left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm")
+ right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm")
+ disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+ self._disparities = disparities
+
+ def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+ disparity_map = _read_pfm_file(file_path)
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
+ valid_mask = None
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T1:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+ """
+ return cast(T1, super().__getitem__(index))
+
+
+class Kitti2012Stereo(StereoMatchingDataset):
+ """
+ KITTI dataset from the `2012 stereo evaluation benchmark `_.
+ Uses the RGB images for consistency with KITTI 2015.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ Kitti2012
+ testing
+ colored_0
+ 1_10.png
+ 2_10.png
+ ...
+ colored_1
+ 1_10.png
+ 2_10.png
+ ...
+ training
+ colored_0
+ 1_10.png
+ 2_10.png
+ ...
+ colored_1
+ 1_10.png
+ 2_10.png
+ ...
+ disp_noc
+ 1.png
+ 2.png
+ ...
+ calib
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
+ split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ _has_built_in_disparity_mask = True
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root, transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ root = Path(root) / "Kitti2012" / (split + "ing")
+
+ left_img_pattern = str(root / "colored_0" / "*_10.png")
+ right_img_pattern = str(root / "colored_1" / "*_10.png")
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ if split == "train":
+ disparity_pattern = str(root / "disp_noc" / "*.png")
+ self._disparities = self._scan_pairs(disparity_pattern, None)
+ else:
+ self._disparities = list((None, None) for _ in self._images)
+
+ def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], None]:
+ # test split has no disparity maps
+ if file_path is None:
+ return None, None
+
+ disparity_map = np.asarray(Image.open(file_path)) / 256.0
+ # unsqueeze the disparity map into (C, H, W) format
+ disparity_map = disparity_map[None, :, :]
+ valid_mask = None
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T1:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+ generate a valid mask.
+ Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+ """
+ return cast(T1, super().__getitem__(index))
+
+
+class Kitti2015Stereo(StereoMatchingDataset):
+ """
+ KITTI dataset from the `2015 stereo evaluation benchmark `_.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ Kitti2015
+ testing
+ image_2
+ img1.png
+ img2.png
+ ...
+ image_3
+ img1.png
+ img2.png
+ ...
+ training
+ image_2
+ img1.png
+ img2.png
+ ...
+ image_3
+ img1.png
+ img2.png
+ ...
+ disp_occ_0
+ img1.png
+ img2.png
+ ...
+ disp_occ_1
+ img1.png
+ img2.png
+ ...
+ calib
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
+ split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ _has_built_in_disparity_mask = True
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root, transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ root = Path(root) / "Kitti2015" / (split + "ing")
+ left_img_pattern = str(root / "image_2" / "*.png")
+ right_img_pattern = str(root / "image_3" / "*.png")
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ if split == "train":
+ left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
+ right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
+ self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+ else:
+ self._disparities = list((None, None) for _ in self._images)
+
+ def _read_disparity(self, file_path: str) -> tuple[Optional[np.ndarray], None]:
+ # test split has no disparity maps
+ if file_path is None:
+ return None, None
+
+ disparity_map = np.asarray(Image.open(file_path)) / 256.0
+ # unsqueeze the disparity map into (C, H, W) format
+ disparity_map = disparity_map[None, :, :]
+ valid_mask = None
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T1:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+ generate a valid mask.
+ Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+ """
+ return cast(T1, super().__getitem__(index))
+
+
+class Middlebury2014Stereo(StereoMatchingDataset):
+ """Publicly available scenes from the Middlebury dataset `2014 version `.
+
+ The dataset mostly follows the original format, without containing the ambient subdirectories. : ::
+
+ root
+ Middlebury2014
+ train
+ scene1-{perfect,imperfect}
+ calib.txt
+ im{0,1}.png
+ im1E.png
+ im1L.png
+ disp{0,1}.pfm
+ disp{0,1}-n.png
+ disp{0,1}-sd.pfm
+ disp{0,1}y.pfm
+ scene2-{perfect,imperfect}
+ calib.txt
+ im{0,1}.png
+ im1E.png
+ im1L.png
+ disp{0,1}.pfm
+ disp{0,1}-n.png
+ disp{0,1}-sd.pfm
+ disp{0,1}y.pfm
+ ...
+ additional
+ scene1-{perfect,imperfect}
+ calib.txt
+ im{0,1}.png
+ im1E.png
+ im1L.png
+ disp{0,1}.pfm
+ disp{0,1}-n.png
+ disp{0,1}-sd.pfm
+ disp{0,1}y.pfm
+ ...
+ test
+ scene1
+ calib.txt
+ im{0,1}.png
+ scene2
+ calib.txt
+ im{0,1}.png
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
+ split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
+ use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
+ The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
+ calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes.
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ download (boolean, optional): Whether or not to download the dataset in the ``root`` directory.
+ """
+
+ splits = {
+ "train": [
+ "Adirondack",
+ "Jadeplant",
+ "Motorcycle",
+ "Piano",
+ "Pipes",
+ "Playroom",
+ "Playtable",
+ "Recycle",
+ "Shelves",
+ "Vintage",
+ ],
+ "additional": [
+ "Backpack",
+ "Bicycle1",
+ "Cable",
+ "Classroom1",
+ "Couch",
+ "Flowers",
+ "Mask",
+ "Shopvac",
+ "Sticks",
+ "Storage",
+ "Sword1",
+ "Sword2",
+ "Umbrella",
+ ],
+ "test": [
+ "Plants",
+ "Classroom2E",
+ "Classroom2",
+ "Australia",
+ "DjembeL",
+ "CrusadeP",
+ "Crusade",
+ "Hoops",
+ "Bicycle2",
+ "Staircase",
+ "Newkuba",
+ "AustraliaP",
+ "Djembe",
+ "Livingroom",
+ "Computer",
+ ],
+ }
+
+ _has_built_in_disparity_mask = True
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ calibration: Optional[str] = "perfect",
+ use_ambient_views: bool = False,
+ transforms: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
+ self.split = split
+
+ if calibration:
+ verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None)) # type: ignore
+ if split == "test":
+ raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.")
+ else:
+ if split != "test":
+ raise ValueError(
+ f"Split '{split}' has calibration settings, however None was provided as an argument."
+ f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
+ )
+
+ if download:
+ self._download_dataset(root)
+
+ root = Path(root) / "Middlebury2014"
+
+ if not os.path.exists(root / split):
+ raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
+
+ split_scenes = self.splits[split]
+ # check that the provided root folder contains the scene splits
+ if not any(
+ # using startswith to account for perfect / imperfect calibrartion
+ scene.startswith(s)
+ for scene in os.listdir(root / split)
+ for s in split_scenes
+ ):
+ raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.")
+
+ calibrartion_suffixes = {
+ None: [""],
+ "perfect": ["-perfect"],
+ "imperfect": ["-imperfect"],
+ "both": ["-perfect", "-imperfect"],
+ }[calibration]
+
+ for calibration_suffix in calibrartion_suffixes:
+ scene_pattern = "*" + calibration_suffix
+ left_img_pattern = str(root / split / scene_pattern / "im0.png")
+ right_img_pattern = str(root / split / scene_pattern / "im1.png")
+ self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ if split == "test":
+ self._disparities = list((None, None) for _ in self._images)
+ else:
+ left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm")
+ right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm")
+ self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern)
+
+ self.use_ambient_views = use_ambient_views
+
+ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+ """
+ Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True.
+ When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
+ as the right image.
+ """
+ ambient_file_paths: list[Union[str, Path]] # make mypy happy
+
+ if not isinstance(file_path, Path):
+ file_path = Path(file_path)
+
+ if file_path.name == "im1.png" and self.use_ambient_views:
+ base_path = file_path.parent
+ # initialize sampleable container
+ ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"])
+ # double check that we're not going to try to read from an invalid file path
+ ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths))
+ # keep the original image as an option as well for uniform sampling between base views
+ ambient_file_paths.append(file_path)
+ file_path = random.choice(ambient_file_paths) # type: ignore
+ return super()._read_img(file_path)
+
+ def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
+ # test split has not disparity maps
+ if file_path is None:
+ return None, None
+
+ disparity_map = _read_pfm_file(file_path)
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
+ disparity_map[disparity_map == np.inf] = 0 # remove infinite disparities
+ valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
+ return disparity_map, valid_mask
+
+ def _download_dataset(self, root: Union[str, Path]) -> None:
+ base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
+ # train and additional splits have 2 different calibration settings
+ root = Path(root) / "Middlebury2014"
+ split_name = self.split
+
+ if split_name != "test":
+ for split_scene in self.splits[split_name]:
+ split_root = root / split_name
+ for calibration in ["perfect", "imperfect"]:
+ scene_name = f"{split_scene}-{calibration}"
+ scene_url = f"{base_url}/{scene_name}.zip"
+ # download the scene only if it doesn't exist
+ if not (split_root / scene_name).exists():
+ download_and_extract_archive(
+ url=scene_url,
+ filename=f"{scene_name}.zip",
+ download_root=str(split_root),
+ remove_finished=True,
+ )
+ else:
+ os.makedirs(root / "test")
+ if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
+ # test split is downloaded from a different location
+ test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
+ # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF
+ # we want to move the contents from testF into the directory
+ download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True)
+ for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")):
+ for scene in scene_names:
+ scene_dst_dir = root / "test"
+ scene_src_dir = Path(scene_dir) / scene
+ os.makedirs(scene_dst_dir, exist_ok=True)
+ shutil.move(str(scene_src_dir), str(scene_dst_dir))
+
+ # cleanup MiddEval3 directory
+ shutil.rmtree(str(root / "MiddEval3"))
+
+ def __getitem__(self, index: int) -> T2:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ ``valid_mask`` is implicitly ``None`` for `split=test`.
+ """
+ return cast(T2, super().__getitem__(index))
+
+
+class CREStereo(StereoMatchingDataset):
+ """Synthetic dataset used in training the `CREStereo `_ architecture.
+ Dataset details on the official paper `repo `_.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ CREStereo
+ tree
+ img1_left.jpg
+ img1_right.jpg
+ img1_left.disp.jpg
+ img1_right.disp.jpg
+ img2_left.jpg
+ img2_right.jpg
+ img2_left.disp.jpg
+ img2_right.disp.jpg
+ ...
+ shapenet
+ img1_left.jpg
+ img1_right.jpg
+ img1_left.disp.jpg
+ img1_right.disp.jpg
+ ...
+ reflective
+ img1_left.jpg
+ img1_right.jpg
+ img1_left.disp.jpg
+ img1_right.disp.jpg
+ ...
+ hole
+ img1_left.jpg
+ img1_right.jpg
+ img1_left.disp.jpg
+ img1_right.disp.jpg
+ ...
+
+ Args:
+ root (str): Root directory of the dataset.
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ _has_built_in_disparity_mask = True
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transforms: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root, transforms)
+
+ root = Path(root) / "CREStereo"
+
+ dirs = ["shapenet", "reflective", "tree", "hole"]
+
+ for s in dirs:
+ left_image_pattern = str(root / s / "*_left.jpg")
+ right_image_pattern = str(root / s / "*_right.jpg")
+ imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+ self._images += imgs
+
+ left_disparity_pattern = str(root / s / "*_left.disp.png")
+ right_disparity_pattern = str(root / s / "*_right.disp.png")
+ disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+ self._disparities += disparities
+
+ def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+ disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+ # unsqueeze the disparity map into (C, H, W) format
+ disparity_map = disparity_map[None, :, :] / 32.0
+ valid_mask = None
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T1:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+ generate a valid mask.
+ """
+ return cast(T1, super().__getitem__(index))
+
+
+class FallingThingsStereo(StereoMatchingDataset):
+ """`FallingThings `_ dataset.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ FallingThings
+ single
+ dir1
+ scene1
+ _object_settings.json
+ _camera_settings.json
+ image1.left.depth.png
+ image1.right.depth.png
+ image1.left.jpg
+ image1.right.jpg
+ image2.left.depth.png
+ image2.right.depth.png
+ image2.left.jpg
+ image2.right
+ ...
+ scene2
+ ...
+ mixed
+ scene1
+ _object_settings.json
+ _camera_settings.json
+ image1.left.depth.png
+ image1.right.depth.png
+ image1.left.jpg
+ image1.right.jpg
+ image2.left.depth.png
+ image2.right.depth.png
+ image2.left.jpg
+ image2.right
+ ...
+ scene2
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where FallingThings is located.
+ variant (string): Which variant to use. Either "single", "mixed", or "both".
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root, transforms)
+
+ root = Path(root) / "FallingThings"
+
+ verify_str_arg(variant, "variant", valid_values=("single", "mixed", "both"))
+
+ variants = {
+ "single": ["single"],
+ "mixed": ["mixed"],
+ "both": ["single", "mixed"],
+ }[variant]
+
+ split_prefix = {
+ "single": Path("*") / "*",
+ "mixed": Path("*"),
+ }
+
+ for s in variants:
+ left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg")
+ right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg")
+ self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png")
+ right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
+ self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+ def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+ # (H, W) image
+ depth = np.asarray(Image.open(file_path))
+ # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
+ # in order to extract disparity from depth maps
+ camera_settings_path = Path(file_path).parent / "_camera_settings.json"
+ with open(camera_settings_path) as f:
+ # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
+ intrinsics = json.load(f)
+ focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
+ baseline, pixel_constant = 6, 100 # pixel constant is inverted
+ disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32)
+ # unsqueeze disparity to (C, H, W)
+ disparity_map = disparity_map[None, :, :]
+ valid_mask = None
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T1:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+ """
+ return cast(T1, super().__getitem__(index))
+
+
+class SceneFlowStereo(StereoMatchingDataset):
+ """Dataset interface for `Scene Flow `_ datasets.
+ This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ SceneFlow
+ Monkaa
+ frames_cleanpass
+ scene1
+ left
+ img1.png
+ img2.png
+ right
+ img1.png
+ img2.png
+ scene2
+ left
+ img1.png
+ img2.png
+ right
+ img1.png
+ img2.png
+ frames_finalpass
+ scene1
+ left
+ img1.png
+ img2.png
+ right
+ img1.png
+ img2.png
+ ...
+ ...
+ disparity
+ scene1
+ left
+ img1.pfm
+ img2.pfm
+ right
+ img1.pfm
+ img2.pfm
+ FlyingThings3D
+ ...
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
+ variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
+ pass_name (string): Which pass to use, "clean" (default), "final" or "both".
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ variant: str = "FlyingThings3D",
+ pass_name: str = "clean",
+ transforms: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root, transforms)
+
+ root = Path(root) / "SceneFlow"
+
+ verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+
+ passes = {
+ "clean": ["frames_cleanpass"],
+ "final": ["frames_finalpass"],
+ "both": ["frames_cleanpass", "frames_finalpass"],
+ }[pass_name]
+
+ root = root / variant
+
+ prefix_directories = {
+ "Monkaa": Path("*"),
+ "FlyingThings3D": Path("*") / "*" / "*",
+ "Driving": Path("*") / "*" / "*",
+ }
+
+ for p in passes:
+ left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png")
+ right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png")
+ self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
+
+ left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm")
+ right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
+ self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+ def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+ disparity_map = _read_pfm_file(file_path)
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
+ valid_mask = None
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T1:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+ """
+ return cast(T1, super().__getitem__(index))
+
+
+class SintelStereo(StereoMatchingDataset):
+ """Sintel `Stereo Dataset `_.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ Sintel
+ training
+ final_left
+ scene1
+ img1.png
+ img2.png
+ ...
+ ...
+ final_right
+ scene2
+ img1.png
+ img2.png
+ ...
+ ...
+ disparities
+ scene1
+ img1.png
+ img2.png
+ ...
+ ...
+ occlusions
+ scene1
+ img1.png
+ img2.png
+ ...
+ ...
+ outofframe
+ scene1
+ img1.png
+ img2.png
+ ...
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
+ pass_name (string): The name of the pass to use, either "final", "clean" or "both".
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ _has_built_in_disparity_mask = True
+
+ def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root, transforms)
+
+ verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
+
+ root = Path(root) / "Sintel"
+ pass_names = {
+ "final": ["final"],
+ "clean": ["clean"],
+ "both": ["final", "clean"],
+ }[pass_name]
+
+ for p in pass_names:
+ left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
+ right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
+ self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
+ self._disparities += self._scan_pairs(disparity_pattern, None)
+
+ def _get_occlussion_mask_paths(self, file_path: str) -> tuple[str, str]:
+ # helper function to get the occlusion mask paths
+ # a path will look like .../.../.../training/disparities/scene1/img1.png
+ # we want to get something like .../.../.../training/occlusions/scene1/img1.png
+ fpath = Path(file_path)
+ basename = fpath.name
+ scenedir = fpath.parent
+ # the parent of the scenedir is actually the disparity dir
+ sampledir = scenedir.parent.parent
+
+ occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
+ outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)
+
+ if not os.path.exists(occlusion_path):
+ raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")
+
+ if not os.path.exists(outofframe_path):
+ raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")
+
+ return occlusion_path, outofframe_path
+
+ def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
+ if file_path is None:
+ return None, None
+
+ # disparity decoding as per Sintel instructions in the README provided with the dataset
+ disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+ r, g, b = np.split(disparity_map, 3, axis=-1)
+ disparity_map = r * 4 + g / (2**6) + b / (2**14)
+ # reshape into (C, H, W) format
+ disparity_map = np.transpose(disparity_map, (2, 0, 1))
+ # find the appropriate file paths
+ occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
+ # occlusion masks
+ valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
+ # out of frame masks
+ off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
+ # combine the masks together
+ valid_mask = np.logical_and(off_mask, valid_mask)
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T2:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
+ the valid_mask is a numpy array of shape (H, W).
+ """
+ return cast(T2, super().__getitem__(index))
+
+
+class InStereo2k(StereoMatchingDataset):
+ """`InStereo2k `_ dataset.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ InStereo2k
+ train
+ scene1
+ left.png
+ right.png
+ left_disp.png
+ right_disp.png
+ ...
+ scene2
+ ...
+ test
+ scene1
+ left.png
+ right.png
+ left_disp.png
+ right_disp.png
+ ...
+ scene2
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
+ split (string): Either "train" or "test".
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root, transforms)
+
+ root = Path(root) / "InStereo2k" / split
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ left_img_pattern = str(root / "*" / "left.png")
+ right_img_pattern = str(root / "*" / "right.png")
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ left_disparity_pattern = str(root / "*" / "left_disp.png")
+ right_disparity_pattern = str(root / "*" / "right_disp.png")
+ self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+ def _read_disparity(self, file_path: str) -> tuple[np.ndarray, None]:
+ disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+ # unsqueeze disparity to (C, H, W)
+ disparity_map = disparity_map[None, :, :] / 1024.0
+ valid_mask = None
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T1:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+ """
+ return cast(T1, super().__getitem__(index))
+
+
+class ETH3DStereo(StereoMatchingDataset):
+ """ETH3D `Low-Res Two-View `_ dataset.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ ETH3D
+ two_view_training
+ scene1
+ im1.png
+ im0.png
+ images.txt
+ cameras.txt
+ calib.txt
+ scene2
+ im1.png
+ im0.png
+ images.txt
+ cameras.txt
+ calib.txt
+ ...
+ two_view_training_gt
+ scene1
+ disp0GT.pfm
+ mask0nocc.png
+ scene2
+ disp0GT.pfm
+ mask0nocc.png
+ ...
+ two_view_testing
+ scene1
+ im1.png
+ im0.png
+ images.txt
+ cameras.txt
+ calib.txt
+ scene2
+ im1.png
+ im0.png
+ images.txt
+ cameras.txt
+ calib.txt
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
+ split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+ """
+
+ _has_built_in_disparity_mask = True
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root, transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ root = Path(root) / "ETH3D"
+
+ img_dir = "two_view_training" if split == "train" else "two_view_test"
+ anot_dir = "two_view_training_gt"
+
+ left_img_pattern = str(root / img_dir / "*" / "im0.png")
+ right_img_pattern = str(root / img_dir / "*" / "im1.png")
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+ if split == "test":
+ self._disparities = list((None, None) for _ in self._images)
+ else:
+ disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
+ self._disparities = self._scan_pairs(disparity_pattern, None)
+
+ def _read_disparity(self, file_path: str) -> Union[tuple[None, None], tuple[np.ndarray, np.ndarray]]:
+ # test split has no disparity maps
+ if file_path is None:
+ return None, None
+
+ disparity_map = _read_pfm_file(file_path)
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
+ mask_path = Path(file_path).parent / "mask0nocc.png"
+ valid_mask = Image.open(mask_path)
+ valid_mask = np.asarray(valid_mask).astype(bool)
+ return disparity_map, valid_mask
+
+ def __getitem__(self, index: int) -> T2:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+ generate a valid mask.
+ Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+ """
+ return cast(T2, super().__getitem__(index))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/caltech.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/caltech.py
new file mode 100644
index 0000000000000000000000000000000000000000..7498f67400158f1c0da8a6bb66866153735120ce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/caltech.py
@@ -0,0 +1,241 @@
+import os
+import os.path
+import shutil
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Caltech101(VisionDataset):
+ """`Caltech 101 `_ Dataset.
+
+ .. warning::
+
+ This class needs `scipy `_ to load target files from `.mat` format.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``caltech101`` exists or will be saved to if download is set to True.
+ target_type (string or list, optional): Type of target to use, ``category`` or
+ ``annotation``. Can also be a list to output a tuple with all specified
+ target types. ``category`` represents the target class, and
+ ``annotation`` is a list of points from a hand-generated outline.
+ Defaults to ``category``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ .. warning::
+
+ To download the dataset `gdown `_ is required.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ target_type: Union[list[str], str] = "category",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
+ os.makedirs(self.root, exist_ok=True)
+ if isinstance(target_type, str):
+ target_type = [target_type]
+ self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
+ self.categories.remove("BACKGROUND_Google") # this is not a real class
+
+ # For some reason, the category names in "101_ObjectCategories" and
+ # "Annotations" do not always match. This is a manual map between the
+ # two. Defaults to using same name, since most names are fine.
+ name_map = {
+ "Faces": "Faces_2",
+ "Faces_easy": "Faces_3",
+ "Motorbikes": "Motorbikes_16",
+ "airplanes": "Airplanes_Side_2",
+ }
+ self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
+
+ self.index: list[int] = []
+ self.y = []
+ for i, c in enumerate(self.categories):
+ n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
+ self.index.extend(range(1, n + 1))
+ self.y.extend(n * [i])
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where the type of target specified by target_type.
+ """
+ import scipy.io
+
+ img = Image.open(
+ os.path.join(
+ self.root,
+ "101_ObjectCategories",
+ self.categories[self.y[index]],
+ f"image_{self.index[index]:04d}.jpg",
+ )
+ )
+
+ target: Any = []
+ for t in self.target_type:
+ if t == "category":
+ target.append(self.y[index])
+ elif t == "annotation":
+ data = scipy.io.loadmat(
+ os.path.join(
+ self.root,
+ "Annotations",
+ self.annotation_categories[self.y[index]],
+ f"annotation_{self.index[index]:04d}.mat",
+ )
+ )
+ target.append(data["obj_contour"])
+ target = tuple(target) if len(target) > 1 else target[0]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def _check_integrity(self) -> bool:
+ # can be more robust and check hash of files
+ return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
+
+ def __len__(self) -> int:
+ return len(self.index)
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+
+ download_and_extract_archive(
+ "https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip",
+ download_root=self.root,
+ filename="caltech-101.zip",
+ md5="3138e1922a9193bfa496528edbbc45d0",
+ )
+ gzip_folder = os.path.join(self.root, "caltech-101")
+ for gzip_file in os.listdir(gzip_folder):
+ if gzip_file.endswith(".gz"):
+ extract_archive(os.path.join(gzip_folder, gzip_file), self.root)
+ shutil.rmtree(gzip_folder)
+ os.remove(os.path.join(self.root, "caltech-101.zip"))
+
+ def extra_repr(self) -> str:
+ return "Target type: {target_type}".format(**self.__dict__)
+
+
+class Caltech256(VisionDataset):
+ """`Caltech 256 `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``caltech256`` exists or will be saved to if download is set to True.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
+ os.makedirs(self.root, exist_ok=True)
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
+ self.index: list[int] = []
+ self.y = []
+ for i, c in enumerate(self.categories):
+ n = len(
+ [
+ item
+ for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
+ if item.endswith(".jpg")
+ ]
+ )
+ self.index.extend(range(1, n + 1))
+ self.y.extend(n * [i])
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img = Image.open(
+ os.path.join(
+ self.root,
+ "256_ObjectCategories",
+ self.categories[self.y[index]],
+ f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
+ )
+ )
+
+ target = self.y[index]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def _check_integrity(self) -> bool:
+ # can be more robust and check hash of files
+ return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
+
+ def __len__(self) -> int:
+ return len(self.index)
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+
+ download_and_extract_archive(
+ "https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar",
+ self.root,
+ filename="256_ObjectCategories.tar",
+ md5="67b4f42ca05d46448c6bb8ecd2220f6d",
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/celeba.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/celeba.py
new file mode 100644
index 0000000000000000000000000000000000000000..469af6ed3b7efa433e2a7e488e8017a78710bad5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/celeba.py
@@ -0,0 +1,199 @@
+import csv
+import os
+from collections import namedtuple
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import PIL
+import torch
+
+from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CSV = namedtuple("CSV", ["header", "index", "data"])
+
+
+class CelebA(VisionDataset):
+ """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+ split (string): One of {'train', 'valid', 'test', 'all'}.
+ Accordingly dataset is selected.
+ target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
+ or ``landmarks``. Can also be a list to output a tuple with all specified target types.
+ The targets represent:
+
+ - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
+ - ``identity`` (int): label for each person (data points with the same identity are the same person)
+ - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
+ - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
+ righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
+
+ Defaults to ``attr``. If empty, ``None`` will be returned as target.
+
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ .. warning::
+
+ To download the dataset `gdown `_ is required.
+ """
+
+ base_folder = "celeba"
+ # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
+ # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
+ # right now.
+ file_list = [
+ # File ID MD5 Hash Filename
+ ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
+ # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
+ # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
+ ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
+ ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
+ ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
+ ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
+ # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
+ ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
+ ]
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ target_type: Union[list[str], str] = "attr",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.split = split
+ if isinstance(target_type, list):
+ self.target_type = target_type
+ else:
+ self.target_type = [target_type]
+
+ if not self.target_type and self.target_transform is not None:
+ raise RuntimeError("target_transform is specified but target_type is empty")
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ split_map = {
+ "train": 0,
+ "valid": 1,
+ "test": 2,
+ "all": None,
+ }
+ split_ = split_map[
+ verify_str_arg(
+ split.lower() if isinstance(split, str) else split,
+ "split",
+ ("train", "valid", "test", "all"),
+ )
+ ]
+ splits = self._load_csv("list_eval_partition.txt")
+ identity = self._load_csv("identity_CelebA.txt")
+ bbox = self._load_csv("list_bbox_celeba.txt", header=1)
+ landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
+ attr = self._load_csv("list_attr_celeba.txt", header=1)
+
+ mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
+
+ if mask == slice(None): # if split == "all"
+ self.filename = splits.index
+ else:
+ self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))] # type: ignore[arg-type]
+ self.identity = identity.data[mask]
+ self.bbox = bbox.data[mask]
+ self.landmarks_align = landmarks_align.data[mask]
+ self.attr = attr.data[mask]
+ # map from {-1, 1} to {0, 1}
+ self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
+ self.attr_names = attr.header
+
+ def _load_csv(
+ self,
+ filename: str,
+ header: Optional[int] = None,
+ ) -> CSV:
+ with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
+ data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
+
+ if header is not None:
+ headers = data[header]
+ data = data[header + 1 :]
+ else:
+ headers = []
+
+ indices = [row[0] for row in data]
+ data = [row[1:] for row in data]
+ data_int = [list(map(int, i)) for i in data]
+
+ return CSV(headers, indices, torch.tensor(data_int))
+
+ def _check_integrity(self) -> bool:
+ for _, md5, filename in self.file_list:
+ fpath = os.path.join(self.root, self.base_folder, filename)
+ _, ext = os.path.splitext(filename)
+ # Allow original archive to be deleted (zip and 7z)
+ # Only need the extracted images
+ if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
+ return False
+
+ # Should check a hash of the images
+ return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+
+ for file_id, md5, filename in self.file_list:
+ download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
+
+ extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
+
+ target: Any = []
+ for t in self.target_type:
+ if t == "attr":
+ target.append(self.attr[index, :])
+ elif t == "identity":
+ target.append(self.identity[index, 0])
+ elif t == "bbox":
+ target.append(self.bbox[index, :])
+ elif t == "landmarks":
+ target.append(self.landmarks_align[index, :])
+ else:
+ # TODO: refactor with utils.verify_str_arg
+ raise ValueError(f'Target type "{t}" is not recognized.')
+
+ if self.transform is not None:
+ X = self.transform(X)
+
+ if target:
+ target = tuple(target) if len(target) > 1 else target[0]
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ else:
+ target = None
+
+ return X, target
+
+ def __len__(self) -> int:
+ return len(self.attr)
+
+ def extra_repr(self) -> str:
+ lines = ["Target type: {target_type}", "Split: {split}"]
+ return "\n".join(lines).format(**self.__dict__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/cifar.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..45893a4499506a43323bf53d9552adec2a457261
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/cifar.py
@@ -0,0 +1,167 @@
+import os.path
+import pickle
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive
+from .vision import VisionDataset
+
+
+class CIFAR10(VisionDataset):
+ """`CIFAR10 `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
+ train (bool, optional): If True, creates dataset from training set, otherwise
+ creates from test set.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+
+ base_folder = "cifar-10-batches-py"
+ url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
+ filename = "cifar-10-python.tar.gz"
+ tgz_md5 = "c58f30108f718f92721af3b95e74349a"
+ train_list = [
+ ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
+ ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
+ ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
+ ["data_batch_4", "634d18415352ddfa80567beed471001a"],
+ ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
+ ]
+
+ test_list = [
+ ["test_batch", "40351d587109b95175f43aff81a1287e"],
+ ]
+ meta = {
+ "filename": "batches.meta",
+ "key": "label_names",
+ "md5": "5ff9c542aee3614f3951f8cda6e48888",
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self.train = train # training set or test set
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ if self.train:
+ downloaded_list = self.train_list
+ else:
+ downloaded_list = self.test_list
+
+ self.data: Any = []
+ self.targets = []
+
+ # now load the picked numpy arrays
+ for file_name, checksum in downloaded_list:
+ file_path = os.path.join(self.root, self.base_folder, file_name)
+ with open(file_path, "rb") as f:
+ entry = pickle.load(f, encoding="latin1")
+ self.data.append(entry["data"])
+ if "labels" in entry:
+ self.targets.extend(entry["labels"])
+ else:
+ self.targets.extend(entry["fine_labels"])
+
+ self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
+ self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
+
+ self._load_meta()
+
+ def _load_meta(self) -> None:
+ path = os.path.join(self.root, self.base_folder, self.meta["filename"])
+ if not check_integrity(path, self.meta["md5"]):
+ raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
+ with open(path, "rb") as infile:
+ data = pickle.load(infile, encoding="latin1")
+ self.classes = data[self.meta["key"]]
+ self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def _check_integrity(self) -> bool:
+ for filename, md5 in self.train_list + self.test_list:
+ fpath = os.path.join(self.root, self.base_folder, filename)
+ if not check_integrity(fpath, md5):
+ return False
+ return True
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+ download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+
+ def extra_repr(self) -> str:
+ split = "Train" if self.train is True else "Test"
+ return f"Split: {split}"
+
+
+class CIFAR100(CIFAR10):
+ """`CIFAR100 `_ Dataset.
+
+ This is a subclass of the `CIFAR10` Dataset.
+ """
+
+ base_folder = "cifar-100-python"
+ url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
+ filename = "cifar-100-python.tar.gz"
+ tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
+ train_list = [
+ ["train", "16019d7e3df5f24257cddd939b257f8d"],
+ ]
+
+ test_list = [
+ ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
+ ]
+ meta = {
+ "filename": "meta",
+ "key": "fine_label_names",
+ "md5": "7973b15100ade9c7d40fb424638fde48",
+ }
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/cityscapes.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..a124439932f98b53d88e9ebc1db59068ae910989
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/cityscapes.py
@@ -0,0 +1,222 @@
+import json
+import os
+from collections import namedtuple
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import extract_archive, iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class Cityscapes(VisionDataset):
+ """`Cityscapes `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
+ and ``gtFine`` or ``gtCoarse`` are located.
+ split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
+ otherwise ``train``, ``train_extra`` or ``val``
+ mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
+ target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
+ or ``color``. Can also be a list to output a tuple with all specified target types.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version.
+
+ Examples:
+
+ Get semantic segmentation target
+
+ .. code-block:: python
+
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+ target_type='semantic')
+
+ img, smnt = dataset[0]
+
+ Get multiple targets
+
+ .. code-block:: python
+
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+ target_type=['instance', 'color', 'polygon'])
+
+ img, (inst, col, poly) = dataset[0]
+
+ Validate on the "coarse" set
+
+ .. code-block:: python
+
+ dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
+ target_type='semantic')
+
+ img, smnt = dataset[0]
+ """
+
+ # Based on https://github.com/mcordts/cityscapesScripts
+ CityscapesClass = namedtuple(
+ "CityscapesClass",
+ ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
+ )
+
+ classes = [
+ CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
+ CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
+ CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
+ CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
+ CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
+ CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
+ CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
+ CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
+ CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
+ CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
+ CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
+ CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
+ CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
+ CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
+ CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
+ CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
+ CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
+ CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
+ CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
+ CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
+ CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
+ CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
+ CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
+ CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
+ CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
+ CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
+ CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
+ CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
+ CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
+ CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
+ CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
+ CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
+ CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
+ CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
+ CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
+ ]
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ mode: str = "fine",
+ target_type: Union[list[str], str] = "instance",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ transforms: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root, transforms, transform, target_transform)
+ self.mode = "gtFine" if mode == "fine" else "gtCoarse"
+ self.images_dir = os.path.join(self.root, "leftImg8bit", split)
+ self.targets_dir = os.path.join(self.root, self.mode, split)
+ self.target_type = target_type
+ self.split = split
+ self.images = []
+ self.targets = []
+
+ verify_str_arg(mode, "mode", ("fine", "coarse"))
+ if mode == "fine":
+ valid_modes = ("train", "test", "val")
+ else:
+ valid_modes = ("train", "train_extra", "val")
+ msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
+ msg = msg.format(split, mode, iterable_to_str(valid_modes))
+ verify_str_arg(split, "split", valid_modes, msg)
+
+ if not isinstance(target_type, list):
+ self.target_type = [target_type]
+ [
+ verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
+ for value in self.target_type
+ ]
+
+ if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
+
+ if split == "train_extra":
+ image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
+ else:
+ image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
+
+ if self.mode == "gtFine":
+ target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
+ elif self.mode == "gtCoarse":
+ target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
+
+ if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
+ extract_archive(from_path=image_dir_zip, to_path=self.root)
+ extract_archive(from_path=target_dir_zip, to_path=self.root)
+ else:
+ raise RuntimeError(
+ "Dataset not found or incomplete. Please make sure all required folders for the"
+ ' specified "split" and "mode" are inside the "root" directory'
+ )
+
+ for city in os.listdir(self.images_dir):
+ img_dir = os.path.join(self.images_dir, city)
+ target_dir = os.path.join(self.targets_dir, city)
+ for file_name in os.listdir(img_dir):
+ target_types = []
+ for t in self.target_type:
+ target_name = "{}_{}".format(
+ file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
+ )
+ target_types.append(os.path.join(target_dir, target_name))
+
+ self.images.append(os.path.join(img_dir, file_name))
+ self.targets.append(target_types)
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
+ than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
+ """
+
+ image = Image.open(self.images[index]).convert("RGB")
+
+ targets: Any = []
+ for i, t in enumerate(self.target_type):
+ if t == "polygon":
+ target = self._load_json(self.targets[index][i])
+ else:
+ target = Image.open(self.targets[index][i]) # type: ignore[assignment]
+
+ targets.append(target)
+
+ target = tuple(targets) if len(targets) > 1 else targets[0] # type: ignore[assignment]
+
+ if self.transforms is not None:
+ image, target = self.transforms(image, target)
+
+ return image, target
+
+ def __len__(self) -> int:
+ return len(self.images)
+
+ def extra_repr(self) -> str:
+ lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
+ return "\n".join(lines).format(**self.__dict__)
+
+ def _load_json(self, path: str) -> dict[str, Any]:
+ with open(path) as file:
+ data = json.load(file)
+ return data
+
+ def _get_target_suffix(self, mode: str, target_type: str) -> str:
+ if target_type == "instance":
+ return f"{mode}_instanceIds.png"
+ elif target_type == "semantic":
+ return f"{mode}_labelIds.png"
+ elif target_type == "color":
+ return f"{mode}_color.png"
+ else:
+ return f"{mode}_polygons.json"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/clevr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/clevr.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf24bc3c80a94aa2ca56b26fd0e1495374d03ab
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/clevr.py
@@ -0,0 +1,93 @@
+import json
+import pathlib
+from typing import Any, Callable, Optional, Union
+from urllib.parse import urlparse
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class CLEVRClassification(VisionDataset):
+ """`CLEVR `_ classification dataset.
+
+ The number of objects in a scene are used as label.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
+ set to True.
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in them target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
+ dataset is already downloaded, it is not downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
+ _MD5 = "b11922020e72d0cd9154779b2d3d07d2"
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
+ ) -> None:
+ self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.loader = loader
+ self._base_folder = pathlib.Path(self.root) / "clevr"
+ self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
+
+ self._labels: list[Optional[int]]
+ if self._split != "test":
+ with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
+ content = json.load(file)
+ num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
+ self._labels = [num_objects[image_file.name] for image_file in self._image_files]
+ else:
+ self._labels = [None] * len(self._image_files)
+
+ def __len__(self) -> int:
+ return len(self._image_files)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_file = self._image_files[idx]
+ label = self._labels[idx]
+
+ image = self.loader(image_file)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def _check_exists(self) -> bool:
+ return self._data_folder.exists() and self._data_folder.is_dir()
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+
+ download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
+
+ def extra_repr(self) -> str:
+ return f"split={self._split}"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/coco.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f3b5d2dfe4a9047ef49322501582ed9d09cb5a1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/coco.py
@@ -0,0 +1,111 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+class CocoDetection(VisionDataset):
+ """`MS Coco Detection `_ Dataset.
+
+ It requires `pycocotools `_ to be installed,
+ which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+ annFile (string): Path to json annotation file.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ annFile: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ transforms: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root, transforms, transform, target_transform)
+ from pycocotools.coco import COCO
+
+ self.coco = COCO(annFile)
+ self.ids = list(sorted(self.coco.imgs.keys()))
+
+ def _load_image(self, id: int) -> Image.Image:
+ path = self.coco.loadImgs(id)[0]["file_name"]
+ return Image.open(os.path.join(self.root, path)).convert("RGB")
+
+ def _load_target(self, id: int) -> list[Any]:
+ return self.coco.loadAnns(self.coco.getAnnIds(id))
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+
+ if not isinstance(index, int):
+ raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
+
+ id = self.ids[index]
+ image = self._load_image(id)
+ target = self._load_target(id)
+
+ if self.transforms is not None:
+ image, target = self.transforms(image, target)
+
+ return image, target
+
+ def __len__(self) -> int:
+ return len(self.ids)
+
+
+class CocoCaptions(CocoDetection):
+ """`MS Coco Captions `_ Dataset.
+
+ It requires `pycocotools `_ to be installed,
+ which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+ annFile (string): Path to json annotation file.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version.
+
+ Example:
+
+ .. code:: python
+
+ import torchvision.datasets as dset
+ import torchvision.transforms as transforms
+ cap = dset.CocoCaptions(root = 'dir where images are',
+ annFile = 'json annotation file',
+ transform=transforms.PILToTensor())
+
+ print('Number of samples: ', len(cap))
+ img, target = cap[3] # load 4th sample
+
+ print("Image Size: ", img.size())
+ print(target)
+
+ Output: ::
+
+ Number of samples: 82783
+ Image Size: (3L, 427L, 640L)
+ [u'A plane emitting smoke stream flying over a mountain.',
+ u'A plane darts across a bright blue sky behind a mountain covered in snow',
+ u'A plane leaves a contrail above the snowy mountain top.',
+ u'A mountain that has a plane flying overheard in the distance.',
+ u'A mountain view with a plume of smoke in the background']
+
+ """
+
+ def _load_target(self, id: int) -> list[str]:
+ return [ann["caption"] for ann in super()._load_target(id)]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/country211.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/country211.py
new file mode 100644
index 0000000000000000000000000000000000000000..50d49db00a72e2592f15329b70f4f0cdbfa6b128
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/country211.py
@@ -0,0 +1,67 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, ImageFolder
+from .utils import download_and_extract_archive, verify_str_arg
+
+
+class Country211(ImageFolder):
+ """`The Country211 Data Set `_ from OpenAI.
+
+ This dataset was built by filtering the images from the YFCC100m dataset
+ that have GPS coordinate corresponding to a ISO-3166 country code. The
+ dataset is balanced by sampling 150 train images, 50 validation images, and
+ 100 test images for each country.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and puts it into
+ ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
+ _MD5 = "84988d7644798601126c29e9877aab6a"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
+
+ root = Path(root).expanduser()
+ self.root = str(root)
+ self._base_folder = root / "country211"
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ super().__init__(
+ str(self._base_folder / self._split),
+ transform=transform,
+ target_transform=target_transform,
+ loader=loader,
+ )
+ self.root = str(root)
+
+ def _check_exists(self) -> bool:
+ return self._base_folder.exists() and self._base_folder.is_dir()
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+ download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/dtd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/dtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fb347955d420e04a68cb7055c46409293235b62
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/dtd.py
@@ -0,0 +1,105 @@
+import os
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class DTD(VisionDataset):
+ """`Describable Textures Dataset (DTD) `_.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+ partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
+
+ .. note::
+
+ The partition only changes which split each image belongs to. Thus, regardless of the selected
+ partition, combining all splits will result in all images.
+
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again. Default is False.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
+ _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "train",
+ partition: int = 1,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
+ ) -> None:
+ self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+ if not isinstance(partition, int) and not (1 <= partition <= 10):
+ raise ValueError(
+ f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
+ f"but got {partition} instead"
+ )
+ self._partition = partition
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
+ self._data_folder = self._base_folder / "dtd"
+ self._meta_folder = self._data_folder / "labels"
+ self._images_folder = self._data_folder / "images"
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ self._image_files = []
+ classes = []
+ with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
+ for line in file:
+ cls, name = line.strip().split("/")
+ self._image_files.append(self._images_folder.joinpath(cls, name))
+ classes.append(cls)
+
+ self.classes = sorted(set(classes))
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+ self._labels = [self.class_to_idx[cls] for cls in classes]
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._image_files)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_file, label = self._image_files[idx], self._labels[idx]
+ image = self.loader(image_file)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def extra_repr(self) -> str:
+ return f"split={self._split}, partition={self._partition}"
+
+ def _check_exists(self) -> bool:
+ return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+ download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/eurosat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/eurosat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4efec57029f617b04b5822489e396bb60ba9b639
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/eurosat.py
@@ -0,0 +1,71 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, ImageFolder
+from .utils import download_and_extract_archive
+
+
+class EuroSAT(ImageFolder):
+ """RGB version of the `EuroSAT `_ Dataset.
+
+ For the MS version of the dataset, see
+ `TorchGeo `__.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again. Default is False.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ self.root = os.path.expanduser(root)
+ self._base_folder = os.path.join(self.root, "eurosat")
+ self._data_folder = os.path.join(self._base_folder, "2750")
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ super().__init__(
+ self._data_folder,
+ transform=transform,
+ target_transform=target_transform,
+ loader=loader,
+ )
+ self.root = os.path.expanduser(root)
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+ def _check_exists(self) -> bool:
+ return os.path.exists(self._data_folder)
+
+ def download(self) -> None:
+
+ if self._check_exists():
+ return
+
+ os.makedirs(self._base_folder, exist_ok=True)
+ download_and_extract_archive(
+ "https://huggingface.co/datasets/torchgeo/eurosat/resolve/c877bcd43f099cd0196738f714544e355477f3fd/EuroSAT.zip",
+ download_root=self._base_folder,
+ md5="c8fa014336c82ac7804f0398fcb19387",
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fakedata.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fakedata.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcb413cdd32e784d962b9be46d53cf319fd677e3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fakedata.py
@@ -0,0 +1,67 @@
+from typing import Any, Callable, Optional
+
+import torch
+
+from .. import transforms
+from .vision import VisionDataset
+
+
+class FakeData(VisionDataset):
+ """A fake dataset that returns randomly generated images and returns them as PIL images
+
+ Args:
+ size (int, optional): Size of the dataset. Default: 1000 images
+ image_size(tuple, optional): Size of the returned images. Default: (3, 224, 224)
+ num_classes(int, optional): Number of classes in the dataset. Default: 10
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ random_offset (int): Offsets the index-based random seed used to
+ generate each image. Default: 0
+
+ """
+
+ def __init__(
+ self,
+ size: int = 1000,
+ image_size: tuple[int, int, int] = (3, 224, 224),
+ num_classes: int = 10,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ random_offset: int = 0,
+ ) -> None:
+ super().__init__(transform=transform, target_transform=target_transform)
+ self.size = size
+ self.num_classes = num_classes
+ self.image_size = image_size
+ self.random_offset = random_offset
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ # create random image that is consistent with the index id
+ if index >= len(self):
+ raise IndexError(f"{self.__class__.__name__} index out of range")
+ rng_state = torch.get_rng_state()
+ torch.manual_seed(index + self.random_offset)
+ img = torch.randn(*self.image_size)
+ target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
+ torch.set_rng_state(rng_state)
+
+ # convert to PIL Image
+ img = transforms.ToPILImage()(img)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target.item()
+
+ def __len__(self) -> int:
+ return self.size
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fer2013.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fer2013.py
new file mode 100644
index 0000000000000000000000000000000000000000..f33afbeebc82e5bc62feb23bdefffe7a1472e22f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fer2013.py
@@ -0,0 +1,120 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+import torch
+from PIL import Image
+
+from .utils import check_integrity, verify_str_arg
+from .vision import VisionDataset
+
+
+class FER2013(VisionDataset):
+ """`FER2013
+ `_ Dataset.
+
+ .. note::
+ This dataset can return test labels only if ``fer2013.csv`` OR
+ ``icml_face_data.csv`` are present in ``root/fer2013/``. If only
+ ``train.csv`` and ``test.csv`` are present, the test labels are set to
+ ``None``.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``root/fer2013`` exists. This directory may contain either
+ ``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
+ ``test.csv``. Precendence is given in that order, i.e. if
+ ``fer2013.csv`` is present then the rest of the files will be
+ ignored. All these (combinations of) files contain the same data and
+ are supported for convenience, but only ``fer2013.csv`` and
+ ``icml_face_data.csv`` are able to return non-None test labels.
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ """
+
+ _RESOURCES = {
+ "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
+ "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
+ # The fer2013.csv and icml_face_data.csv files contain both train and
+ # tests instances, and unlike test.csv they contain the labels for the
+ # test instances. We give these 2 files precedence over train.csv and
+ # test.csv. And yes, they both contain the same data, but with different
+ # column names (note the spaces) and ordering:
+ # $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
+ # ==> fer2013.csv <==
+ # emotion,pixels,Usage
+ #
+ # ==> icml_face_data.csv <==
+ # emotion, Usage, pixels
+ #
+ # ==> train.csv <==
+ # emotion,pixels
+ #
+ # ==> test.csv <==
+ # pixels
+ "fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
+ "icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
+ }
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ self._split = verify_str_arg(split, "split", ("train", "test"))
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ base_folder = pathlib.Path(self.root) / "fer2013"
+ use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
+ use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
+ file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
+ data_file = base_folder / file_name
+ if not check_integrity(str(data_file), md5=md5):
+ raise RuntimeError(
+ f"{file_name} not found in {base_folder} or corrupted. "
+ f"You can download it from "
+ f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
+ )
+
+ pixels_key = " pixels" if use_icml_file else "pixels"
+ usage_key = " Usage" if use_icml_file else "Usage"
+
+ def get_img(row):
+ return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
+
+ def get_label(row):
+ if use_fer_file or use_icml_file or self._split == "train":
+ return int(row["emotion"])
+ else:
+ return None
+
+ with open(data_file, newline="") as file:
+ rows = (row for row in csv.DictReader(file))
+
+ if use_fer_file or use_icml_file:
+ valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
+ rows = (row for row in rows if row[usage_key] in valid_keys)
+
+ self._samples = [(get_img(row), get_label(row)) for row in rows]
+
+ def __len__(self) -> int:
+ return len(self._samples)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_tensor, target = self._samples[idx]
+ image = Image.fromarray(image_tensor.numpy())
+
+ if self.transform is not None:
+ image = self.transform(image)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return image, target
+
+ def extra_repr(self) -> str:
+ return f"split={self._split}"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fgvc_aircraft.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fgvc_aircraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f2277b23353fda4191bc1e6df87a805600e10d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/fgvc_aircraft.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import Any, Callable
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class FGVCAircraft(VisionDataset):
+ """`FGVC Aircraft `_ Dataset.
+
+ The dataset contains 10,000 images of aircraft, with 100 images for each of 100
+ different aircraft model variants, most of which are airplanes.
+ Aircraft models are organized in a three-levels hierarchy. The three levels, from
+ finer to coarser, are:
+
+ - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
+ indistinguishable into one class. The dataset comprises 100 different variants.
+ - ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
+ - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the FGVC Aircraft dataset.
+ split (string, optional): The dataset split, supports ``train``, ``val``,
+ ``trainval`` and ``test``.
+ annotation_level (str, optional): The annotation level, supports ``variant``,
+ ``family`` and ``manufacturer``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
+
+ def __init__(
+ self,
+ root: str | Path,
+ split: str = "trainval",
+ annotation_level: str = "variant",
+ transform: Callable | None = None,
+ target_transform: Callable | None = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
+ self._annotation_level = verify_str_arg(
+ annotation_level, "annotation_level", ("variant", "family", "manufacturer")
+ )
+
+ self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ annotation_file = os.path.join(
+ self._data_path,
+ "data",
+ {
+ "variant": "variants.txt",
+ "family": "families.txt",
+ "manufacturer": "manufacturers.txt",
+ }[self._annotation_level],
+ )
+ with open(annotation_file) as f:
+ self.classes = [line.strip() for line in f]
+
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+ image_data_folder = os.path.join(self._data_path, "data", "images")
+ labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
+
+ self._image_files = []
+ self._labels = []
+
+ with open(labels_file) as f:
+ for line in f:
+ image_name, label_name = line.strip().split(" ", 1)
+ self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
+ self._labels.append(self.class_to_idx[label_name])
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._image_files)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_file, label = self._image_files[idx], self._labels[idx]
+ image = self.loader(image_file)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def _download(self) -> None:
+ """
+ Download the FGVC Aircraft dataset archive and extract it under root.
+ """
+ if self._check_exists():
+ return
+ download_and_extract_archive(self._URL, self.root)
+
+ def _check_exists(self) -> bool:
+ return os.path.exists(self._data_path) and os.path.isdir(self._data_path)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/flickr.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/flickr.py
new file mode 100644
index 0000000000000000000000000000000000000000..84f1dc0e1702d0a263d8c8a05dcaad47dde35a14
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/flickr.py
@@ -0,0 +1,176 @@
+import glob
+import os
+from collections import defaultdict
+from html.parser import HTMLParser
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+from .vision import VisionDataset
+
+
+class Flickr8kParser(HTMLParser):
+ """Parser for extracting captions from the Flickr8k dataset web page."""
+
+ def __init__(self, root: Union[str, Path]) -> None:
+ super().__init__()
+
+ self.root = root
+
+ # Data structure to store captions
+ self.annotations: dict[str, list[str]] = {}
+
+ # State variables
+ self.in_table = False
+ self.current_tag: Optional[str] = None
+ self.current_img: Optional[str] = None
+
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, Optional[str]]]) -> None:
+ self.current_tag = tag
+
+ if tag == "table":
+ self.in_table = True
+
+ def handle_endtag(self, tag: str) -> None:
+ self.current_tag = None
+
+ if tag == "table":
+ self.in_table = False
+
+ def handle_data(self, data: str) -> None:
+ if self.in_table:
+ if data == "Image Not Found":
+ self.current_img = None
+ elif self.current_tag == "a":
+ img_id = data.split("/")[-2]
+ img_id = os.path.join(self.root, img_id + "_*.jpg")
+ img_id = glob.glob(img_id)[0]
+ self.current_img = img_id
+ self.annotations[img_id] = []
+ elif self.current_tag == "li" and self.current_img:
+ img_id = self.current_img
+ self.annotations[img_id].append(data.strip())
+
+
+class Flickr8k(VisionDataset):
+ """`Flickr8k Entities `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+ ann_file (string): Path to annotation file.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ ann_file: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.ann_file = os.path.expanduser(ann_file)
+
+ # Read annotations and store in a dict
+ parser = Flickr8kParser(self.root)
+ with open(self.ann_file) as fh:
+ parser.feed(fh.read())
+ self.annotations = parser.annotations
+
+ self.ids = list(sorted(self.annotations.keys()))
+ self.loader = loader
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target). target is a list of captions for the image.
+ """
+ img_id = self.ids[index]
+
+ # Image
+ img = self.loader(img_id)
+ if self.transform is not None:
+ img = self.transform(img)
+
+ # Captions
+ target = self.annotations[img_id]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.ids)
+
+
+class Flickr30k(VisionDataset):
+ """`Flickr30k Entities `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+ ann_file (string): Path to annotation file.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ ann_file: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.ann_file = os.path.expanduser(ann_file)
+
+ # Read annotations and store in a dict
+ self.annotations = defaultdict(list)
+ with open(self.ann_file) as fh:
+ for line in fh:
+ img_id, caption = line.strip().split("\t")
+ self.annotations[img_id[:-2]].append(caption)
+
+ self.ids = list(sorted(self.annotations.keys()))
+ self.loader = loader
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target). target is a list of captions for the image.
+ """
+ img_id = self.ids[index]
+
+ # Image
+ filename = os.path.join(self.root, img_id)
+ img = self.loader(filename)
+ if self.transform is not None:
+ img = self.transform(img)
+
+ # Captions
+ target = self.annotations[img_id]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.ids)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/flowers102.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/flowers102.py
new file mode 100644
index 0000000000000000000000000000000000000000..80bca71e9676869c49a9f9f01d8b6e6df7323a23
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/flowers102.py
@@ -0,0 +1,225 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class Flowers102(VisionDataset):
+ """`Oxford 102 Flower `_ Dataset.
+
+ .. warning::
+
+ This class needs `scipy `_ to load target files from `.mat` format.
+
+ Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
+ flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
+ between 40 and 258 images.
+
+ The images have large scale, pose and light variations. In addition, there are categories that
+ have large variations within the category, and several very similar categories.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
+ _file_dict = { # filename, md5
+ "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
+ "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
+ "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
+ }
+ _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[Union[str, Path]], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+ self._base_folder = Path(self.root) / "flowers-102"
+ self._images_folder = self._base_folder / "jpg"
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ from scipy.io import loadmat
+
+ set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
+ image_ids = set_ids[self._splits_map[self._split]].tolist()
+
+ labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
+ image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
+
+ self._labels = []
+ self._image_files = []
+ for image_id in image_ids:
+ self._labels.append(image_id_to_label[image_id])
+ self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
+
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._image_files)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_file, label = self._image_files[idx], self._labels[idx]
+ image = self.loader(image_file)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def extra_repr(self) -> str:
+ return f"split={self._split}"
+
+ def _check_integrity(self):
+ if not (self._images_folder.exists() and self._images_folder.is_dir()):
+ return False
+
+ for id in ["label", "setid"]:
+ filename, md5 = self._file_dict[id]
+ if not check_integrity(str(self._base_folder / filename), md5):
+ return False
+ return True
+
+ def download(self):
+ if self._check_integrity():
+ return
+ download_and_extract_archive(
+ f"{self._download_url_prefix}{self._file_dict['image'][0]}",
+ str(self._base_folder),
+ md5=self._file_dict["image"][1],
+ )
+ for id in ["label", "setid"]:
+ filename, md5 = self._file_dict[id]
+ download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
+
+ classes = [
+ "pink primrose",
+ "hard-leaved pocket orchid",
+ "canterbury bells",
+ "sweet pea",
+ "english marigold",
+ "tiger lily",
+ "moon orchid",
+ "bird of paradise",
+ "monkshood",
+ "globe thistle",
+ "snapdragon",
+ "colt's foot",
+ "king protea",
+ "spear thistle",
+ "yellow iris",
+ "globe-flower",
+ "purple coneflower",
+ "peruvian lily",
+ "balloon flower",
+ "giant white arum lily",
+ "fire lily",
+ "pincushion flower",
+ "fritillary",
+ "red ginger",
+ "grape hyacinth",
+ "corn poppy",
+ "prince of wales feathers",
+ "stemless gentian",
+ "artichoke",
+ "sweet william",
+ "carnation",
+ "garden phlox",
+ "love in the mist",
+ "mexican aster",
+ "alpine sea holly",
+ "ruby-lipped cattleya",
+ "cape flower",
+ "great masterwort",
+ "siam tulip",
+ "lenten rose",
+ "barbeton daisy",
+ "daffodil",
+ "sword lily",
+ "poinsettia",
+ "bolero deep blue",
+ "wallflower",
+ "marigold",
+ "buttercup",
+ "oxeye daisy",
+ "common dandelion",
+ "petunia",
+ "wild pansy",
+ "primula",
+ "sunflower",
+ "pelargonium",
+ "bishop of llandaff",
+ "gaura",
+ "geranium",
+ "orange dahlia",
+ "pink-yellow dahlia?",
+ "cautleya spicata",
+ "japanese anemone",
+ "black-eyed susan",
+ "silverbush",
+ "californian poppy",
+ "osteospermum",
+ "spring crocus",
+ "bearded iris",
+ "windflower",
+ "tree poppy",
+ "gazania",
+ "azalea",
+ "water lily",
+ "rose",
+ "thorn apple",
+ "morning glory",
+ "passion flower",
+ "lotus",
+ "toad lily",
+ "anthurium",
+ "frangipani",
+ "clematis",
+ "hibiscus",
+ "columbine",
+ "desert-rose",
+ "tree mallow",
+ "magnolia",
+ "cyclamen",
+ "watercress",
+ "canna lily",
+ "hippeastrum",
+ "bee balm",
+ "ball moss",
+ "foxglove",
+ "bougainvillea",
+ "camellia",
+ "mallow",
+ "mexican petunia",
+ "bromelia",
+ "blanket flower",
+ "trumpet creeper",
+ "blackberry lily",
+ ]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/folder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..387439c0433e8fa9f16163b1ad9629591639d09e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/folder.py
@@ -0,0 +1,337 @@
+import os
+import os.path
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+def has_file_allowed_extension(filename: str, extensions: Union[str, tuple[str, ...]]) -> bool:
+ """Checks if a file is an allowed extension.
+
+ Args:
+ filename (string): path to a file
+ extensions (tuple of strings): extensions to consider (lowercase)
+
+ Returns:
+ bool: True if the filename ends with one of given extensions
+ """
+ return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
+
+
+def is_image_file(filename: str) -> bool:
+ """Checks if a file is an allowed image extension.
+
+ Args:
+ filename (string): path to a file
+
+ Returns:
+ bool: True if the filename ends with a known image extension
+ """
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
+
+def find_classes(directory: Union[str, Path]) -> tuple[list[str], dict[str, int]]:
+ """Finds the class folders in a dataset.
+
+ See :class:`DatasetFolder` for details.
+ """
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+ if not classes:
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+ return classes, class_to_idx
+
+
+def make_dataset(
+ directory: Union[str, Path],
+ class_to_idx: Optional[dict[str, int]] = None,
+ extensions: Optional[Union[str, tuple[str, ...]]] = None,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+) -> list[tuple[str, int]]:
+ """Generates a list of samples of a form (path_to_sample, class).
+
+ See :class:`DatasetFolder` for details.
+
+ Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
+ by default.
+ """
+ directory = os.path.expanduser(directory)
+
+ if class_to_idx is None:
+ _, class_to_idx = find_classes(directory)
+ elif not class_to_idx:
+ raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
+
+ both_none = extensions is None and is_valid_file is None
+ both_something = extensions is not None and is_valid_file is not None
+ if both_none or both_something:
+ raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+
+ if extensions is not None:
+
+ def is_valid_file(x: str) -> bool:
+ return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
+
+ is_valid_file = cast(Callable[[str], bool], is_valid_file)
+
+ instances = []
+ available_classes = set()
+ for target_class in sorted(class_to_idx.keys()):
+ class_index = class_to_idx[target_class]
+ target_dir = os.path.join(directory, target_class)
+ if not os.path.isdir(target_dir):
+ continue
+ for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
+ for fname in sorted(fnames):
+ path = os.path.join(root, fname)
+ if is_valid_file(path):
+ item = path, class_index
+ instances.append(item)
+
+ if target_class not in available_classes:
+ available_classes.add(target_class)
+
+ empty_classes = set(class_to_idx.keys()) - available_classes
+ if empty_classes and not allow_empty:
+ msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
+ if extensions is not None:
+ msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
+ raise FileNotFoundError(msg)
+
+ return instances
+
+
+class DatasetFolder(VisionDataset):
+ """A generic data loader.
+
+ This default directory structure can be customized by overriding the
+ :meth:`find_classes` method.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory path.
+ loader (callable): A function to load a sample given its path.
+ extensions (tuple[string]): A list of allowed extensions.
+ both extensions and is_valid_file should not be passed.
+ transform (callable, optional): A function/transform that takes in
+ a sample and returns a transformed version.
+ E.g, ``transforms.RandomCrop`` for images.
+ target_transform (callable, optional): A function/transform that takes
+ in the target and transforms it.
+ is_valid_file (callable, optional): A function that takes path of a file
+ and check if the file is a valid file (used to check of corrupt files)
+ both extensions and is_valid_file should not be passed.
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+ An error is raised on empty folders if False (default).
+
+ Attributes:
+ classes (list): List of the class names sorted alphabetically.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ samples (list): List of (sample path, class_index) tuples
+ targets (list): The class_index value for each image in the dataset
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ loader: Callable[[str], Any],
+ extensions: Optional[tuple[str, ...]] = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ classes, class_to_idx = self.find_classes(self.root)
+ samples = self.make_dataset(
+ self.root,
+ class_to_idx=class_to_idx,
+ extensions=extensions,
+ is_valid_file=is_valid_file,
+ allow_empty=allow_empty,
+ )
+
+ self.loader = loader
+ self.extensions = extensions
+
+ self.classes = classes
+ self.class_to_idx = class_to_idx
+ self.samples = samples
+ self.targets = [s[1] for s in samples]
+
+ @staticmethod
+ def make_dataset(
+ directory: Union[str, Path],
+ class_to_idx: dict[str, int],
+ extensions: Optional[tuple[str, ...]] = None,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+ ) -> list[tuple[str, int]]:
+ """Generates a list of samples of a form (path_to_sample, class).
+
+ This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
+
+ Args:
+ directory (str): root dataset directory, corresponding to ``self.root``.
+ class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
+ extensions (optional): A list of allowed extensions.
+ Either extensions or is_valid_file should be passed. Defaults to None.
+ is_valid_file (optional): A function that takes path of a file
+ and checks if the file is a valid file
+ (used to check of corrupt files) both extensions and
+ is_valid_file should not be passed. Defaults to None.
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+ An error is raised on empty folders if False (default).
+
+ Raises:
+ ValueError: In case ``class_to_idx`` is empty.
+ ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
+ FileNotFoundError: In case no valid file was found for any class.
+
+ Returns:
+ List[Tuple[str, int]]: samples of a form (path_to_sample, class)
+ """
+ if class_to_idx is None:
+ # prevent potential bug since make_dataset() would use the class_to_idx logic of the
+ # find_classes() function, instead of using that of the find_classes() method, which
+ # is potentially overridden and thus could have a different logic.
+ raise ValueError("The class_to_idx parameter cannot be None.")
+ return make_dataset(
+ directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
+ )
+
+ def find_classes(self, directory: Union[str, Path]) -> tuple[list[str], dict[str, int]]:
+ """Find the class folders in a dataset structured as follows::
+
+ directory/
+ ├── class_x
+ │ ├── xxx.ext
+ │ ├── xxy.ext
+ │ └── ...
+ │ └── xxz.ext
+ └── class_y
+ ├── 123.ext
+ ├── nsdf3.ext
+ └── ...
+ └── asd932_.ext
+
+ This method can be overridden to only consider
+ a subset of classes, or to adapt to a different dataset directory structure.
+
+ Args:
+ directory(str): Root directory path, corresponding to ``self.root``
+
+ Raises:
+ FileNotFoundError: If ``dir`` has no class folders.
+
+ Returns:
+ (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
+ """
+ return find_classes(directory)
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
+
+
+def pil_loader(path: Union[str, Path]) -> Image.Image:
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+ with open(path, "rb") as f:
+ img = Image.open(f)
+ return img.convert("RGB")
+
+
+# TODO: specify the return type
+def accimage_loader(path: Union[str, Path]) -> Any:
+ import accimage
+
+ try:
+ return accimage.Image(path)
+ except OSError:
+ # Potentially a decoding problem, fall back to PIL.Image
+ return pil_loader(path)
+
+
+def default_loader(path: Union[str, Path]) -> Any:
+ from torchvision import get_image_backend
+
+ if get_image_backend() == "accimage":
+ return accimage_loader(path)
+ else:
+ return pil_loader(path)
+
+
+class ImageFolder(DatasetFolder):
+ """A generic data loader where the images are arranged in this way by default: ::
+
+ root/dog/xxx.png
+ root/dog/xxy.png
+ root/dog/[...]/xxz.png
+
+ root/cat/123.png
+ root/cat/nsdf3.png
+ root/cat/[...]/asd932_.png
+
+ This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
+ the same methods can be overridden to customize the dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory path.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ is_valid_file (callable, optional): A function that takes path of an Image file
+ and check if the file is a valid file (used to check of corrupt files)
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+ An error is raised on empty folders if False (default).
+
+ Attributes:
+ classes (list): List of the class names sorted alphabetically.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ imgs (list): List of (image path, class_index) tuples
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+ ):
+ super().__init__(
+ root,
+ loader,
+ IMG_EXTENSIONS if is_valid_file is None else None,
+ transform=transform,
+ target_transform=target_transform,
+ is_valid_file=is_valid_file,
+ allow_empty=allow_empty,
+ )
+ self.imgs = self.samples
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/food101.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/food101.py
new file mode 100644
index 0000000000000000000000000000000000000000..fee23680b05255029c1e3b433e7890df754f0fe0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/food101.py
@@ -0,0 +1,98 @@
+import json
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Food101(VisionDataset):
+ """`The Food-101 Data Set `_.
+
+ The Food-101 is a challenging data set of 101 food categories with 101,000 images.
+ For each class, 250 manually reviewed test images are provided as well as 750 training images.
+ On purpose, the training images were not cleaned, and thus still contain some amount of noise.
+ This comes mostly in the form of intense colors and sometimes wrong labels. All images were
+ rescaled to have a maximum side length of 512 pixels.
+
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again. Default is False.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
+ _MD5 = "85eeb15f3717b99a5da872d97d918f87"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[Union[str, Path]], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self._split = verify_str_arg(split, "split", ("train", "test"))
+ self._base_folder = Path(self.root) / "food-101"
+ self._meta_folder = self._base_folder / "meta"
+ self._images_folder = self._base_folder / "images"
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ self._labels = []
+ self._image_files = []
+ with open(self._meta_folder / f"{split}.json") as f:
+ metadata = json.loads(f.read())
+
+ self.classes = sorted(metadata.keys())
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+ for class_label, im_rel_paths in metadata.items():
+ self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
+ self._image_files += [
+ self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
+ ]
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._image_files)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_file, label = self._image_files[idx], self._labels[idx]
+ image = self.loader(image_file)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def extra_repr(self) -> str:
+ return f"split={self._split}"
+
+ def _check_exists(self) -> bool:
+ return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+ download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/gtsrb.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/gtsrb.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6b60116c401dd7819f527f095990dea2193b8ec
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/gtsrb.py
@@ -0,0 +1,103 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+import PIL
+
+from .folder import make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class GTSRB(VisionDataset):
+ """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``.
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self._split = verify_str_arg(split, "split", ("train", "test"))
+ self._base_folder = pathlib.Path(root) / "gtsrb"
+ self._target_folder = (
+ self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
+ )
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ if self._split == "train":
+ samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
+ else:
+ with open(self._base_folder / "GT-final_test.csv") as csv_file:
+ samples = [
+ (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
+ for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
+ ]
+
+ self._samples = samples
+ self.transform = transform
+ self.target_transform = target_transform
+
+ def __len__(self) -> int:
+ return len(self._samples)
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+
+ path, target = self._samples[index]
+ sample = PIL.Image.open(path).convert("RGB")
+
+ if self.transform is not None:
+ sample = self.transform(sample)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def _check_exists(self) -> bool:
+ return self._target_folder.is_dir()
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
+
+ if self._split == "train":
+ download_and_extract_archive(
+ f"{base_url}GTSRB-Training_fixed.zip",
+ download_root=str(self._base_folder),
+ md5="513f3c79a4c5141765e10e952eaa2478",
+ )
+ else:
+ download_and_extract_archive(
+ f"{base_url}GTSRB_Final_Test_Images.zip",
+ download_root=str(self._base_folder),
+ md5="c7e4e6327067d32654124b0fe9e82185",
+ )
+ download_and_extract_archive(
+ f"{base_url}GTSRB_Final_Test_GT.zip",
+ download_root=str(self._base_folder),
+ md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/hmdb51.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/hmdb51.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9b84771cac21e41cc27b2e18f18922ec7e74952
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/hmdb51.py
@@ -0,0 +1,152 @@
+import glob
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class HMDB51(VisionDataset):
+ """
+ `HMDB51 `_
+ dataset.
+
+ HMDB51 is an action recognition video dataset.
+ This dataset consider every video as a collection of video clips of fixed size, specified
+ by ``frames_per_clip``, where the step in frames between each clip is given by
+ ``step_between_clips``.
+
+ To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+ and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+ elements will come from video 1, and the next three elements from video 2.
+ Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+ frames in a video might be present.
+
+ Internally, it uses a VideoClips object to handle clip creation.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the HMDB51 Dataset.
+ annotation_path (str): Path to the folder containing the split files.
+ frames_per_clip (int): Number of frames in a clip.
+ step_between_clips (int): Number of frames between each clip.
+ fold (int, optional): Which fold to use. Should be between 1 and 3.
+ train (bool, optional): If ``True``, creates a dataset from the train split,
+ otherwise from the ``test`` split.
+ transform (callable, optional): A function/transform that takes in a TxHxWxC video
+ and returns a transformed version.
+ output_format (str, optional): The format of the output video tensors (before transforms).
+ Can be either "THWC" (default) or "TCHW".
+
+ Returns:
+ tuple: A 3-tuple with the following entries:
+
+ - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+ - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+ and `L` is the number of points
+ - label (int): class of the video clip
+ """
+
+ data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
+ splits = {
+ "url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
+ "md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
+ }
+ TRAIN_TAG = 1
+ TEST_TAG = 2
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ annotation_path: str,
+ frames_per_clip: int,
+ step_between_clips: int = 1,
+ frame_rate: Optional[int] = None,
+ fold: int = 1,
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ _precomputed_metadata: Optional[dict[str, Any]] = None,
+ num_workers: int = 1,
+ _video_width: int = 0,
+ _video_height: int = 0,
+ _video_min_dimension: int = 0,
+ _audio_samples: int = 0,
+ output_format: str = "THWC",
+ ) -> None:
+ super().__init__(root)
+ if fold not in (1, 2, 3):
+ raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+ extensions = ("avi",)
+ self.classes, class_to_idx = find_classes(self.root)
+ self.samples = make_dataset(
+ self.root,
+ class_to_idx,
+ extensions,
+ )
+
+ video_paths = [path for (path, _) in self.samples]
+ video_clips = VideoClips(
+ video_paths,
+ frames_per_clip,
+ step_between_clips,
+ frame_rate,
+ _precomputed_metadata,
+ num_workers=num_workers,
+ _video_width=_video_width,
+ _video_height=_video_height,
+ _video_min_dimension=_video_min_dimension,
+ _audio_samples=_audio_samples,
+ output_format=output_format,
+ )
+ # we bookkeep the full version of video clips because we want to be able
+ # to return the metadata of full version rather than the subset version of
+ # video clips
+ self.full_video_clips = video_clips
+ self.fold = fold
+ self.train = train
+ self.indices = self._select_fold(video_paths, annotation_path, fold, train)
+ self.video_clips = video_clips.subset(self.indices)
+ self.transform = transform
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ return self.full_video_clips.metadata
+
+ def _select_fold(self, video_list: list[str], annotations_dir: str, fold: int, train: bool) -> list[int]:
+ target_tag = self.TRAIN_TAG if train else self.TEST_TAG
+ split_pattern_name = f"*test_split{fold}.txt"
+ split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
+ annotation_paths = glob.glob(split_pattern_path)
+ selected_files = set()
+ for filepath in annotation_paths:
+ with open(filepath) as fid:
+ lines = fid.readlines()
+ for line in lines:
+ video_filename, tag_string = line.split()
+ tag = int(tag_string)
+ if tag == target_tag:
+ selected_files.add(video_filename)
+
+ indices = []
+ for video_index, video_path in enumerate(video_list):
+ if os.path.basename(video_path) in selected_files:
+ indices.append(video_index)
+
+ return indices
+
+ def __len__(self) -> int:
+ return self.video_clips.num_clips()
+
+ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, int]:
+ video, audio, _, video_idx = self.video_clips.get_clip(idx)
+ sample_index = self.indices[video_idx]
+ _, class_index = self.samples[sample_index]
+
+ if self.transform is not None:
+ video = self.transform(video)
+
+ return video, audio, class_index
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/imagenet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1808dc4f85b0bb77ac2fa469f17b5f903621f608
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/imagenet.py
@@ -0,0 +1,222 @@
+import os
+import shutil
+import tempfile
+from collections.abc import Iterator
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Optional, Union
+
+import torch
+
+from .folder import ImageFolder
+from .utils import check_integrity, extract_archive, verify_str_arg
+
+ARCHIVE_META = {
+ "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
+ "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
+ "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
+}
+
+META_FILE = "meta.bin"
+
+
+class ImageNet(ImageFolder):
+ """`ImageNet `_ 2012 Classification Dataset.
+
+ .. note::
+ Before using this class, it is required to download ImageNet 2012 dataset from
+ `here `_ and
+ place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
+ or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset.
+ split (string, optional): The dataset split, supports ``train``, or ``val``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+
+ Attributes:
+ classes (list): List of the class name tuples.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ wnids (list): List of the WordNet IDs.
+ wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
+ imgs (list): List of (image path, class_index) tuples
+ targets (list): The class_index value for each image in the dataset
+ """
+
+ def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None:
+ root = self.root = os.path.expanduser(root)
+ self.split = verify_str_arg(split, "split", ("train", "val"))
+
+ self.parse_archives()
+ wnid_to_classes = load_meta_file(self.root)[0]
+
+ super().__init__(self.split_folder, **kwargs)
+ self.root = root
+
+ self.wnids = self.classes
+ self.wnid_to_idx = self.class_to_idx
+ self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
+ self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
+
+ def parse_archives(self) -> None:
+ if not check_integrity(os.path.join(self.root, META_FILE)):
+ parse_devkit_archive(self.root)
+
+ if not os.path.isdir(self.split_folder):
+ if self.split == "train":
+ parse_train_archive(self.root)
+ elif self.split == "val":
+ parse_val_archive(self.root)
+
+ @property
+ def split_folder(self) -> str:
+ return os.path.join(self.root, self.split)
+
+ def extra_repr(self) -> str:
+ return "Split: {split}".format(**self.__dict__)
+
+
+def load_meta_file(root: Union[str, Path], file: Optional[str] = None) -> tuple[dict[str, str], list[str]]:
+ if file is None:
+ file = META_FILE
+ file = os.path.join(root, file)
+
+ if check_integrity(file):
+ return torch.load(file, weights_only=True)
+ else:
+ msg = (
+ "The meta file {} is not present in the root directory or is corrupted. "
+ "This file is automatically created by the ImageNet dataset."
+ )
+ raise RuntimeError(msg.format(file, root))
+
+
+def _verify_archive(root: Union[str, Path], file: str, md5: str) -> None:
+ if not check_integrity(os.path.join(root, file), md5):
+ msg = (
+ "The archive {} is not present in the root directory or is corrupted. "
+ "You need to download it externally and place it in {}."
+ )
+ raise RuntimeError(msg.format(file, root))
+
+
+def parse_devkit_archive(root: Union[str, Path], file: Optional[str] = None) -> None:
+ """Parse the devkit archive of the ImageNet2012 classification dataset and save
+ the meta information in a binary file.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory containing the devkit archive
+ file (str, optional): Name of devkit archive. Defaults to
+ 'ILSVRC2012_devkit_t12.tar.gz'
+ """
+ import scipy.io as sio
+
+ def parse_meta_mat(devkit_root: str) -> tuple[dict[int, str], dict[str, tuple[str, ...]]]:
+ metafile = os.path.join(devkit_root, "data", "meta.mat")
+ meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
+ nums_children = list(zip(*meta))[4]
+ meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
+ idcs, wnids, classes = list(zip(*meta))[:3]
+ classes = [tuple(clss.split(", ")) for clss in classes]
+ idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
+ wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
+ return idx_to_wnid, wnid_to_classes
+
+ def parse_val_groundtruth_txt(devkit_root: str) -> list[int]:
+ file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
+ with open(file) as txtfh:
+ val_idcs = txtfh.readlines()
+ return [int(val_idx) for val_idx in val_idcs]
+
+ @contextmanager
+ def get_tmp_dir() -> Iterator[str]:
+ tmp_dir = tempfile.mkdtemp()
+ try:
+ yield tmp_dir
+ finally:
+ shutil.rmtree(tmp_dir)
+
+ archive_meta = ARCHIVE_META["devkit"]
+ if file is None:
+ file = archive_meta[0]
+ md5 = archive_meta[1]
+
+ _verify_archive(root, file, md5)
+
+ with get_tmp_dir() as tmp_dir:
+ extract_archive(os.path.join(root, file), tmp_dir)
+
+ devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
+ idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
+ val_idcs = parse_val_groundtruth_txt(devkit_root)
+ val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
+
+ torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
+
+
+def parse_train_archive(root: Union[str, Path], file: Optional[str] = None, folder: str = "train") -> None:
+ """Parse the train images archive of the ImageNet2012 classification dataset and
+ prepare it for usage with the ImageNet dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory containing the train images archive
+ file (str, optional): Name of train images archive. Defaults to
+ 'ILSVRC2012_img_train.tar'
+ folder (str, optional): Optional name for train images folder. Defaults to
+ 'train'
+ """
+ archive_meta = ARCHIVE_META["train"]
+ if file is None:
+ file = archive_meta[0]
+ md5 = archive_meta[1]
+
+ _verify_archive(root, file, md5)
+
+ train_root = os.path.join(root, folder)
+ extract_archive(os.path.join(root, file), train_root)
+
+ archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
+ for archive in archives:
+ extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
+
+
+def parse_val_archive(
+ root: Union[str, Path], file: Optional[str] = None, wnids: Optional[list[str]] = None, folder: str = "val"
+) -> None:
+ """Parse the validation images archive of the ImageNet2012 classification dataset
+ and prepare it for usage with the ImageNet dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory containing the validation images archive
+ file (str, optional): Name of validation images archive. Defaults to
+ 'ILSVRC2012_img_val.tar'
+ wnids (list, optional): List of WordNet IDs of the validation images. If None
+ is given, the IDs are loaded from the meta file in the root directory
+ folder (str, optional): Optional name for validation images folder. Defaults to
+ 'val'
+ """
+ archive_meta = ARCHIVE_META["val"]
+ if file is None:
+ file = archive_meta[0]
+ md5 = archive_meta[1]
+ if wnids is None:
+ wnids = load_meta_file(root)[1]
+
+ _verify_archive(root, file, md5)
+
+ val_root = os.path.join(root, folder)
+ extract_archive(os.path.join(root, file), val_root)
+
+ images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
+
+ for wnid in set(wnids):
+ os.mkdir(os.path.join(val_root, wnid))
+
+ for wnid, img_file in zip(wnids, images):
+ shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/imagenette.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/imagenette.py
new file mode 100644
index 0000000000000000000000000000000000000000..16bac9bfadcb99ebf16736cfa89bebc1dcc32e46
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/imagenette.py
@@ -0,0 +1,104 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, find_classes, make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Imagenette(VisionDataset):
+ """`Imagenette `_ image classification dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the Imagenette dataset.
+ split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``.
+ size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``.
+ download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+ downloaded archives are not downloaded again.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+
+ Attributes:
+ classes (list): List of the class name tuples.
+ class_to_idx (dict): Dict with items (class name, class index).
+ wnids (list): List of the WordNet IDs.
+ wnid_to_idx (dict): Dict with items (WordNet ID, class index).
+ """
+
+ _ARCHIVES = {
+ "full": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", "fe2fc210e6bb7c5664d602c3cd71e612"),
+ "320px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", "3df6f0d01a2c9592104656642f5e78a3"),
+ "160px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", "e793b78cc4c9e9a4ccc0c1155377a412"),
+ }
+ _WNID_TO_CLASS = {
+ "n01440764": ("tench", "Tinca tinca"),
+ "n02102040": ("English springer", "English springer spaniel"),
+ "n02979186": ("cassette player",),
+ "n03000684": ("chain saw", "chainsaw"),
+ "n03028079": ("church", "church building"),
+ "n03394916": ("French horn", "horn"),
+ "n03417042": ("garbage truck", "dustcart"),
+ "n03425413": ("gas pump", "gasoline pump", "petrol pump", "island dispenser"),
+ "n03445777": ("golf ball",),
+ "n03888257": ("parachute", "chute"),
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ size: str = "full",
+ download=False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self._split = verify_str_arg(split, "split", ["train", "val"])
+ self._size = verify_str_arg(size, "size", ["full", "320px", "160px"])
+
+ self._url, self._md5 = self._ARCHIVES[self._size]
+ self._size_root = Path(self.root) / Path(self._url).stem
+ self._image_root = str(self._size_root / self._split)
+
+ if download:
+ self._download()
+ elif not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+ self.wnids, self.wnid_to_idx = find_classes(self._image_root)
+ self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids]
+ self.class_to_idx = {
+ class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid]
+ }
+ self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg")
+ self.loader = loader
+
+ def _check_exists(self) -> bool:
+ return self._size_root.exists()
+
+ def _download(self):
+ if self._check_exists():
+ return
+
+ download_and_extract_archive(self._url, self.root, md5=self._md5)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ path, label = self._samples[idx]
+ image = self.loader(path)
+
+ if self.transform is not None:
+ image = self.transform(image)
+
+ if self.target_transform is not None:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def __len__(self) -> int:
+ return len(self._samples)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/inaturalist.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/inaturalist.py
new file mode 100644
index 0000000000000000000000000000000000000000..a47483e158d04830b607d2f2cca42650f5b077e7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/inaturalist.py
@@ -0,0 +1,245 @@
+import os
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
+
+DATASET_URLS = {
+ "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
+ "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
+ "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
+ "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
+ "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
+ "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
+}
+
+DATASET_MD5 = {
+ "2017": "7c784ea5e424efaec655bd392f87301f",
+ "2018": "b1c6952ce38f31868cc50ea72d066cc3",
+ "2019": "c60a6e2962c9b8ccbd458d12c8582644",
+ "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
+ "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
+ "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
+}
+
+
+class INaturalist(VisionDataset):
+ """`iNaturalist `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where the image files are stored.
+ This class does not require/use annotation files.
+ version (string, optional): Which version of the dataset to download/use. One of
+ '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
+ Default: `2021_train`.
+ target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
+
+ - ``full``: the full category (species)
+ - ``kingdom``: e.g. "Animalia"
+ - ``phylum``: e.g. "Arthropoda"
+ - ``class``: e.g. "Insecta"
+ - ``order``: e.g. "Coleoptera"
+ - ``family``: e.g. "Cleridae"
+ - ``genus``: e.g. "Trichodes"
+
+ for 2017-2019 versions, one of:
+
+ - ``full``: the full (numeric) category
+ - ``super``: the super category, e.g. "Amphibians"
+
+ Can also be a list to output a tuple with all specified target types.
+ Defaults to ``full``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ version: str = "2021_train",
+ target_type: Union[list[str], str] = "full",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Optional[Callable[[Union[str, Path]], Any]] = None,
+ ) -> None:
+ self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
+
+ super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
+
+ os.makedirs(root, exist_ok=True)
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ self.all_categories: list[str] = []
+
+ # map: category type -> name of category -> index
+ self.categories_index: dict[str, dict[str, int]] = {}
+
+ # list indexed by category id, containing mapping from category type -> index
+ self.categories_map: list[dict[str, int]] = []
+
+ if not isinstance(target_type, list):
+ target_type = [target_type]
+ if self.version[:4] == "2021":
+ self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
+ self._init_2021()
+ else:
+ self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
+ self._init_pre2021()
+
+ # index of all files: (full category id, filename)
+ self.index: list[tuple[int, str]] = []
+
+ for dir_index, dir_name in enumerate(self.all_categories):
+ files = os.listdir(os.path.join(self.root, dir_name))
+ for fname in files:
+ self.index.append((dir_index, fname))
+
+ self.loader = loader
+
+ def _init_2021(self) -> None:
+ """Initialize based on 2021 layout"""
+
+ self.all_categories = sorted(os.listdir(self.root))
+
+ # map: category type -> name of category -> index
+ self.categories_index = {k: {} for k in CATEGORIES_2021}
+
+ for dir_index, dir_name in enumerate(self.all_categories):
+ pieces = dir_name.split("_")
+ if len(pieces) != 8:
+ raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
+ if pieces[0] != f"{dir_index:05d}":
+ raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
+ cat_map = {}
+ for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
+ if name in self.categories_index[cat]:
+ cat_id = self.categories_index[cat][name]
+ else:
+ cat_id = len(self.categories_index[cat])
+ self.categories_index[cat][name] = cat_id
+ cat_map[cat] = cat_id
+ self.categories_map.append(cat_map)
+
+ def _init_pre2021(self) -> None:
+ """Initialize based on 2017-2019 layout"""
+
+ # map: category type -> name of category -> index
+ self.categories_index = {"super": {}}
+
+ cat_index = 0
+ super_categories = sorted(os.listdir(self.root))
+ for sindex, scat in enumerate(super_categories):
+ self.categories_index["super"][scat] = sindex
+ subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
+ for subcat in subcategories:
+ if self.version == "2017":
+ # this version does not use ids as directory names
+ subcat_i = cat_index
+ cat_index += 1
+ else:
+ try:
+ subcat_i = int(subcat)
+ except ValueError:
+ raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
+ if subcat_i >= len(self.categories_map):
+ old_len = len(self.categories_map)
+ self.categories_map.extend([{}] * (subcat_i - old_len + 1))
+ self.all_categories.extend([""] * (subcat_i - old_len + 1))
+ if self.categories_map[subcat_i]:
+ raise RuntimeError(f"Duplicate category {subcat}")
+ self.categories_map[subcat_i] = {"super": sindex}
+ self.all_categories[subcat_i] = os.path.join(scat, subcat)
+
+ # validate the dictionary
+ for cindex, c in enumerate(self.categories_map):
+ if not c:
+ raise RuntimeError(f"Missing category {cindex}")
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where the type of target specified by target_type.
+ """
+
+ cat_id, fname = self.index[index]
+ image_path = os.path.join(self.root, self.all_categories[cat_id], fname)
+ img = self.loader(image_path) if self.loader is not None else Image.open(image_path)
+
+ target: Any = []
+ for t in self.target_type:
+ if t == "full":
+ target.append(cat_id)
+ else:
+ target.append(self.categories_map[cat_id][t])
+ target = tuple(target) if len(target) > 1 else target[0]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.index)
+
+ def category_name(self, category_type: str, category_id: int) -> str:
+ """
+ Args:
+ category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
+ category_id(int): an index (class id) from this category
+
+ Returns:
+ the name of the category
+ """
+ if category_type == "full":
+ return self.all_categories[category_id]
+ else:
+ if category_type not in self.categories_index:
+ raise ValueError(f"Invalid category type '{category_type}'")
+ else:
+ for name, id in self.categories_index[category_type].items():
+ if id == category_id:
+ return name
+ raise ValueError(f"Invalid category id {category_id} for {category_type}")
+
+ def _check_exists(self) -> bool:
+ return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ base_root = os.path.dirname(self.root)
+
+ download_and_extract_archive(
+ DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
+ )
+
+ orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
+ if not os.path.exists(orig_dir_name):
+ raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
+ os.rename(orig_dir_name, self.root)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/kinetics.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/kinetics.py
new file mode 100644
index 0000000000000000000000000000000000000000..c568e46a62d5d8f92c0bfcdb7ce79b6b60f234ce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/kinetics.py
@@ -0,0 +1,237 @@
+import csv
+import os
+import urllib
+from functools import partial
+from multiprocessing import Pool
+from os import path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+def _dl_wrap(tarpath: Union[str, Path], videopath: Union[str, Path], line: str) -> None:
+ download_and_extract_archive(line, tarpath, videopath)
+
+
+class Kinetics(VisionDataset):
+ """`Generic Kinetics `_
+ dataset.
+
+ Kinetics-400/600/700 are action recognition video datasets.
+ This dataset consider every video as a collection of video clips of fixed size, specified
+ by ``frames_per_clip``, where the step in frames between each clip is given by
+ ``step_between_clips``.
+
+ To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+ and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+ elements will come from video 1, and the next three elements from video 2.
+ Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+ frames in a video might be present.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the Kinetics Dataset.
+ Directory should be structured as follows:
+ .. code::
+
+ root/
+ ├── split
+ │ ├── class1
+ │ │ ├── vid1.mp4
+ │ │ ├── vid2.mp4
+ │ │ ├── vid3.mp4
+ │ │ ├── ...
+ │ ├── class2
+ │ │ ├── vidx.mp4
+ │ │ └── ...
+
+ Note: split is appended automatically using the split argument.
+ frames_per_clip (int): number of frames in a clip
+ num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700
+ split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` ``"test"``
+ frame_rate (float): If omitted, interpolate different frame rate for each clip.
+ step_between_clips (int): number of frames between each clip
+ transform (callable, optional): A function/transform that takes in a TxHxWxC video
+ and returns a transformed version.
+ download (bool): Download the official version of the dataset to root folder.
+ num_workers (int): Use multiple workers for VideoClips creation
+ num_download_workers (int): Use multiprocessing in order to speed up download.
+ output_format (str, optional): The format of the output video tensors (before transforms).
+ Can be either "THWC" or "TCHW" (default).
+ Note that in most other utils and datasets, the default is actually "THWC".
+
+ Returns:
+ tuple: A 3-tuple with the following entries:
+
+ - video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
+ - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+ and `L` is the number of points in torch.float tensor
+ - label (int): class of the video clip
+
+ Raises:
+ RuntimeError: If ``download is True`` and the video archives are already extracted.
+ """
+
+ _TAR_URLS = {
+ "400": "https://s3.amazonaws.com/kinetics/400/{split}/k400_{split}_path.txt",
+ "600": "https://s3.amazonaws.com/kinetics/600/{split}/k600_{split}_path.txt",
+ "700": "https://s3.amazonaws.com/kinetics/700_2020/{split}/k700_2020_{split}_path.txt",
+ }
+ _ANNOTATION_URLS = {
+ "400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv",
+ "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.csv",
+ "700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv",
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ frames_per_clip: int,
+ num_classes: str = "400",
+ split: str = "train",
+ frame_rate: Optional[int] = None,
+ step_between_clips: int = 1,
+ transform: Optional[Callable] = None,
+ extensions: tuple[str, ...] = ("avi", "mp4"),
+ download: bool = False,
+ num_download_workers: int = 1,
+ num_workers: int = 1,
+ _precomputed_metadata: Optional[dict[str, Any]] = None,
+ _video_width: int = 0,
+ _video_height: int = 0,
+ _video_min_dimension: int = 0,
+ _audio_samples: int = 0,
+ _audio_channels: int = 0,
+ _legacy: bool = False,
+ output_format: str = "TCHW",
+ ) -> None:
+
+ # TODO: support test
+ self.num_classes = verify_str_arg(num_classes, arg="num_classes", valid_values=["400", "600", "700"])
+ self.extensions = extensions
+ self.num_download_workers = num_download_workers
+
+ self.root = root
+ self._legacy = _legacy
+
+ if _legacy:
+ self.split_folder = root
+ self.split = "unknown"
+ output_format = "THWC"
+ if download:
+ raise ValueError("Cannot download the videos using legacy_structure.")
+ else:
+ self.split_folder = path.join(root, split)
+ self.split = verify_str_arg(split, arg="split", valid_values=["train", "val", "test"])
+
+ if download:
+ self.download_and_process_videos()
+
+ super().__init__(self.root)
+
+ self.classes, class_to_idx = find_classes(self.split_folder)
+ self.samples = make_dataset(self.split_folder, class_to_idx, extensions, is_valid_file=None)
+ video_list = [x[0] for x in self.samples]
+ self.video_clips = VideoClips(
+ video_list,
+ frames_per_clip,
+ step_between_clips,
+ frame_rate,
+ _precomputed_metadata,
+ num_workers=num_workers,
+ _video_width=_video_width,
+ _video_height=_video_height,
+ _video_min_dimension=_video_min_dimension,
+ _audio_samples=_audio_samples,
+ _audio_channels=_audio_channels,
+ output_format=output_format,
+ )
+ self.transform = transform
+
+ def download_and_process_videos(self) -> None:
+ """Downloads all the videos to the _root_ folder in the expected format."""
+ self._download_videos()
+ self._make_ds_structure()
+
+ def _download_videos(self) -> None:
+ """download tarballs containing the video to "tars" folder and extract them into the _split_ folder where
+ split is one of the official dataset splits.
+
+ Raises:
+ RuntimeError: if download folder exists, break to prevent downloading entire dataset again.
+ """
+ if path.exists(self.split_folder):
+ return
+ tar_path = path.join(self.root, "tars")
+ file_list_path = path.join(self.root, "files")
+
+ split_url = self._TAR_URLS[self.num_classes].format(split=self.split)
+ split_url_filepath = path.join(file_list_path, path.basename(split_url))
+ if not check_integrity(split_url_filepath):
+ download_url(split_url, file_list_path)
+ with open(split_url_filepath) as file:
+ list_video_urls = [urllib.parse.quote(line, safe="/,:") for line in file.read().splitlines()]
+
+ if self.num_download_workers == 1:
+ for line in list_video_urls:
+ download_and_extract_archive(line, tar_path, self.split_folder)
+ else:
+ part = partial(_dl_wrap, tar_path, self.split_folder)
+ poolproc = Pool(self.num_download_workers)
+ poolproc.map(part, list_video_urls)
+
+ def _make_ds_structure(self) -> None:
+ """move videos from
+ split_folder/
+ ├── clip1.avi
+ ├── clip2.avi
+
+ to the correct format as described below:
+ split_folder/
+ ├── class1
+ │ ├── clip1.avi
+
+ """
+ annotation_path = path.join(self.root, "annotations")
+ if not check_integrity(path.join(annotation_path, f"{self.split}.csv")):
+ download_url(self._ANNOTATION_URLS[self.num_classes].format(split=self.split), annotation_path)
+ annotations = path.join(annotation_path, f"{self.split}.csv")
+
+ file_fmtstr = "{ytid}_{start:06}_{end:06}.mp4"
+ with open(annotations) as csvfile:
+ reader = csv.DictReader(csvfile)
+ for row in reader:
+ f = file_fmtstr.format(
+ ytid=row["youtube_id"],
+ start=int(row["time_start"]),
+ end=int(row["time_end"]),
+ )
+ label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "")
+ os.makedirs(path.join(self.split_folder, label), exist_ok=True)
+ downloaded_file = path.join(self.split_folder, f)
+ if path.isfile(downloaded_file):
+ os.replace(
+ downloaded_file,
+ path.join(self.split_folder, label, f),
+ )
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ return self.video_clips.metadata
+
+ def __len__(self) -> int:
+ return self.video_clips.num_clips()
+
+ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, int]:
+ video, audio, info, video_idx = self.video_clips.get_clip(idx)
+ label = self.samples[video_idx][1]
+
+ if self.transform is not None:
+ video = self.transform(video)
+
+ return video, audio, label
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/kitti.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/kitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..d275248d92a5cad4efe8dfaf7ec89c6dda6dd8ef
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/kitti.py
@@ -0,0 +1,158 @@
+import csv
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class Kitti(VisionDataset):
+ """`KITTI `_ Dataset.
+
+ It corresponds to the "left color images of object" dataset, for object detection.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
+ Expects the following folder structure if download=False:
+
+ .. code::
+
+
+ └── Kitti
+ └─ raw
+ ├── training
+ | ├── image_2
+ | └── label_2
+ └── testing
+ └── image_2
+ train (bool, optional): Use ``train`` split if true, else ``test`` split.
+ Defaults to ``train``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample
+ and its target as entry and returns a transformed version.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+
+ data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
+ resources = [
+ "data_object_image_2.zip",
+ "data_object_label_2.zip",
+ ]
+ image_dir_name = "image_2"
+ labels_dir_name = "label_2"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ transforms: Optional[Callable] = None,
+ download: bool = False,
+ ):
+ super().__init__(
+ root,
+ transform=transform,
+ target_transform=target_transform,
+ transforms=transforms,
+ )
+ self.images = []
+ self.targets = []
+ self.train = train
+ self._location = "training" if self.train else "testing"
+
+ if download:
+ self.download()
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You may use download=True to download it.")
+
+ image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
+ if self.train:
+ labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
+ for img_file in os.listdir(image_dir):
+ self.images.append(os.path.join(image_dir, img_file))
+ if self.train:
+ self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """Get item at a given index.
+
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target), where
+ target is a list of dictionaries with the following keys:
+
+ - type: str
+ - truncated: float
+ - occluded: int
+ - alpha: float
+ - bbox: float[4]
+ - dimensions: float[3]
+ - locations: float[3]
+ - rotation_y: float
+
+ """
+ image = Image.open(self.images[index])
+ target = self._parse_target(index) if self.train else None
+ if self.transforms:
+ image, target = self.transforms(image, target)
+ return image, target
+
+ def _parse_target(self, index: int) -> list:
+ target = []
+ with open(self.targets[index]) as inp:
+ content = csv.reader(inp, delimiter=" ")
+ for line in content:
+ target.append(
+ {
+ "type": line[0],
+ "truncated": float(line[1]),
+ "occluded": int(line[2]),
+ "alpha": float(line[3]),
+ "bbox": [float(x) for x in line[4:8]],
+ "dimensions": [float(x) for x in line[8:11]],
+ "location": [float(x) for x in line[11:14]],
+ "rotation_y": float(line[14]),
+ }
+ )
+ return target
+
+ def __len__(self) -> int:
+ return len(self.images)
+
+ @property
+ def _raw_folder(self) -> str:
+ return os.path.join(self.root, self.__class__.__name__, "raw")
+
+ def _check_exists(self) -> bool:
+ """Check if the data directory exists."""
+ folders = [self.image_dir_name]
+ if self.train:
+ folders.append(self.labels_dir_name)
+ return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
+
+ def download(self) -> None:
+ """Download the KITTI data if it doesn't exist already."""
+
+ if self._check_exists():
+ return
+
+ os.makedirs(self._raw_folder, exist_ok=True)
+
+ # download files
+ for fname in self.resources:
+ download_and_extract_archive(
+ url=f"{self.data_url}{fname}",
+ download_root=self._raw_folder,
+ filename=fname,
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/lfw.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/lfw.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ff17af5328cbc0995432560c86288f405cd5a46
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/lfw.py
@@ -0,0 +1,268 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class _LFW(VisionDataset):
+
+ base_folder = "lfw-py"
+ download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
+
+ file_dict = {
+ "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
+ "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
+ "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"),
+ }
+ checksums = {
+ "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d",
+ "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b",
+ "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21",
+ "people.txt": "450f0863dd89e85e73936a6d71a3474b",
+ "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5",
+ "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21",
+ "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d",
+ }
+ annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
+ names = "lfw-names.txt"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str,
+ image_set: str,
+ view: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
+
+ self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
+ images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
+
+ self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
+ self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"])
+ self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
+ self.data: list[Any] = []
+
+ if download:
+ raise ValueError(
+ "LFW dataset is no longer available for download."
+ "Please download the dataset manually and place it in the specified directory"
+ )
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ self.images_dir = os.path.join(self.root, images_dir)
+ self._loader = loader
+
+ def _check_integrity(self) -> bool:
+ st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
+ st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
+ if not st1 or not st2:
+ return False
+ if self.view == "people":
+ return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
+ return True
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+ url = f"{self.download_url_prefix}{self.filename}"
+ download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
+ download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
+ if self.view == "people":
+ download_url(f"{self.download_url_prefix}{self.names}", self.root)
+
+ def _get_path(self, identity: str, no: Union[int, str]) -> str:
+ return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
+
+ def extra_repr(self) -> str:
+ return f"Alignment: {self.image_set}\nSplit: {self.split}"
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+
+class LFWPeople(_LFW):
+ """`LFW `_ Dataset.
+
+ .. warning:
+
+ The LFW dataset is no longer available for automatic download. Please
+ download it manually and place it in the specified directory.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``lfw-py`` exists or will be saved to if download is set to True.
+ split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+ ``10fold`` (default).
+ image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+ ``deepfunneled``. Defaults to ``funneled``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): NOT SUPPORTED ANYMORE, leave to False.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: str,
+ split: str = "10fold",
+ image_set: str = "funneled",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, split, image_set, "people", transform, target_transform, download, loader=loader)
+
+ self.class_to_idx = self._get_classes()
+ self.data, self.targets = self._get_people()
+
+ def _get_people(self) -> tuple[list[str], list[int]]:
+ data, targets = [], []
+ with open(os.path.join(self.root, self.labels_file)) as f:
+ lines = f.readlines()
+ n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
+
+ for fold in range(n_folds):
+ n_lines = int(lines[s])
+ people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]]
+ s += n_lines + 1
+ for i, (identity, num_imgs) in enumerate(people):
+ for num in range(1, int(num_imgs) + 1):
+ img = self._get_path(identity, num)
+ data.append(img)
+ targets.append(self.class_to_idx[identity])
+
+ return data, targets
+
+ def _get_classes(self) -> dict[str, int]:
+ with open(os.path.join(self.root, self.names)) as f:
+ lines = f.readlines()
+ names = [line.strip().split()[0] for line in lines]
+ class_to_idx = {name: i for i, name in enumerate(names)}
+ return class_to_idx
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target) where target is the identity of the person.
+ """
+ img = self._loader(self.data[index])
+ target = self.targets[index]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def extra_repr(self) -> str:
+ return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}"
+
+
+class LFWPairs(_LFW):
+ """`LFW `_ Dataset.
+
+ .. warning:
+
+ The LFW dataset is no longer available for automatic download. Please
+ download it manually and place it in the specified directory.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``lfw-py`` exists or will be saved to if download is set to True.
+ split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+ ``10fold``. Defaults to ``10fold``.
+ image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+ ``deepfunneled``. Defaults to ``funneled``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomRotation``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): NOT SUPPORTED ANYMORE, leave to False.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+
+ """
+
+ def __init__(
+ self,
+ root: str,
+ split: str = "10fold",
+ image_set: str = "funneled",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, split, image_set, "pairs", transform, target_transform, download, loader=loader)
+
+ self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
+
+ def _get_pairs(self, images_dir: str) -> tuple[list[tuple[str, str]], list[tuple[str, str]], list[int]]:
+ pair_names, data, targets = [], [], []
+ with open(os.path.join(self.root, self.labels_file)) as f:
+ lines = f.readlines()
+ if self.split == "10fold":
+ n_folds, n_pairs = lines[0].split("\t")
+ n_folds, n_pairs = int(n_folds), int(n_pairs)
+ else:
+ n_folds, n_pairs = 1, int(lines[0])
+ s = 1
+
+ for fold in range(n_folds):
+ matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]]
+ unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]]
+ s += 2 * n_pairs
+ for pair in matched_pairs:
+ img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
+ pair_names.append((pair[0], pair[0]))
+ data.append((img1, img2))
+ targets.append(same)
+ for pair in unmatched_pairs:
+ img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
+ pair_names.append((pair[0], pair[2]))
+ data.append((img1, img2))
+ targets.append(same)
+
+ return pair_names, data, targets
+
+ def __getitem__(self, index: int) -> tuple[Any, Any, int]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
+ """
+ img1, img2 = self.data[index]
+ img1, img2 = self._loader(img1), self._loader(img2)
+ target = self.targets[index]
+
+ if self.transform is not None:
+ img1, img2 = self.transform(img1), self.transform(img2)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img1, img2, target
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/lsun.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/lsun.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f6c7a5eb63c21e042b4be0e059fa5df581acbaf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/lsun.py
@@ -0,0 +1,168 @@
+import io
+import os.path
+import pickle
+import string
+from collections.abc import Iterable
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+
+from PIL import Image
+
+from .utils import iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class LSUNClass(VisionDataset):
+ def __init__(
+ self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
+ ) -> None:
+ import lmdb
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
+ with self.env.begin(write=False) as txn:
+ self.length = txn.stat()["entries"]
+ cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
+ if os.path.isfile(cache_file):
+ self.keys = pickle.load(open(cache_file, "rb"))
+ else:
+ with self.env.begin(write=False) as txn:
+ self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
+ pickle.dump(self.keys, open(cache_file, "wb"))
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ img, target = None, None
+ env = self.env
+ with env.begin(write=False) as txn:
+ imgbuf = txn.get(self.keys[index])
+
+ buf = io.BytesIO()
+ buf.write(imgbuf)
+ buf.seek(0)
+ img = Image.open(buf).convert("RGB")
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return self.length
+
+
+class LSUN(VisionDataset):
+ """`LSUN `_ dataset.
+
+ You will need to install the ``lmdb`` package to use this dataset: run
+ ``pip install lmdb``
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory for the database files.
+ classes (string or list): One of {'train', 'val', 'test'} or a list of
+ categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ classes: Union[str, list[str]] = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.classes = self._verify_classes(classes)
+
+ # for each class, create an LSUNClassDataset
+ self.dbs = []
+ for c in self.classes:
+ self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
+
+ self.indices = []
+ count = 0
+ for db in self.dbs:
+ count += len(db)
+ self.indices.append(count)
+
+ self.length = count
+
+ def _verify_classes(self, classes: Union[str, list[str]]) -> list[str]:
+ categories = [
+ "bedroom",
+ "bridge",
+ "church_outdoor",
+ "classroom",
+ "conference_room",
+ "dining_room",
+ "kitchen",
+ "living_room",
+ "restaurant",
+ "tower",
+ ]
+ dset_opts = ["train", "val", "test"]
+
+ try:
+ classes = cast(str, classes)
+ verify_str_arg(classes, "classes", dset_opts)
+ if classes == "test":
+ classes = [classes]
+ else:
+ classes = [c + "_" + classes for c in categories]
+ except ValueError:
+ if not isinstance(classes, Iterable):
+ msg = "Expected type str or Iterable for argument classes, but got type {}."
+ raise ValueError(msg.format(type(classes)))
+
+ classes = list(classes)
+ msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
+ for c in classes:
+ verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
+ c_short = c.split("_")
+ category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
+
+ msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
+ msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
+ verify_str_arg(category, valid_values=categories, custom_msg=msg)
+
+ msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
+ verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
+
+ return classes
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target) where target is the index of the target category.
+ """
+ target = 0
+ sub = 0
+ for ind in self.indices:
+ if index < ind:
+ break
+ target += 1
+ sub = ind
+
+ db = self.dbs[target]
+ index = index - sub
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ img, _ = db[index]
+ return img, target
+
+ def __len__(self) -> int:
+ return self.length
+
+ def extra_repr(self) -> str:
+ return "Classes: {classes}".format(**self.__dict__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/mnist.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..06a658cbea476aaa5a286b8902649944889998d5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/mnist.py
@@ -0,0 +1,560 @@
+import codecs
+import os
+import os.path
+import shutil
+import string
+import sys
+import warnings
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+from urllib.error import URLError
+
+import numpy as np
+import torch
+
+from ..utils import _Image_fromarray
+from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class MNIST(VisionDataset):
+ """`MNIST `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
+ and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
+ train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+ otherwise from ``t10k-images-idx3-ubyte``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ mirrors = [
+ "https://ossci-datasets.s3.amazonaws.com/mnist/",
+ "http://yann.lecun.com/exdb/mnist/",
+ ]
+
+ resources = [
+ ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
+ ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
+ ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
+ ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
+ ]
+
+ training_file = "training.pt"
+ test_file = "test.pt"
+ classes = [
+ "0 - zero",
+ "1 - one",
+ "2 - two",
+ "3 - three",
+ "4 - four",
+ "5 - five",
+ "6 - six",
+ "7 - seven",
+ "8 - eight",
+ "9 - nine",
+ ]
+
+ @property
+ def train_labels(self):
+ warnings.warn("train_labels has been renamed targets")
+ return self.targets
+
+ @property
+ def test_labels(self):
+ warnings.warn("test_labels has been renamed targets")
+ return self.targets
+
+ @property
+ def train_data(self):
+ warnings.warn("train_data has been renamed data")
+ return self.data
+
+ @property
+ def test_data(self):
+ warnings.warn("test_data has been renamed data")
+ return self.data
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.train = train # training set or test set
+
+ if self._check_legacy_exist():
+ self.data, self.targets = self._load_legacy_data()
+ return
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ self.data, self.targets = self._load_data()
+
+ def _check_legacy_exist(self):
+ processed_folder_exists = os.path.exists(self.processed_folder)
+ if not processed_folder_exists:
+ return False
+
+ return all(
+ check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
+ )
+
+ def _load_legacy_data(self):
+ # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
+ # directly.
+ data_file = self.training_file if self.train else self.test_file
+ return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True)
+
+ def _load_data(self):
+ image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
+ data = read_image_file(os.path.join(self.raw_folder, image_file))
+
+ label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
+ targets = read_label_file(os.path.join(self.raw_folder, label_file))
+
+ return data, targets
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.targets[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = _Image_fromarray(img.numpy(), mode="L")
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ @property
+ def raw_folder(self) -> str:
+ return os.path.join(self.root, self.__class__.__name__, "raw")
+
+ @property
+ def processed_folder(self) -> str:
+ return os.path.join(self.root, self.__class__.__name__, "processed")
+
+ @property
+ def class_to_idx(self) -> dict[str, int]:
+ return {_class: i for i, _class in enumerate(self.classes)}
+
+ def _check_exists(self) -> bool:
+ return all(
+ check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
+ for url, _ in self.resources
+ )
+
+ def download(self) -> None:
+ """Download the MNIST data if it doesn't exist already."""
+
+ if self._check_exists():
+ return
+
+ os.makedirs(self.raw_folder, exist_ok=True)
+
+ # download files
+ for filename, md5 in self.resources:
+ errors = []
+ for mirror in self.mirrors:
+ url = f"{mirror}{filename}"
+ try:
+ download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
+ except URLError as e:
+ errors.append(e)
+ continue
+ break
+ else:
+ s = f"Error downloading {filename}:\n"
+ for mirror, err in zip(self.mirrors, errors):
+ s += f"Tried {mirror}, got:\n{str(err)}\n"
+ raise RuntimeError(s)
+
+ def extra_repr(self) -> str:
+ split = "Train" if self.train is True else "Test"
+ return f"Split: {split}"
+
+
+class FashionMNIST(MNIST):
+ """`Fashion-MNIST `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
+ and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
+ train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+ otherwise from ``t10k-images-idx3-ubyte``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
+
+ resources = [
+ ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
+ ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
+ ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
+ ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
+ ]
+ classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
+
+
+class KMNIST(MNIST):
+ """`Kuzushiji-MNIST `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
+ and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
+ train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+ otherwise from ``t10k-images-idx3-ubyte``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
+
+ resources = [
+ ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
+ ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
+ ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
+ ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
+ ]
+ classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
+
+
+class EMNIST(MNIST):
+ """`EMNIST `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
+ and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
+ split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
+ ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
+ which one to use.
+ train (bool, optional): If True, creates dataset from ``training.pt``,
+ otherwise from ``test.pt``.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ url = "https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip"
+ md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
+ splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
+ # Merged Classes assumes Same structure for both uppercase and lowercase version
+ _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
+ _all_classes = set(string.digits + string.ascii_letters)
+ classes_split_dict = {
+ "byclass": sorted(list(_all_classes)),
+ "bymerge": sorted(list(_all_classes - _merged_classes)),
+ "balanced": sorted(list(_all_classes - _merged_classes)),
+ "letters": ["N/A"] + list(string.ascii_lowercase),
+ "digits": list(string.digits),
+ "mnist": list(string.digits),
+ }
+
+ def __init__(self, root: Union[str, Path], split: str, **kwargs: Any) -> None:
+ self.split = verify_str_arg(split, "split", self.splits)
+ self.training_file = self._training_file(split)
+ self.test_file = self._test_file(split)
+ super().__init__(root, **kwargs)
+ self.classes = self.classes_split_dict[self.split]
+
+ @staticmethod
+ def _training_file(split) -> str:
+ return f"training_{split}.pt"
+
+ @staticmethod
+ def _test_file(split) -> str:
+ return f"test_{split}.pt"
+
+ @property
+ def _file_prefix(self) -> str:
+ return f"emnist-{self.split}-{'train' if self.train else 'test'}"
+
+ @property
+ def images_file(self) -> str:
+ return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
+
+ @property
+ def labels_file(self) -> str:
+ return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
+
+ def _load_data(self):
+ return read_image_file(self.images_file), read_label_file(self.labels_file)
+
+ def _check_exists(self) -> bool:
+ return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+ def download(self) -> None:
+ """Download the EMNIST data if it doesn't exist already."""
+
+ if self._check_exists():
+ return
+
+ os.makedirs(self.raw_folder, exist_ok=True)
+
+ download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
+ gzip_folder = os.path.join(self.raw_folder, "gzip")
+ for gzip_file in os.listdir(gzip_folder):
+ if gzip_file.endswith(".gz"):
+ extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
+ shutil.rmtree(gzip_folder)
+
+
+class QMNIST(MNIST):
+ """`QMNIST `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset whose ``raw``
+ subdir contains binary files of the datasets.
+ what (string,optional): Can be 'train', 'test', 'test10k',
+ 'test50k', or 'nist' for respectively the mnist compatible
+ training set, the 60k qmnist testing set, the 10k qmnist
+ examples that match the mnist testing set, the 50k
+ remaining qmnist testing examples, or all the nist
+ digits. The default is to select 'train' or 'test'
+ according to the compatibility argument 'train'.
+ compat (bool,optional): A boolean that says whether the target
+ for each example is class number (for compatibility with
+ the MNIST dataloader) or a torch vector containing the
+ full qmnist information. Default=True.
+ train (bool,optional,compatibility): When argument 'what' is
+ not specified, this boolean decides whether to load the
+ training set or the testing set. Default: True.
+ download (bool, optional): If True, downloads the dataset from
+ the internet and puts it in root directory. If dataset is
+ already downloaded, it is not downloaded again.
+ transform (callable, optional): A function/transform that
+ takes in a PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform
+ that takes in the target and transforms it.
+ """
+
+ subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
+ resources: dict[str, list[tuple[str, str]]] = { # type: ignore[assignment]
+ "train": [
+ (
+ "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
+ "ed72d4157d28c017586c42bc6afe6370",
+ ),
+ (
+ "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
+ "0058f8dd561b90ffdd0f734c6a30e5e4",
+ ),
+ ],
+ "test": [
+ (
+ "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
+ "1394631089c404de565df7b7aeaf9412",
+ ),
+ (
+ "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
+ "5b5b05890a5e13444e108efe57b788aa",
+ ),
+ ],
+ "nist": [
+ (
+ "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
+ "7f124b3b8ab81486c9d8c2749c17f834",
+ ),
+ (
+ "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
+ "5ed0e788978e45d4a8bd4b7caec3d79d",
+ ),
+ ],
+ }
+ classes = [
+ "0 - zero",
+ "1 - one",
+ "2 - two",
+ "3 - three",
+ "4 - four",
+ "5 - five",
+ "6 - six",
+ "7 - seven",
+ "8 - eight",
+ "9 - nine",
+ ]
+
+ def __init__(
+ self, root: Union[str, Path], what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
+ ) -> None:
+ if what is None:
+ what = "train" if train else "test"
+ self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
+ self.compat = compat
+ self.data_file = what + ".pt"
+ self.training_file = self.data_file
+ self.test_file = self.data_file
+ super().__init__(root, train, **kwargs)
+
+ @property
+ def images_file(self) -> str:
+ (url, _), _ = self.resources[self.subsets[self.what]]
+ return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+ @property
+ def labels_file(self) -> str:
+ _, (url, _) = self.resources[self.subsets[self.what]]
+ return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+ def _check_exists(self) -> bool:
+ return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+ def _load_data(self):
+ data = read_sn3_pascalvincent_tensor(self.images_file)
+ if data.dtype != torch.uint8:
+ raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
+ if data.ndimension() != 3:
+ raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
+
+ targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
+ if targets.ndimension() != 2:
+ raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
+
+ if self.what == "test10k":
+ data = data[0:10000, :, :].clone()
+ targets = targets[0:10000, :].clone()
+ elif self.what == "test50k":
+ data = data[10000:, :, :].clone()
+ targets = targets[10000:, :].clone()
+
+ return data, targets
+
+ def download(self) -> None:
+ """Download the QMNIST data if it doesn't exist already.
+ Note that we only download what has been asked for (argument 'what').
+ """
+ if self._check_exists():
+ return
+
+ os.makedirs(self.raw_folder, exist_ok=True)
+ split = self.resources[self.subsets[self.what]]
+
+ for url, md5 in split:
+ download_and_extract_archive(url, self.raw_folder, md5=md5)
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ # redefined to handle the compat flag
+ img, target = self.data[index], self.targets[index]
+ img = _Image_fromarray(img.numpy(), mode="L")
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.compat:
+ target = int(target[0])
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return img, target
+
+ def extra_repr(self) -> str:
+ return f"Split: {self.what}"
+
+
+def get_int(b: bytes) -> int:
+ return int(codecs.encode(b, "hex"), 16)
+
+
+SN3_PASCALVINCENT_TYPEMAP = {
+ 8: torch.uint8,
+ 9: torch.int8,
+ 11: torch.int16,
+ 12: torch.int32,
+ 13: torch.float32,
+ 14: torch.float64,
+}
+
+
+def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
+ """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
+ Argument may be a filename, compressed filename, or file object.
+ """
+ # read
+ with open(path, "rb") as f:
+ data = f.read()
+
+ # parse
+ if sys.byteorder == "little" or sys.platform == "aix":
+ magic = get_int(data[0:4])
+ nd = magic % 256
+ ty = magic // 256
+ else:
+ nd = get_int(data[0:1])
+ ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256
+
+ assert 1 <= nd <= 3
+ assert 8 <= ty <= 14
+ torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
+ s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
+
+ if sys.byteorder == "big" and not sys.platform == "aix":
+ for i in range(len(s)):
+ s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
+
+ parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
+
+ # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
+ # that is little endian and the dtype has more than one byte, we need to flip them.
+ if sys.byteorder == "little" and parsed.element_size() > 1:
+ parsed = _flip_byte_order(parsed)
+
+ assert parsed.shape[0] == np.prod(s) or not strict
+ return parsed.view(*s)
+
+
+def read_label_file(path: str) -> torch.Tensor:
+ x = read_sn3_pascalvincent_tensor(path, strict=False)
+ if x.dtype != torch.uint8:
+ raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+ if x.ndimension() != 1:
+ raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
+ return x.long()
+
+
+def read_image_file(path: str) -> torch.Tensor:
+ x = read_sn3_pascalvincent_tensor(path, strict=False)
+ if x.dtype != torch.uint8:
+ raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+ if x.ndimension() != 3:
+ raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
+ return x
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/moving_mnist.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/moving_mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..4466d82291bfa908aff424bb66ae704289b97274
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/moving_mnist.py
@@ -0,0 +1,94 @@
+import os.path
+from pathlib import Path
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torchvision.datasets.utils import download_url, verify_str_arg
+from torchvision.datasets.vision import VisionDataset
+
+
+class MovingMNIST(VisionDataset):
+ """`MovingMNIST `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
+ split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
+ If ``split=None``, the full data is returned.
+ split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
+ frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
+ is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in a torch Tensor
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ """
+
+ _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: Optional[str] = None,
+ split_ratio: int = 10,
+ download: bool = False,
+ transform: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root, transform=transform)
+
+ self._base_folder = os.path.join(self.root, self.__class__.__name__)
+ self._filename = self._URL.split("/")[-1]
+
+ if split is not None:
+ verify_str_arg(split, "split", ("train", "test"))
+ self.split = split
+
+ if not isinstance(split_ratio, int):
+ raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
+ elif not (1 <= split_ratio <= 19):
+ raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
+ self.split_ratio = split_ratio
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+ data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
+ if self.split == "train":
+ data = data[: self.split_ratio]
+ elif self.split == "test":
+ data = data[self.split_ratio :]
+ self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
+
+ def __getitem__(self, idx: int) -> torch.Tensor:
+ """
+ Args:
+ idx (int): Index
+ Returns:
+ torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
+ """
+ data = self.data[idx]
+ if self.transform is not None:
+ data = self.transform(data)
+
+ return data
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def _check_exists(self) -> bool:
+ return os.path.exists(os.path.join(self._base_folder, self._filename))
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ download_url(
+ url=self._URL,
+ root=self._base_folder,
+ filename=self._filename,
+ md5="be083ec986bfe91a449d63653c411eb2",
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/omniglot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/omniglot.py
new file mode 100644
index 0000000000000000000000000000000000000000..22fd59aa9c2f107864eda6a79f1bea7ac643710c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/omniglot.py
@@ -0,0 +1,107 @@
+from os.path import join
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
+from .vision import VisionDataset
+
+
+class Omniglot(VisionDataset):
+ """`Omniglot `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``omniglot-py`` exists.
+ background (bool, optional): If True, creates dataset from the "background" set, otherwise
+ creates from the "evaluation" set. This terminology is defined by the authors.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset zip files from the internet and
+ puts it in root directory. If the zip files are already downloaded, they are not
+ downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ folder = "omniglot-py"
+ download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
+ zips_md5 = {
+ "images_background": "68d2efa1b9178cc56df9314c21c6e718",
+ "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ background: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Optional[Callable[[Union[str, Path]], Any]] = None,
+ ) -> None:
+ super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
+ self.background = background
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ self.target_folder = join(self.root, self._get_target_folder())
+ self._alphabets = list_dir(self.target_folder)
+ self._characters: list[str] = sum(
+ ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
+ )
+ self._character_images = [
+ [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
+ for idx, character in enumerate(self._characters)
+ ]
+ self._flat_character_images: list[tuple[str, int]] = sum(self._character_images, [])
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._flat_character_images)
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target character class.
+ """
+ image_name, character_class = self._flat_character_images[index]
+ image_path = join(self.target_folder, self._characters[character_class], image_name)
+ image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ character_class = self.target_transform(character_class)
+
+ return image, character_class
+
+ def _check_integrity(self) -> bool:
+ zip_filename = self._get_target_folder()
+ if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
+ return False
+ return True
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+
+ filename = self._get_target_folder()
+ zip_filename = filename + ".zip"
+ url = self.download_url_prefix + "/" + zip_filename
+ download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
+
+ def _get_target_folder(self) -> str:
+ return "images_background" if self.background else "images_evaluation"
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/oxford_iiit_pet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/oxford_iiit_pet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e598920f8fe392f45a212dc7251ec84d5bb399b4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/oxford_iiit_pet.py
@@ -0,0 +1,135 @@
+import os
+import os.path
+import pathlib
+from collections.abc import Sequence
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class OxfordIIITPet(VisionDataset):
+ """`Oxford-IIIT Pet Dataset `_.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
+ target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
+ ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
+
+ - ``category`` (int): Label for one of the 37 pet categories.
+ - ``binary-category`` (int): Binary label for cat or dog.
+ - ``segmentation`` (PIL image): Segmentation trimap of the image.
+
+ If empty, ``None`` will be returned as target.
+
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``.
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample
+ and its target as entry and returns a transformed version.
+ download (bool, optional): If True, downloads the dataset from the internet and puts it into
+ ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
+ """
+
+ _RESOURCES = (
+ ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
+ ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
+ )
+ _VALID_TARGET_TYPES = ("category", "binary-category", "segmentation")
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "trainval",
+ target_types: Union[Sequence[str], str] = "category",
+ transforms: Optional[Callable] = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ):
+ self._split = verify_str_arg(split, "split", ("trainval", "test"))
+ if isinstance(target_types, str):
+ target_types = [target_types]
+ self._target_types = [
+ verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types
+ ]
+
+ super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
+ self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet"
+ self._images_folder = self._base_folder / "images"
+ self._anns_folder = self._base_folder / "annotations"
+ self._segs_folder = self._anns_folder / "trimaps"
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ image_ids = []
+ self._labels = []
+ self._bin_labels = []
+ with open(self._anns_folder / f"{self._split}.txt") as file:
+ for line in file:
+ image_id, label, bin_label, _ = line.strip().split()
+ image_ids.append(image_id)
+ self._labels.append(int(label) - 1)
+ self._bin_labels.append(int(bin_label) - 1)
+
+ self.bin_classes = ["Cat", "Dog"]
+ self.classes = [
+ " ".join(part.title() for part in raw_cls.split("_"))
+ for raw_cls, _ in sorted(
+ {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
+ key=lambda image_id_and_label: image_id_and_label[1],
+ )
+ ]
+ self.bin_class_to_idx = dict(zip(self.bin_classes, range(len(self.bin_classes))))
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+ self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
+ self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
+
+ def __len__(self) -> int:
+ return len(self._images)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image = Image.open(self._images[idx]).convert("RGB")
+
+ target: Any = []
+ for target_type in self._target_types:
+ if target_type == "category":
+ target.append(self._labels[idx])
+ elif target_type == "binary-category":
+ target.append(self._bin_labels[idx])
+ else: # target_type == "segmentation"
+ target.append(Image.open(self._segs[idx]))
+
+ if not target:
+ target = None
+ elif len(target) == 1:
+ target = target[0]
+ else:
+ target = tuple(target)
+
+ if self.transforms:
+ image, target = self.transforms(image, target)
+
+ return image, target
+
+ def _check_exists(self) -> bool:
+ for folder in (self._images_folder, self._anns_folder):
+ if not (os.path.exists(folder) and os.path.isdir(folder)):
+ return False
+ else:
+ return True
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+
+ for url, md5 in self._RESOURCES:
+ download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/pcam.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/pcam.py
new file mode 100644
index 0000000000000000000000000000000000000000..00d10f6a01035bf4cafa57231176c826c463c6f3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/pcam.py
@@ -0,0 +1,134 @@
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+from PIL import Image
+
+from .utils import _decompress, download_file_from_google_drive, verify_str_arg
+from .vision import VisionDataset
+
+
+class PCAM(VisionDataset):
+ """`PCAM Dataset `_.
+
+ The PatchCamelyon dataset is a binary classification dataset with 327,680
+ color images (96px x 96px), extracted from histopathologic scans of lymph node
+ sections. Each image is annotated with a binary label indicating presence of
+ metastatic tissue.
+
+ This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``.
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
+ dataset is already downloaded, it is not downloaded again.
+
+ .. warning::
+
+ To download the dataset `gdown `_ is required.
+ """
+
+ _FILES = {
+ "train": {
+ "images": (
+ "camelyonpatch_level_2_split_train_x.h5", # Data file name
+ "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", # Google Drive ID
+ "1571f514728f59376b705fc836ff4b63", # md5 hash
+ ),
+ "targets": (
+ "camelyonpatch_level_2_split_train_y.h5",
+ "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
+ "35c2d7259d906cfc8143347bb8e05be7",
+ ),
+ },
+ "test": {
+ "images": (
+ "camelyonpatch_level_2_split_test_x.h5",
+ "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
+ "d8c2d60d490dbd479f8199bdfa0cf6ec",
+ ),
+ "targets": (
+ "camelyonpatch_level_2_split_test_y.h5",
+ "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
+ "60a7035772fbdb7f34eb86d4420cf66a",
+ ),
+ },
+ "val": {
+ "images": (
+ "camelyonpatch_level_2_split_valid_x.h5",
+ "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
+ "d5b63470df7cfa627aeec8b9dc0c066e",
+ ),
+ "targets": (
+ "camelyonpatch_level_2_split_valid_y.h5",
+ "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
+ "2b85f58b927af9964a4c15b8f7e8f179",
+ ),
+ },
+ }
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ):
+ try:
+ import h5py
+
+ self.h5py = h5py
+ except ImportError:
+ raise RuntimeError(
+ "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
+ )
+
+ self._split = verify_str_arg(split, "split", ("train", "test", "val"))
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self._base_folder = pathlib.Path(self.root) / "pcam"
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ def __len__(self) -> int:
+ images_file = self._FILES[self._split]["images"][0]
+ with self.h5py.File(self._base_folder / images_file) as images_data:
+ return images_data["x"].shape[0]
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ images_file = self._FILES[self._split]["images"][0]
+ with self.h5py.File(self._base_folder / images_file) as images_data:
+ image = Image.fromarray(images_data["x"][idx]).convert("RGB")
+
+ targets_file = self._FILES[self._split]["targets"][0]
+ with self.h5py.File(self._base_folder / targets_file) as targets_data:
+ target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1]
+
+ if self.transform:
+ image = self.transform(image)
+ if self.target_transform:
+ target = self.target_transform(target)
+
+ return image, target
+
+ def _check_exists(self) -> bool:
+ images_file = self._FILES[self._split]["images"][0]
+ targets_file = self._FILES[self._split]["targets"][0]
+ return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+
+ for file_name, file_id, md5 in self._FILES[self._split].values():
+ archive_name = file_name + ".gz"
+ download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
+ _decompress(str(self._base_folder / archive_name))
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/phototour.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/phototour.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d625b51ecef08164b82328ec0d18338eecda31c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/phototour.py
@@ -0,0 +1,230 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class PhotoTour(VisionDataset):
+ """`Multi-view Stereo Correspondence `_ Dataset.
+
+ .. note::
+
+ We only provide the newer version of the dataset, since the authors state that it
+
+ is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
+ patches are centred on real interest point detections, rather than being projections of 3D points as is the
+ case in the old dataset.
+
+ The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
+
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images are.
+ name (string): Name of the dataset to load.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+
+ urls = {
+ "notredame_harris": [
+ "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
+ "notredame_harris.zip",
+ "69f8c90f78e171349abdf0307afefe4d",
+ ],
+ "yosemite_harris": [
+ "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
+ "yosemite_harris.zip",
+ "a73253d1c6fbd3ba2613c45065c00d46",
+ ],
+ "liberty_harris": [
+ "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
+ "liberty_harris.zip",
+ "c731fcfb3abb4091110d0ae8c7ba182c",
+ ],
+ "notredame": [
+ "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
+ "notredame.zip",
+ "509eda8535847b8c0a90bbb210c83484",
+ ],
+ "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
+ "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
+ }
+ means = {
+ "notredame": 0.4854,
+ "yosemite": 0.4844,
+ "liberty": 0.4437,
+ "notredame_harris": 0.4854,
+ "yosemite_harris": 0.4844,
+ "liberty_harris": 0.4437,
+ }
+ stds = {
+ "notredame": 0.1864,
+ "yosemite": 0.1818,
+ "liberty": 0.2019,
+ "notredame_harris": 0.1864,
+ "yosemite_harris": 0.1818,
+ "liberty_harris": 0.2019,
+ }
+ lens = {
+ "notredame": 468159,
+ "yosemite": 633587,
+ "liberty": 450092,
+ "liberty_harris": 379587,
+ "yosemite_harris": 450912,
+ "notredame_harris": 325295,
+ }
+ image_ext = "bmp"
+ info_file = "info.txt"
+ matches_files = "m50_100000_100000_0.txt"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ name: str,
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform)
+ self.name = name
+ self.data_dir = os.path.join(self.root, name)
+ self.data_down = os.path.join(self.root, f"{name}.zip")
+ self.data_file = os.path.join(self.root, f"{name}.pt")
+
+ self.train = train
+ self.mean = self.means[name]
+ self.std = self.stds[name]
+
+ if download:
+ self.download()
+
+ if not self._check_datafile_exists():
+ self.cache()
+
+ # load the serialized data
+ self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)
+
+ def __getitem__(self, index: int) -> Union[torch.Tensor, tuple[Any, Any, torch.Tensor]]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (data1, data2, matches)
+ """
+ if self.train:
+ data = self.data[index]
+ if self.transform is not None:
+ data = self.transform(data)
+ return data
+ m = self.matches[index]
+ data1, data2 = self.data[m[0]], self.data[m[1]]
+ if self.transform is not None:
+ data1 = self.transform(data1)
+ data2 = self.transform(data2)
+ return data1, data2, m[2]
+
+ def __len__(self) -> int:
+ return len(self.data if self.train else self.matches)
+
+ def _check_datafile_exists(self) -> bool:
+ return os.path.exists(self.data_file)
+
+ def _check_downloaded(self) -> bool:
+ return os.path.exists(self.data_dir)
+
+ def download(self) -> None:
+ if self._check_datafile_exists():
+ return
+
+ if not self._check_downloaded():
+ # download files
+ url = self.urls[self.name][0]
+ filename = self.urls[self.name][1]
+ md5 = self.urls[self.name][2]
+ fpath = os.path.join(self.root, filename)
+
+ download_url(url, self.root, filename, md5)
+
+ import zipfile
+
+ with zipfile.ZipFile(fpath, "r") as z:
+ z.extractall(self.data_dir)
+
+ os.unlink(fpath)
+
+ def cache(self) -> None:
+ # process and save as torch files
+
+ dataset = (
+ read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
+ read_info_file(self.data_dir, self.info_file),
+ read_matches_files(self.data_dir, self.matches_files),
+ )
+
+ with open(self.data_file, "wb") as f:
+ torch.save(dataset, f)
+
+ def extra_repr(self) -> str:
+ split = "Train" if self.train is True else "Test"
+ return f"Split: {split}"
+
+
+def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
+ """Return a Tensor containing the patches"""
+
+ def PIL2array(_img: Image.Image) -> np.ndarray:
+ """Convert PIL image type to numpy 2D array"""
+ return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
+
+ def find_files(_data_dir: str, _image_ext: str) -> list[str]:
+ """Return a list with the file names of the images containing the patches"""
+ files = []
+ # find those files with the specified extension
+ for file_dir in os.listdir(_data_dir):
+ if file_dir.endswith(_image_ext):
+ files.append(os.path.join(_data_dir, file_dir))
+ return sorted(files) # sort files in ascend order to keep relations
+
+ patches = []
+ list_files = find_files(data_dir, image_ext)
+
+ for fpath in list_files:
+ img = Image.open(fpath)
+ for y in range(0, img.height, 64):
+ for x in range(0, img.width, 64):
+ patch = img.crop((x, y, x + 64, y + 64))
+ patches.append(PIL2array(patch))
+ return torch.ByteTensor(np.array(patches[:n]))
+
+
+def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
+ """Return a Tensor containing the list of labels
+ Read the file and keep only the ID of the 3D point.
+ """
+ with open(os.path.join(data_dir, info_file)) as f:
+ labels = [int(line.split()[0]) for line in f]
+ return torch.LongTensor(labels)
+
+
+def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
+ """Return a Tensor containing the ground truth matches
+ Read the file and keep only 3D point ID.
+ Matches are represented with a 1, non matches with a 0.
+ """
+ matches = []
+ with open(os.path.join(data_dir, matches_file)) as f:
+ for line in f:
+ line_split = line.split()
+ matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
+ return torch.LongTensor(matches)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/places365.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/places365.py
new file mode 100644
index 0000000000000000000000000000000000000000..51b845de7234635a17a6ef87a9649898c9cdc0b2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/places365.py
@@ -0,0 +1,176 @@
+import os
+from os import path
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+from urllib.parse import urljoin
+
+from .folder import default_loader
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Places365(VisionDataset):
+ r"""`Places365 `_ classification dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the Places365 dataset.
+ split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
+ ``val``, ``test``.
+ small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
+ high resolution ones.
+ download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+ downloaded archives are not downloaded again.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+
+ Attributes:
+ classes (list): List of the class names.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ imgs (list): List of (image path, class_index) tuples
+ targets (list): The class_index value for each image in the dataset
+
+ Raises:
+ RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
+ RuntimeError: If ``download is True`` and the image archive is already extracted.
+ """
+
+ _SPLITS = ("train-standard", "train-challenge", "val", "test")
+ _BASE_URL = "http://data.csail.mit.edu/places/places365/"
+ # {variant: (archive, md5)}
+ _DEVKIT_META = {
+ "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
+ "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
+ }
+ # (file, md5)
+ _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
+ # {split: (file, md5)}
+ _FILE_LIST_META = {
+ "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
+ "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
+ "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
+ "test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
+ }
+ # {(split, small): (file, md5)}
+ _IMAGES_META = {
+ ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
+ ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
+ ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
+ ("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
+ ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
+ ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
+ ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
+ ("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train-standard",
+ small: bool = False,
+ download: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self.split = self._verify_split(split)
+ self.small = small
+ self.loader = loader
+
+ self.classes, self.class_to_idx = self.load_categories(download)
+ self.imgs, self.targets = self.load_file_list(download)
+
+ if download:
+ self.download_images()
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ file, target = self.imgs[index]
+ image = self.loader(file)
+
+ if self.transforms is not None:
+ image, target = self.transforms(image, target)
+
+ return image, target
+
+ def __len__(self) -> int:
+ return len(self.imgs)
+
+ @property
+ def variant(self) -> str:
+ return "challenge" if "challenge" in self.split else "standard"
+
+ @property
+ def images_dir(self) -> str:
+ size = "256" if self.small else "large"
+ if self.split.startswith("train"):
+ dir = f"data_{size}_{self.variant}"
+ else:
+ dir = f"{self.split}_{size}"
+ return path.join(self.root, dir)
+
+ def load_categories(self, download: bool = True) -> tuple[list[str], dict[str, int]]:
+ def process(line: str) -> tuple[str, int]:
+ cls, idx = line.split()
+ return cls, int(idx)
+
+ file, md5 = self._CATEGORIES_META
+ file = path.join(self.root, file)
+ if not self._check_integrity(file, md5, download):
+ self.download_devkit()
+
+ with open(file) as fh:
+ class_to_idx = dict(process(line) for line in fh)
+
+ return sorted(class_to_idx.keys()), class_to_idx
+
+ def load_file_list(
+ self, download: bool = True
+ ) -> tuple[list[tuple[str, Union[int, None]]], list[Union[int, None]]]:
+ def process(line: str, sep="/") -> tuple[str, Union[int, None]]:
+ image, idx = (line.split() + [None])[:2]
+ image = cast(str, image)
+ idx = int(idx) if idx is not None else None
+ return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), idx
+
+ file, md5 = self._FILE_LIST_META[self.split]
+ file = path.join(self.root, file)
+ if not self._check_integrity(file, md5, download):
+ self.download_devkit()
+
+ with open(file) as fh:
+ images = [process(line) for line in fh]
+
+ _, targets = zip(*images)
+ return images, list(targets)
+
+ def download_devkit(self) -> None:
+ file, md5 = self._DEVKIT_META[self.variant]
+ download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+ def download_images(self) -> None:
+ if path.exists(self.images_dir):
+ return
+
+ file, md5 = self._IMAGES_META[(self.split, self.small)]
+ download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+ if self.split.startswith("train"):
+ os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
+
+ def extra_repr(self) -> str:
+ return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
+
+ def _verify_split(self, split: str) -> str:
+ return verify_str_arg(split, "split", self._SPLITS)
+
+ def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
+ integrity = check_integrity(file, md5=md5)
+ if not integrity and not download:
+ raise RuntimeError(
+ f"The file {file} does not exist or is corrupted. You can set download=True to download it."
+ )
+ return integrity
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/rendered_sst2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/rendered_sst2.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ad3bc6d0018a3c297607d6ad4e221ed0b7595a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/rendered_sst2.py
@@ -0,0 +1,89 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader, make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class RenderedSST2(VisionDataset):
+ """`The Rendered SST2 Dataset `_.
+
+ Rendered SST2 is an image classification dataset used to evaluate the models capability on optical
+ character recognition. This dataset was generated by rendering sentences in the Standford Sentiment
+ Treebank v2 dataset.
+
+ This dataset contains two classes (positive and negative) and is divided in three splits: a train
+ split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
+ (444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again. Default is False.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
+ _MD5 = "2384d08e9dcfa4bd55b324e610496ee5"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+ self._split_to_folder = {"train": "train", "val": "valid", "test": "test"}
+ self._base_folder = Path(self.root) / "rendered-sst2"
+ self.classes = ["negative", "positive"]
+ self.class_to_idx = {"negative": 0, "positive": 1}
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",))
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._samples)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_file, label = self._samples[idx]
+ image = self.loader(image_file)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def extra_repr(self) -> str:
+ return f"split={self._split}"
+
+ def _check_exists(self) -> bool:
+ for class_label in set(self.classes):
+ if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir():
+ return False
+ return True
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+ download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sbd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sbd.py
new file mode 100644
index 0000000000000000000000000000000000000000..091e8698197584064974664474083a87d64f2908
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sbd.py
@@ -0,0 +1,126 @@
+import os
+import shutil
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SBDataset(VisionDataset):
+ """`Semantic Boundaries Dataset `_
+
+ The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
+
+ .. note ::
+
+ Please note that the train and val splits included with this dataset are different from
+ the splits in the PASCAL VOC dataset. In particular some "train" images might be part of
+ VOC2012 val.
+ If you are interested in testing on VOC 2012 val, then use `image_set='train_noval'`,
+ which excludes all val images.
+
+ .. warning::
+
+ This class needs `scipy `_ to load target files from `.mat` format.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the Semantic Boundaries Dataset
+ image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``.
+ Image set ``train_noval`` excludes VOC 2012 val images.
+ mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'.
+ In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`,
+ where `num_classes=20`.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version. Input sample is PIL image and target is a numpy array
+ if `mode='boundaries'` or PIL image if `mode='segmentation'`.
+ """
+
+ url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
+ md5 = "82b4d87ceb2ed10f6038a1cba92111cb"
+ filename = "benchmark.tgz"
+
+ voc_train_url = "https://www.cs.cornell.edu/~bharathh/train_noval.txt"
+ voc_split_filename = "train_noval.txt"
+ voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ image_set: str = "train",
+ mode: str = "boundaries",
+ download: bool = False,
+ transforms: Optional[Callable] = None,
+ ) -> None:
+
+ try:
+ from scipy.io import loadmat
+
+ self._loadmat = loadmat
+ except ImportError:
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+ super().__init__(root, transforms)
+ self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
+ self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
+ self.num_classes = 20
+
+ sbd_root = self.root
+ image_dir = os.path.join(sbd_root, "img")
+ mask_dir = os.path.join(sbd_root, "cls")
+
+ if download:
+ download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+ extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
+ for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
+ old_path = os.path.join(extracted_ds_root, f)
+ shutil.move(old_path, sbd_root)
+ if self.image_set == "train_noval":
+ # Note: this is failing as of June 2024 https://github.com/pytorch/vision/issues/8471
+ download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
+
+ if not os.path.isdir(sbd_root):
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
+
+ with open(os.path.join(split_f)) as fh:
+ file_names = [x.strip() for x in fh.readlines()]
+
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+ self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
+
+ self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
+
+ def _get_segmentation_target(self, filepath: str) -> Image.Image:
+ mat = self._loadmat(filepath)
+ return Image.fromarray(mat["GTcls"][0]["Segmentation"][0])
+
+ def _get_boundaries_target(self, filepath: str) -> np.ndarray:
+ mat = self._loadmat(filepath)
+ return np.concatenate(
+ [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)],
+ axis=0,
+ )
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ img = Image.open(self.images[index]).convert("RGB")
+ target = self._get_target(self.masks[index])
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.images)
+
+ def extra_repr(self) -> str:
+ lines = ["Image set: {image_set}", "Mode: {mode}"]
+ return "\n".join(lines).format(**self.__dict__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sbu.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sbu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c97503eec5f886fe1a188bdd797c710f93daa6
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sbu.py
@@ -0,0 +1,114 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import check_integrity, download_and_extract_archive, download_url
+from .vision import VisionDataset
+
+
+class SBU(VisionDataset):
+ """`SBU Captioned Photo `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where tarball
+ ``SBUCaptionedPhotoDataset.tar.gz`` exists.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
+ filename = "SBUCaptionedPhotoDataset.tar.gz"
+ md5_checksum = "9aec147b3488753cf758b4d493422285"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = True,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.loader = loader
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ # Read the caption for each photo
+ self.photos = []
+ self.captions = []
+
+ file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
+ file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
+
+ for line1, line2 in zip(open(file1), open(file2)):
+ url = line1.rstrip()
+ photo = os.path.basename(url)
+ filename = os.path.join(self.root, "dataset", photo)
+ if os.path.exists(filename):
+ caption = line2.rstrip()
+ self.photos.append(photo)
+ self.captions.append(caption)
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a caption for the photo.
+ """
+ filename = os.path.join(self.root, "dataset", self.photos[index])
+ img = self.loader(filename)
+ if self.transform is not None:
+ img = self.transform(img)
+
+ target = self.captions[index]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ """The number of photos in the dataset."""
+ return len(self.photos)
+
+ def _check_integrity(self) -> bool:
+ """Check the md5 checksum of the downloaded tarball."""
+ root = self.root
+ fpath = os.path.join(root, self.filename)
+ if not check_integrity(fpath, self.md5_checksum):
+ return False
+ return True
+
+ def download(self) -> None:
+ """Download and extract the tarball, and download each individual photo."""
+
+ if self._check_integrity():
+ return
+
+ download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
+
+ # Download individual photos
+ with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
+ for line in fh:
+ url = line.rstrip()
+ try:
+ download_url(url, os.path.join(self.root, "dataset"))
+ except OSError:
+ # The images point to public images on Flickr.
+ # Note: Images might be removed by users at anytime.
+ pass
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/semeion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/semeion.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8d139cb21fcd4e2a8157f7071ac43f62b72288
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/semeion.py
@@ -0,0 +1,92 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from ..utils import _Image_fromarray
+from .utils import check_integrity, download_url
+from .vision import VisionDataset
+
+
+class SEMEION(VisionDataset):
+ r"""`SEMEION `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``semeion.py`` exists.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+
+ url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
+ filename = "semeion.data"
+ md5_checksum = "cb545d371d2ce14ec121470795a77432"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = True,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ fp = os.path.join(self.root, self.filename)
+ data = np.loadtxt(fp)
+ # convert value to 8 bit unsigned integer
+ # color (white #255) the pixels
+ self.data = (data[:, :256] * 255).astype("uint8")
+ self.data = np.reshape(self.data, (-1, 16, 16))
+ self.labels = np.nonzero(data[:, 256:])[1]
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.labels[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = _Image_fromarray(img, mode="L")
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def _check_integrity(self) -> bool:
+ root = self.root
+ fpath = os.path.join(root, self.filename)
+ if not check_integrity(fpath, self.md5_checksum):
+ return False
+ return True
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+
+ root = self.root
+ download_url(self.url, root, self.filename, self.md5_checksum)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/stanford_cars.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/stanford_cars.py
new file mode 100644
index 0000000000000000000000000000000000000000..e73fb1f3141dad7689ad3ed0ef0a580a2bc02b14
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/stanford_cars.py
@@ -0,0 +1,105 @@
+import pathlib
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import verify_str_arg
+from .vision import VisionDataset
+
+
+class StanfordCars(VisionDataset):
+ """Stanford Cars Dataset
+
+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
+ split into 8,144 training images and 8,041 testing images, where each class
+ has been split roughly in a 50-50 split
+
+ The original URL is https://ai.stanford.edu/~jkrause/cars/car_dataset.html,
+ the dataset isn't available online anymore.
+
+ .. note::
+
+ This class needs `scipy `_ to load target files from `.mat` format.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset
+ split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): This parameter exists for backward compatibility but it does not
+ download the dataset, since the original URL is not available anymore.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[str], Any] = default_loader,
+ ) -> None:
+
+ try:
+ import scipy.io as sio
+ except ImportError:
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self._split = verify_str_arg(split, "split", ("train", "test"))
+ self._base_folder = pathlib.Path(root) / "stanford_cars"
+ devkit = self._base_folder / "devkit"
+
+ if self._split == "train":
+ self._annotations_mat_path = devkit / "cars_train_annos.mat"
+ self._images_base_path = self._base_folder / "cars_train"
+ else:
+ self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
+ self._images_base_path = self._base_folder / "cars_test"
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found.")
+
+ self._samples = [
+ (
+ str(self._images_base_path / annotation["fname"]),
+ annotation["class"] - 1, # Original target mapping starts from 1, hence -1
+ )
+ for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
+ ]
+
+ self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._samples)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ """Returns pil_image and class_id for given index"""
+ image_path, target = self._samples[idx]
+ image = self.loader(image_path)
+
+ if self.transform is not None:
+ image = self.transform(image)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return image, target
+
+ def _check_exists(self) -> bool:
+ if not (self._base_folder / "devkit").is_dir():
+ return False
+
+ return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
+
+ def download(self):
+ raise ValueError("The original URL is broken so the StanfordCars dataset cannot be downloaded anymore.")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/stl10.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/stl10.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7212a1b55578334efcfe55861165d4b196326c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/stl10.py
@@ -0,0 +1,174 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, cast, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class STL10(VisionDataset):
+ """`STL10 `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
+ ``stl10_binary`` exists.
+ split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
+ Accordingly, dataset is selected.
+ folds (int, optional): One of {0-9} or None.
+ For training, loads one of the 10 pre-defined folds of 1k samples for the
+ standard evaluation procedure. If no value is passed, loads the 5k samples.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ base_folder = "stl10_binary"
+ url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
+ filename = "stl10_binary.tar.gz"
+ tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
+ class_names_file = "class_names.txt"
+ folds_list_file = "fold_indices.txt"
+ train_list = [
+ ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
+ ["train_y.bin", "5a34089d4802c674881badbb80307741"],
+ ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
+ ]
+
+ test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
+ splits = ("train", "train+unlabeled", "unlabeled", "test")
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ folds: Optional[int] = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.split = verify_str_arg(split, "split", self.splits)
+ self.folds = self._verify_folds(folds)
+
+ if download:
+ self.download()
+ elif not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ # now load the picked numpy arrays
+ self.labels: Optional[np.ndarray]
+ if self.split == "train":
+ self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+ self.labels = cast(np.ndarray, self.labels)
+ self.__load_folds(folds)
+
+ elif self.split == "train+unlabeled":
+ self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+ self.labels = cast(np.ndarray, self.labels)
+ self.__load_folds(folds)
+ unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
+ self.data = np.concatenate((self.data, unlabeled_data))
+ self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
+
+ elif self.split == "unlabeled":
+ self.data, _ = self.__loadfile(self.train_list[2][0])
+ self.labels = np.asarray([-1] * self.data.shape[0])
+ else: # self.split == 'test':
+ self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
+
+ class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
+ if os.path.isfile(class_file):
+ with open(class_file) as f:
+ self.classes = f.read().splitlines()
+
+ def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
+ if folds is None:
+ return folds
+ elif isinstance(folds, int):
+ if folds in range(10):
+ return folds
+ msg = "Value for argument folds should be in the range [0, 10), but got {}."
+ raise ValueError(msg.format(folds))
+ else:
+ msg = "Expected type None or int for argument folds, but got type {}."
+ raise ValueError(msg.format(type(folds)))
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ target: Optional[int]
+ if self.labels is not None:
+ img, target = self.data[index], int(self.labels[index])
+ else:
+ img, target = self.data[index], None
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return self.data.shape[0]
+
+ def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> tuple[np.ndarray, Optional[np.ndarray]]:
+ labels = None
+ if labels_file:
+ path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
+ with open(path_to_labels, "rb") as f:
+ labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
+
+ path_to_data = os.path.join(self.root, self.base_folder, data_file)
+ with open(path_to_data, "rb") as f:
+ # read whole file in uint8 chunks
+ everything = np.fromfile(f, dtype=np.uint8)
+ images = np.reshape(everything, (-1, 3, 96, 96))
+ images = np.transpose(images, (0, 1, 3, 2))
+
+ return images, labels
+
+ def _check_integrity(self) -> bool:
+ for filename, md5 in self.train_list + self.test_list:
+ fpath = os.path.join(self.root, self.base_folder, filename)
+ if not check_integrity(fpath, md5):
+ return False
+ return True
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+ download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+ self._check_integrity()
+
+ def extra_repr(self) -> str:
+ return "Split: {split}".format(**self.__dict__)
+
+ def __load_folds(self, folds: Optional[int]) -> None:
+ # loads one of the folds if specified
+ if folds is None:
+ return
+ path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
+ with open(path_to_folds) as f:
+ str_idx = f.read().splitlines()[folds]
+ list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
+ self.data = self.data[list_idx, :, :, :]
+ if self.labels is not None:
+ self.labels = self.labels[list_idx]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sun397.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sun397.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27f86d95795641c475bbf508c22734e2cf37412
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/sun397.py
@@ -0,0 +1,81 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from .folder import default_loader
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class SUN397(VisionDataset):
+ """`The SUN397 Data Set `_.
+
+ The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
+ 397 categories with 108'754 images.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ loader (callable, optional): A function to load an image given its path.
+ By default, it uses PIL as its image loader, but users could also pass in
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
+ """
+
+ _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
+ _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ loader: Callable[[Union[str, Path]], Any] = default_loader,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self._data_dir = Path(self.root) / "SUN397"
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ with open(self._data_dir / "ClassName.txt") as f:
+ self.classes = [c[3:].strip() for c in f]
+
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+ self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
+
+ self._labels = [
+ self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
+ ]
+ self.loader = loader
+
+ def __len__(self) -> int:
+ return len(self._image_files)
+
+ def __getitem__(self, idx: int) -> tuple[Any, Any]:
+ image_file, label = self._image_files[idx], self._labels[idx]
+ image = self.loader(image_file)
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ label = self.target_transform(label)
+
+ return image, label
+
+ def _check_exists(self) -> bool:
+ return self._data_dir.is_dir()
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+ download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/svhn.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/svhn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b59f78ec050d045bbf8099434b8cd579bba12c72
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/svhn.py
@@ -0,0 +1,130 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SVHN(VisionDataset):
+ """`SVHN `_ Dataset.
+ Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
+ we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
+ expect the class labels to be in the range `[0, C-1]`
+
+ .. warning::
+
+ This class needs `scipy `_ to load data from `.mat` format.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset where the data is stored.
+ split (string): One of {'train', 'test', 'extra'}.
+ Accordingly dataset is selected. 'extra' is Extra training set.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+
+ split_list = {
+ "train": [
+ "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
+ "train_32x32.mat",
+ "e26dedcc434d2e4c54c9b2d4a06d8373",
+ ],
+ "test": [
+ "http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
+ "test_32x32.mat",
+ "eb5a983be6a315427106f1b164d9cef3",
+ ],
+ "extra": [
+ "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
+ "extra_32x32.mat",
+ "a93ce644f1a588dc4d68dda5feec44a7",
+ ],
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
+ self.url = self.split_list[split][0]
+ self.filename = self.split_list[split][1]
+ self.file_md5 = self.split_list[split][2]
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ # import here rather than at top of file because this is
+ # an optional dependency for torchvision
+ import scipy.io as sio
+
+ # reading(loading) mat file as array
+ loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
+
+ self.data = loaded_mat["X"]
+ # loading from the .mat file gives an np.ndarray of type np.uint8
+ # converting to np.int64, so that we have a LongTensor after
+ # the conversion from the numpy array
+ # the squeeze is needed to obtain a 1D tensor
+ self.labels = loaded_mat["y"].astype(np.int64).squeeze()
+
+ # the svhn dataset assigns the class label "10" to the digit 0
+ # this makes it inconsistent with several loss functions
+ # which expect the class labels to be in the range [0, C-1]
+ np.place(self.labels, self.labels == 10, 0)
+ self.data = np.transpose(self.data, (3, 2, 0, 1))
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.labels[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def _check_integrity(self) -> bool:
+ root = self.root
+ md5 = self.split_list[self.split][2]
+ fpath = os.path.join(root, self.filename)
+ return check_integrity(fpath, md5)
+
+ def download(self) -> None:
+ md5 = self.split_list[self.split][2]
+ download_url(self.url, self.root, self.filename, md5)
+
+ def extra_repr(self) -> str:
+ return "Split: {split}".format(**self.__dict__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/ucf101.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/ucf101.py
new file mode 100644
index 0000000000000000000000000000000000000000..85930dbc742beb0dcfdac6e515f16966b92b9634
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/ucf101.py
@@ -0,0 +1,131 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class UCF101(VisionDataset):
+ """
+ `UCF101 `_ dataset.
+
+ UCF101 is an action recognition video dataset.
+ This dataset consider every video as a collection of video clips of fixed size, specified
+ by ``frames_per_clip``, where the step in frames between each clip is given by
+ ``step_between_clips``. The dataset itself can be downloaded from the dataset website;
+ annotations that ``annotation_path`` should be pointing to can be downloaded from `here
+ `_.
+
+ To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+ and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+ elements will come from video 1, and the next three elements from video 2.
+ Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+ frames in a video might be present.
+
+ Internally, it uses a VideoClips object to handle clip creation.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the UCF101 Dataset.
+ annotation_path (str): path to the folder containing the split files;
+ see docstring above for download instructions of these files
+ frames_per_clip (int): number of frames in a clip.
+ step_between_clips (int, optional): number of frames between each clip.
+ fold (int, optional): which fold to use. Should be between 1 and 3.
+ train (bool, optional): if ``True``, creates a dataset from the train split,
+ otherwise from the ``test`` split.
+ transform (callable, optional): A function/transform that takes in a TxHxWxC video
+ and returns a transformed version.
+ output_format (str, optional): The format of the output video tensors (before transforms).
+ Can be either "THWC" (default) or "TCHW".
+
+ Returns:
+ tuple: A 3-tuple with the following entries:
+
+ - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+ - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+ and `L` is the number of points
+ - label (int): class of the video clip
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ annotation_path: str,
+ frames_per_clip: int,
+ step_between_clips: int = 1,
+ frame_rate: Optional[int] = None,
+ fold: int = 1,
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ _precomputed_metadata: Optional[dict[str, Any]] = None,
+ num_workers: int = 1,
+ _video_width: int = 0,
+ _video_height: int = 0,
+ _video_min_dimension: int = 0,
+ _audio_samples: int = 0,
+ output_format: str = "THWC",
+ ) -> None:
+ super().__init__(root)
+ if not 1 <= fold <= 3:
+ raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+ extensions = ("avi",)
+ self.fold = fold
+ self.train = train
+
+ self.classes, class_to_idx = find_classes(self.root)
+ self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
+ video_list = [x[0] for x in self.samples]
+ video_clips = VideoClips(
+ video_list,
+ frames_per_clip,
+ step_between_clips,
+ frame_rate,
+ _precomputed_metadata,
+ num_workers=num_workers,
+ _video_width=_video_width,
+ _video_height=_video_height,
+ _video_min_dimension=_video_min_dimension,
+ _audio_samples=_audio_samples,
+ output_format=output_format,
+ )
+ # we bookkeep the full version of video clips because we want to be able
+ # to return the metadata of full version rather than the subset version of
+ # video clips
+ self.full_video_clips = video_clips
+ self.indices = self._select_fold(video_list, annotation_path, fold, train)
+ self.video_clips = video_clips.subset(self.indices)
+ self.transform = transform
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ return self.full_video_clips.metadata
+
+ def _select_fold(self, video_list: list[str], annotation_path: str, fold: int, train: bool) -> list[int]:
+ name = "train" if train else "test"
+ name = f"{name}list{fold:02d}.txt"
+ f = os.path.join(annotation_path, name)
+ selected_files = set()
+ with open(f) as fid:
+ data = fid.readlines()
+ data = [x.strip().split(" ")[0] for x in data]
+ data = [os.path.join(self.root, *x.split("/")) for x in data]
+ selected_files.update(data)
+ indices = [i for i in range(len(video_list)) if video_list[i] in selected_files]
+ return indices
+
+ def __len__(self) -> int:
+ return self.video_clips.num_clips()
+
+ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, int]:
+ video, audio, info, video_idx = self.video_clips.get_clip(idx)
+ label = self.samples[self.indices[video_idx]][1]
+
+ if self.transform is not None:
+ video = self.transform(video)
+
+ return video, audio, label
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/usps.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/usps.py
new file mode 100644
index 0000000000000000000000000000000000000000..e09ac96e45eefd8ae2458a196baa4f07630d3d43
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/usps.py
@@ -0,0 +1,96 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from ..utils import _Image_fromarray
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class USPS(VisionDataset):
+ """`USPS `_ Dataset.
+ The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
+ The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
+ and make pixel values in ``[0, 255]``.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset to store``USPS`` data files.
+ train (bool, optional): If True, creates dataset from ``usps.bz2``,
+ otherwise from ``usps.t.bz2``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+
+ split_list = {
+ "train": [
+ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
+ "usps.bz2",
+ "ec16c51db3855ca6c91edd34d0e9b197",
+ ],
+ "test": [
+ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
+ "usps.t.bz2",
+ "8ea070ee2aca1ac39742fdd1ef5ed118",
+ ],
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ split = "train" if train else "test"
+ url, filename, checksum = self.split_list[split]
+ full_path = os.path.join(self.root, filename)
+
+ if download and not os.path.exists(full_path):
+ download_url(url, self.root, filename, md5=checksum)
+
+ import bz2
+
+ with bz2.open(full_path) as fp:
+ raw_data = [line.decode().split() for line in fp.readlines()]
+ tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
+ imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
+ imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
+ targets = [int(d[0]) - 1 for d in raw_data]
+
+ self.data = imgs
+ self.targets = targets
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.targets[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = _Image_fromarray(img, mode="L")
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.data)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b6670800d2b012829bdf06b887f82ff3f554108
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/utils.py
@@ -0,0 +1,468 @@
+import bz2
+import gzip
+import hashlib
+import lzma
+import os
+import os.path
+import pathlib
+import re
+import tarfile
+import urllib
+import urllib.error
+import urllib.request
+import zipfile
+from collections.abc import Iterable
+from typing import Any, Callable, IO, Optional, TypeVar, Union
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch.utils.model_zoo import tqdm
+
+from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
+
+USER_AGENT = "pytorch/vision"
+
+
+def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
+ with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
+ with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar:
+ while chunk := response.read(chunk_size):
+ fh.write(chunk)
+ pbar.update(len(chunk))
+
+
+def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
+ # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
+ # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
+ # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
+ md5 = hashlib.md5(usedforsecurity=False)
+ with open(fpath, "rb") as f:
+ while chunk := f.read(chunk_size):
+ md5.update(chunk)
+ return md5.hexdigest()
+
+
+def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
+ return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
+ if not os.path.isfile(fpath):
+ return False
+ if md5 is None:
+ return True
+ return check_md5(fpath, md5)
+
+
+def _get_redirect_url(url: str, max_hops: int = 3) -> str:
+ initial_url = url
+ headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
+
+ for _ in range(max_hops + 1):
+ with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
+ if response.url == url or response.url is None:
+ return url
+
+ url = response.url
+ else:
+ raise RecursionError(
+ f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
+ )
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+ parts = urlparse(url)
+
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+ return None
+
+ match = re.match(r"/file/d/(?P[^/]*)", parts.path)
+ if match is None:
+ return None
+
+ return match.group("id")
+
+
+def download_url(
+ url: str,
+ root: Union[str, pathlib.Path],
+ filename: Optional[Union[str, pathlib.Path]] = None,
+ md5: Optional[str] = None,
+ max_redirect_hops: int = 3,
+) -> None:
+ """Download a file from a url and place it in root.
+
+ Args:
+ url (str): URL to download file from
+ root (str): Directory to place downloaded file in
+ filename (str, optional): Name to save the file under. If None, use the basename of the URL
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
+ max_redirect_hops (int, optional): Maximum number of redirect hops allowed
+ """
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = os.path.basename(url)
+ fpath = os.fspath(os.path.join(root, filename))
+
+ os.makedirs(root, exist_ok=True)
+
+ # check if file is already present locally
+ if check_integrity(fpath, md5):
+ return
+
+ if _is_remote_location_available():
+ _download_file_from_remote_location(fpath, url)
+ else:
+ # expand redirect chain if needed
+ url = _get_redirect_url(url, max_hops=max_redirect_hops)
+
+ # check if file is located on Google Drive
+ file_id = _get_google_drive_file_id(url)
+ if file_id is not None:
+ return download_file_from_google_drive(file_id, root, filename, md5)
+
+ # download the file
+ try:
+ _urlretrieve(url, fpath)
+ except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
+ if url[:5] == "https":
+ url = url.replace("https:", "http:")
+ _urlretrieve(url, fpath)
+ else:
+ raise e
+
+ # check integrity of downloaded file
+ if not check_integrity(fpath, md5):
+ raise RuntimeError("File not found or corrupted.")
+
+
+def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> list[str]:
+ """List all directories at a given root
+
+ Args:
+ root (str): Path to directory whose folders need to be listed
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
+ only returns the name of the directories found
+ """
+ root = os.path.expanduser(root)
+ directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
+ if prefix is True:
+ directories = [os.path.join(root, d) for d in directories]
+ return directories
+
+
+def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> list[str]:
+ """List all files ending with a suffix at a given root
+
+ Args:
+ root (str): Path to directory whose folders need to be listed
+ suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+ It uses the Python "str.endswith" method and is passed directly
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
+ only returns the name of the files found
+ """
+ root = os.path.expanduser(root)
+ files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
+ if prefix is True:
+ files = [os.path.join(root, d) for d in files]
+ return files
+
+
+def download_file_from_google_drive(
+ file_id: str,
+ root: Union[str, pathlib.Path],
+ filename: Optional[Union[str, pathlib.Path]] = None,
+ md5: Optional[str] = None,
+):
+ """Download a Google Drive file from and place it in root.
+
+ Args:
+ file_id (str): id of file to be downloaded
+ root (str): Directory to place downloaded file in
+ filename (str, optional): Name to save the file under. If None, use the id of the file.
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
+ """
+ try:
+ import gdown
+ except ModuleNotFoundError:
+ raise RuntimeError(
+ "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
+ )
+
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = file_id
+ fpath = os.fspath(os.path.join(root, filename))
+
+ os.makedirs(root, exist_ok=True)
+
+ if check_integrity(fpath, md5):
+ return
+
+ gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
+
+ if not check_integrity(fpath, md5):
+ raise RuntimeError("File not found or corrupted.")
+
+
+def _extract_tar(
+ from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
+) -> None:
+ with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
+ tar.extractall(to_path)
+
+
+_ZIP_COMPRESSION_MAP: dict[str, int] = {
+ ".bz2": zipfile.ZIP_BZIP2,
+ ".xz": zipfile.ZIP_LZMA,
+}
+
+
+def _extract_zip(
+ from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
+) -> None:
+ with zipfile.ZipFile(
+ from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
+ ) as zip:
+ zip.extractall(to_path)
+
+
+_ARCHIVE_EXTRACTORS: dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
+ ".tar": _extract_tar,
+ ".zip": _extract_zip,
+}
+_COMPRESSED_FILE_OPENERS: dict[str, Callable[..., IO]] = {
+ ".bz2": bz2.open,
+ ".gz": gzip.open,
+ ".xz": lzma.open,
+}
+_FILE_TYPE_ALIASES: dict[str, tuple[Optional[str], Optional[str]]] = {
+ ".tbz": (".tar", ".bz2"),
+ ".tbz2": (".tar", ".bz2"),
+ ".tgz": (".tar", ".gz"),
+}
+
+
+def _detect_file_type(file: Union[str, pathlib.Path]) -> tuple[str, Optional[str], Optional[str]]:
+ """Detect the archive type and/or compression of a file.
+
+ Args:
+ file (str): the filename
+
+ Returns:
+ (tuple): tuple of suffix, archive type, and compression
+
+ Raises:
+ RuntimeError: if file has no suffix or suffix is not supported
+ """
+ suffixes = pathlib.Path(file).suffixes
+ if not suffixes:
+ raise RuntimeError(
+ f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
+ )
+ suffix = suffixes[-1]
+
+ # check if the suffix is a known alias
+ if suffix in _FILE_TYPE_ALIASES:
+ return (suffix, *_FILE_TYPE_ALIASES[suffix])
+
+ # check if the suffix is an archive type
+ if suffix in _ARCHIVE_EXTRACTORS:
+ return suffix, suffix, None
+
+ # check if the suffix is a compression
+ if suffix in _COMPRESSED_FILE_OPENERS:
+ # check for suffix hierarchy
+ if len(suffixes) > 1:
+ suffix2 = suffixes[-2]
+
+ # check if the suffix2 is an archive type
+ if suffix2 in _ARCHIVE_EXTRACTORS:
+ return suffix2 + suffix, suffix2, suffix
+
+ return suffix, None, suffix
+
+ valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
+ raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
+
+
+def _decompress(
+ from_path: Union[str, pathlib.Path],
+ to_path: Optional[Union[str, pathlib.Path]] = None,
+ remove_finished: bool = False,
+) -> pathlib.Path:
+ r"""Decompress a file.
+
+ The compression is automatically detected from the file name.
+
+ Args:
+ from_path (str): Path to the file to be decompressed.
+ to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
+ remove_finished (bool): If ``True``, remove the file after the extraction.
+
+ Returns:
+ (str): Path to the decompressed file.
+ """
+ suffix, archive_type, compression = _detect_file_type(from_path)
+ if not compression:
+ raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
+
+ if to_path is None:
+ to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))
+
+ # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+ compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
+
+ with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
+ wfh.write(rfh.read())
+
+ if remove_finished:
+ os.remove(from_path)
+
+ return pathlib.Path(to_path)
+
+
+def extract_archive(
+ from_path: Union[str, pathlib.Path],
+ to_path: Optional[Union[str, pathlib.Path]] = None,
+ remove_finished: bool = False,
+) -> Union[str, pathlib.Path]:
+ """Extract an archive.
+
+ The archive type and a possible compression is automatically detected from the file name. If the file is compressed
+ but not an archive the call is dispatched to :func:`decompress`.
+
+ Args:
+ from_path (str): Path to the file to be extracted.
+ to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
+ used.
+ remove_finished (bool): If ``True``, remove the file after the extraction.
+
+ Returns:
+ (str): Path to the directory the file was extracted to.
+ """
+
+ def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
+ if isinstance(from_path, str):
+ return os.fspath(ret_path)
+ else:
+ return ret_path
+
+ if to_path is None:
+ to_path = os.path.dirname(from_path)
+
+ suffix, archive_type, compression = _detect_file_type(from_path)
+ if not archive_type:
+ ret_path = _decompress(
+ from_path,
+ os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
+ remove_finished=remove_finished,
+ )
+ return path_or_str(ret_path)
+
+ # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+ extractor = _ARCHIVE_EXTRACTORS[archive_type]
+
+ extractor(from_path, to_path, compression)
+ if remove_finished:
+ os.remove(from_path)
+
+ return path_or_str(pathlib.Path(to_path))
+
+
+def download_and_extract_archive(
+ url: str,
+ download_root: Union[str, pathlib.Path],
+ extract_root: Optional[Union[str, pathlib.Path]] = None,
+ filename: Optional[Union[str, pathlib.Path]] = None,
+ md5: Optional[str] = None,
+ remove_finished: bool = False,
+) -> None:
+ download_root = os.path.expanduser(download_root)
+ if extract_root is None:
+ extract_root = download_root
+ if not filename:
+ filename = os.path.basename(url)
+
+ download_url(url, download_root, filename, md5)
+
+ archive = os.path.join(download_root, filename)
+ extract_archive(archive, extract_root, remove_finished)
+
+
+def iterable_to_str(iterable: Iterable) -> str:
+ return "'" + "', '".join([str(item) for item in iterable]) + "'"
+
+
+T = TypeVar("T", str, bytes)
+
+
+def verify_str_arg(
+ value: T,
+ arg: Optional[str] = None,
+ valid_values: Optional[Iterable[T]] = None,
+ custom_msg: Optional[str] = None,
+) -> T:
+ if not isinstance(value, str):
+ if arg is None:
+ msg = "Expected type str, but got type {type}."
+ else:
+ msg = "Expected type str for argument {arg}, but got type {type}."
+ msg = msg.format(type=type(value), arg=arg)
+ raise ValueError(msg)
+
+ if valid_values is None:
+ return value
+
+ if value not in valid_values:
+ if custom_msg is not None:
+ msg = custom_msg
+ else:
+ msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
+ msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
+ raise ValueError(msg)
+
+ return value
+
+
+def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
+ """Read file in .pfm format. Might contain either 1 or 3 channels of data.
+
+ Args:
+ file_name (str): Path to the file.
+ slice_channels (int): Number of channels to slice out of the file.
+ Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
+ """
+
+ with open(file_name, "rb") as f:
+ header = f.readline().rstrip()
+ if header not in [b"PF", b"Pf"]:
+ raise ValueError("Invalid PFM file")
+
+ dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
+ if not dim_match:
+ raise Exception("Malformed PFM header.")
+ w, h = (int(dim) for dim in dim_match.groups())
+
+ scale = float(f.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ endian = ">" # big-endian
+
+ data = np.fromfile(f, dtype=endian + "f")
+
+ pfm_channels = 3 if header == b"PF" else 1
+
+ data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
+ data = np.flip(data, axis=1) # flip on h dimension
+ data = data[:slice_channels, :, :]
+ return data.astype(np.float32)
+
+
+def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
+ return (
+ t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/video_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9214beaa680057ae10a414244b6c88310be8513
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/video_utils.py
@@ -0,0 +1,419 @@
+import bisect
+import math
+import warnings
+from fractions import Fraction
+from typing import Any, Callable, cast, Optional, TypeVar, Union
+
+import torch
+from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps
+
+from .utils import tqdm
+
+T = TypeVar("T")
+
+
+def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
+ """convert pts between different time bases
+ Args:
+ pts: presentation timestamp, float
+ timebase_from: original timebase. Fraction
+ timebase_to: new timebase. Fraction
+ round_func: rounding function.
+ """
+ new_pts = Fraction(pts, 1) * timebase_from / timebase_to
+ return round_func(new_pts)
+
+
+def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
+ """
+ similar to tensor.unfold, but with the dilation
+ and specialized for 1d tensors
+
+ Returns all consecutive windows of `size` elements, with
+ `step` between windows. The distance between each element
+ in a window is given by `dilation`.
+ """
+ if tensor.dim() != 1:
+ raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
+ o_stride = tensor.stride(0)
+ numel = tensor.numel()
+ new_stride = (step * o_stride, dilation * o_stride)
+ new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
+ if new_size[0] < 1:
+ new_size = (0, size)
+ return torch.as_strided(tensor, new_size, new_stride)
+
+
+class _VideoTimestampsDataset:
+ """
+ Dataset used to parallelize the reading of the timestamps
+ of a list of videos, given their paths in the filesystem.
+
+ Used in VideoClips and defined at top level, so it can be
+ pickled when forking.
+ """
+
+ def __init__(self, video_paths: list[str]) -> None:
+ self.video_paths = video_paths
+
+ def __len__(self) -> int:
+ return len(self.video_paths)
+
+ def __getitem__(self, idx: int) -> tuple[list[int], Optional[float]]:
+ return read_video_timestamps(self.video_paths[idx])
+
+
+def _collate_fn(x: T) -> T:
+ """
+ Dummy collate function to be used with _VideoTimestampsDataset
+ """
+ return x
+
+
+class VideoClips:
+ """
+ Given a list of video files, computes all consecutive subvideos of size
+ `clip_length_in_frames`, where the distance between each subvideo in the
+ same video is defined by `frames_between_clips`.
+ If `frame_rate` is specified, it will also resample all the videos to have
+ the same frame rate, and the clips will refer to this frame rate.
+
+ Creating this instance the first time is time-consuming, as it needs to
+ decode all the videos in `video_paths`. It is recommended that you
+ cache the results after instantiation of the class.
+
+ Recreating the clips for different clip lengths is fast, and can be done
+ with the `compute_clips` method.
+
+ Args:
+ video_paths (List[str]): paths to the video files
+ clip_length_in_frames (int): size of a clip in number of frames
+ frames_between_clips (int): step (in frames) between each clip
+ frame_rate (float, optional): if specified, it will resample the video
+ so that it has `frame_rate`, and then the clips will be defined
+ on the resampled video
+ num_workers (int): how many subprocesses to use for data loading.
+ 0 means that the data will be loaded in the main process. (default: 0)
+ output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+ """
+
+ def __init__(
+ self,
+ video_paths: list[str],
+ clip_length_in_frames: int = 16,
+ frames_between_clips: int = 1,
+ frame_rate: Optional[float] = None,
+ _precomputed_metadata: Optional[dict[str, Any]] = None,
+ num_workers: int = 0,
+ _video_width: int = 0,
+ _video_height: int = 0,
+ _video_min_dimension: int = 0,
+ _video_max_dimension: int = 0,
+ _audio_samples: int = 0,
+ _audio_channels: int = 0,
+ output_format: str = "THWC",
+ ) -> None:
+
+ self.video_paths = video_paths
+ self.num_workers = num_workers
+
+ # these options are not valid for pyav backend
+ self._video_width = _video_width
+ self._video_height = _video_height
+ self._video_min_dimension = _video_min_dimension
+ self._video_max_dimension = _video_max_dimension
+ self._audio_samples = _audio_samples
+ self._audio_channels = _audio_channels
+ self.output_format = output_format.upper()
+ if self.output_format not in ("THWC", "TCHW"):
+ raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+ if _precomputed_metadata is None:
+ self._compute_frame_pts()
+ else:
+ self._init_from_metadata(_precomputed_metadata)
+ self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
+
+ def _compute_frame_pts(self) -> None:
+ self.video_pts = [] # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
+ self.video_fps: list[float] = [] # len = num_videos
+
+ # strategy: use a DataLoader to parallelize read_video_timestamps
+ # so need to create a dummy dataset first
+ import torch.utils.data
+
+ dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
+ _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type]
+ batch_size=16,
+ num_workers=self.num_workers,
+ collate_fn=_collate_fn,
+ )
+
+ with tqdm(total=len(dl)) as pbar:
+ for batch in dl:
+ pbar.update(1)
+ batch_pts, batch_fps = list(zip(*batch))
+ # we need to specify dtype=torch.long because for empty list,
+ # torch.as_tensor will use torch.float as default dtype. This
+ # happens when decoding fails and no pts is returned in the list.
+ batch_pts = [torch.as_tensor(pts, dtype=torch.long) for pts in batch_pts]
+ self.video_pts.extend(batch_pts)
+ self.video_fps.extend(batch_fps)
+
+ def _init_from_metadata(self, metadata: dict[str, Any]) -> None:
+ self.video_paths = metadata["video_paths"]
+ assert len(self.video_paths) == len(metadata["video_pts"])
+ self.video_pts = metadata["video_pts"]
+ assert len(self.video_paths) == len(metadata["video_fps"])
+ self.video_fps = metadata["video_fps"]
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ _metadata = {
+ "video_paths": self.video_paths,
+ "video_pts": self.video_pts,
+ "video_fps": self.video_fps,
+ }
+ return _metadata
+
+ def subset(self, indices: list[int]) -> "VideoClips":
+ video_paths = [self.video_paths[i] for i in indices]
+ video_pts = [self.video_pts[i] for i in indices]
+ video_fps = [self.video_fps[i] for i in indices]
+ metadata = {
+ "video_paths": video_paths,
+ "video_pts": video_pts,
+ "video_fps": video_fps,
+ }
+ return type(self)(
+ video_paths,
+ clip_length_in_frames=self.num_frames,
+ frames_between_clips=self.step,
+ frame_rate=self.frame_rate,
+ _precomputed_metadata=metadata,
+ num_workers=self.num_workers,
+ _video_width=self._video_width,
+ _video_height=self._video_height,
+ _video_min_dimension=self._video_min_dimension,
+ _video_max_dimension=self._video_max_dimension,
+ _audio_samples=self._audio_samples,
+ _audio_channels=self._audio_channels,
+ output_format=self.output_format,
+ )
+
+ @staticmethod
+ def compute_clips_for_video(
+ video_pts: torch.Tensor, num_frames: int, step: int, fps: Optional[float], frame_rate: Optional[float] = None
+ ) -> tuple[torch.Tensor, Union[list[slice], torch.Tensor]]:
+ if fps is None:
+ # if for some reason the video doesn't have fps (because doesn't have a video stream)
+ # set the fps to 1. The value doesn't matter, because video_pts is empty anyway
+ fps = 1
+ if frame_rate is None:
+ frame_rate = fps
+ total_frames = len(video_pts) * frame_rate / fps
+ _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
+ video_pts = video_pts[_idxs]
+ clips = unfold(video_pts, num_frames, step)
+ if not clips.numel():
+ warnings.warn(
+ "There aren't enough frames in the current video to get a clip for the given clip length and "
+ "frames between clips. The video (and potentially others) will be skipped."
+ )
+ idxs: Union[list[slice], torch.Tensor]
+ if isinstance(_idxs, slice):
+ idxs = [_idxs] * len(clips)
+ else:
+ idxs = unfold(_idxs, num_frames, step)
+ return clips, idxs
+
+ def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[float] = None) -> None:
+ """
+ Compute all consecutive sequences of clips from video_pts.
+ Always returns clips of size `num_frames`, meaning that the
+ last few frames in a video can potentially be dropped.
+
+ Args:
+ num_frames (int): number of frames for the clip
+ step (int): distance between two clips
+ frame_rate (int, optional): The frame rate
+ """
+ self.num_frames = num_frames
+ self.step = step
+ self.frame_rate = frame_rate
+ self.clips = []
+ self.resampling_idxs = []
+ for video_pts, fps in zip(self.video_pts, self.video_fps):
+ clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
+ self.clips.append(clips)
+ self.resampling_idxs.append(idxs)
+ clip_lengths = torch.as_tensor([len(v) for v in self.clips])
+ self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
+
+ def __len__(self) -> int:
+ return self.num_clips()
+
+ def num_videos(self) -> int:
+ return len(self.video_paths)
+
+ def num_clips(self) -> int:
+ """
+ Number of subclips that are available in the video list.
+ """
+ return self.cumulative_sizes[-1]
+
+ def get_clip_location(self, idx: int) -> tuple[int, int]:
+ """
+ Converts a flattened representation of the indices into a video_idx, clip_idx
+ representation.
+ """
+ video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if video_idx == 0:
+ clip_idx = idx
+ else:
+ clip_idx = idx - self.cumulative_sizes[video_idx - 1]
+ return video_idx, clip_idx
+
+ @staticmethod
+ def _resample_video_idx(num_frames: int, original_fps: float, new_fps: float) -> Union[slice, torch.Tensor]:
+ step = original_fps / new_fps
+ if step.is_integer():
+ # optimization: if step is integer, don't need to perform
+ # advanced indexing
+ step = int(step)
+ return slice(None, None, step)
+ idxs = torch.arange(num_frames, dtype=torch.float32) * step
+ idxs = idxs.floor().to(torch.int64)
+ return idxs
+
+ def get_clip(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], int]:
+ """
+ Gets a subclip from a list of videos.
+
+ Args:
+ idx (int): index of the subclip. Must be between 0 and num_clips().
+
+ Returns:
+ video (Tensor)
+ audio (Tensor)
+ info (Dict)
+ video_idx (int): index of the video in `video_paths`
+ """
+ if idx >= self.num_clips():
+ raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
+ video_idx, clip_idx = self.get_clip_location(idx)
+ video_path = self.video_paths[video_idx]
+ clip_pts = self.clips[video_idx][clip_idx]
+
+ from torchvision import get_video_backend
+
+ backend = get_video_backend()
+
+ if backend == "pyav":
+ # check for invalid options
+ if self._video_width != 0:
+ raise ValueError("pyav backend doesn't support _video_width != 0")
+ if self._video_height != 0:
+ raise ValueError("pyav backend doesn't support _video_height != 0")
+ if self._video_min_dimension != 0:
+ raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
+ if self._video_max_dimension != 0:
+ raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
+ if self._audio_samples != 0:
+ raise ValueError("pyav backend doesn't support _audio_samples != 0")
+
+ if backend == "pyav":
+ start_pts = clip_pts[0].item()
+ end_pts = clip_pts[-1].item()
+ video, audio, info = read_video(video_path, start_pts, end_pts)
+ else:
+ _info = _probe_video_from_file(video_path)
+ video_fps = _info.video_fps
+ audio_fps = None
+
+ video_start_pts = cast(int, clip_pts[0].item())
+ video_end_pts = cast(int, clip_pts[-1].item())
+
+ audio_start_pts, audio_end_pts = 0, -1
+ audio_timebase = Fraction(0, 1)
+ video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
+ if _info.has_audio:
+ audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
+ audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
+ audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
+ audio_fps = _info.audio_sample_rate
+ video, audio, _ = _read_video_from_file(
+ video_path,
+ video_width=self._video_width,
+ video_height=self._video_height,
+ video_min_dimension=self._video_min_dimension,
+ video_max_dimension=self._video_max_dimension,
+ video_pts_range=(video_start_pts, video_end_pts),
+ video_timebase=video_timebase,
+ audio_samples=self._audio_samples,
+ audio_channels=self._audio_channels,
+ audio_pts_range=(audio_start_pts, audio_end_pts),
+ audio_timebase=audio_timebase,
+ )
+
+ info = {"video_fps": video_fps}
+ if audio_fps is not None:
+ info["audio_fps"] = audio_fps
+
+ if self.frame_rate is not None:
+ resampling_idx = self.resampling_idxs[video_idx][clip_idx]
+ if isinstance(resampling_idx, torch.Tensor):
+ resampling_idx = resampling_idx - resampling_idx[0]
+ video = video[resampling_idx]
+ info["video_fps"] = self.frame_rate
+ assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
+
+ if self.output_format == "TCHW":
+ # [T,H,W,C] --> [T,C,H,W]
+ video = video.permute(0, 3, 1, 2)
+
+ return video, audio, info, video_idx
+
+ def __getstate__(self) -> dict[str, Any]:
+ video_pts_sizes = [len(v) for v in self.video_pts]
+ # To be back-compatible, we convert data to dtype torch.long as needed
+ # because for empty list, in legacy implementation, torch.as_tensor will
+ # use torch.float as default dtype. This happens when decoding fails and
+ # no pts is returned in the list.
+ video_pts = [x.to(torch.int64) for x in self.video_pts]
+ # video_pts can be an empty list if no frames have been decoded
+ if video_pts:
+ video_pts = torch.cat(video_pts) # type: ignore[assignment]
+ # avoid bug in https://github.com/pytorch/pytorch/issues/32351
+ # TODO: Revert it once the bug is fixed.
+ video_pts = video_pts.numpy() # type: ignore[attr-defined]
+
+ # make a copy of the fields of self
+ d = self.__dict__.copy()
+ d["video_pts_sizes"] = video_pts_sizes
+ d["video_pts"] = video_pts
+ # delete the following attributes to reduce the size of dictionary. They
+ # will be re-computed in "__setstate__()"
+ del d["clips"]
+ del d["resampling_idxs"]
+ del d["cumulative_sizes"]
+
+ # for backwards-compatibility
+ d["_version"] = 2
+ return d
+
+ def __setstate__(self, d: dict[str, Any]) -> None:
+ # for backwards-compatibility
+ if "_version" not in d:
+ self.__dict__ = d
+ return
+
+ video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64)
+ video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0)
+ # don't need this info anymore
+ del d["video_pts_sizes"]
+
+ d["video_pts"] = video_pts
+ self.__dict__ = d
+ # recompute attributes "clips", "resampling_idxs" and other derivative ones
+ self.compute_clips(self.num_frames, self.step, self.frame_rate)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/vision.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..c43f7814c6c4462489b18348dd95078eb0e05c0a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/vision.py
@@ -0,0 +1,111 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+import torch.utils.data as data
+
+from ..utils import _log_api_usage_once
+
+
+class VisionDataset(data.Dataset):
+ """
+ Base Class For making datasets which are compatible with torchvision.
+ It is necessary to override the ``__getitem__`` and ``__len__`` method.
+
+ Args:
+ root (string, optional): Root directory of dataset. Only used for `__repr__`.
+ transforms (callable, optional): A function/transforms that takes in
+ an image and a label and returns the transformed versions of both.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+
+ .. note::
+
+ :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
+ """
+
+ _repr_indent = 4
+
+ def __init__(
+ self,
+ root: Union[str, Path] = None, # type: ignore[assignment]
+ transforms: Optional[Callable] = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ _log_api_usage_once(self)
+ if isinstance(root, str):
+ root = os.path.expanduser(root)
+ self.root = root
+
+ has_transforms = transforms is not None
+ has_separate_transform = transform is not None or target_transform is not None
+ if has_transforms and has_separate_transform:
+ raise ValueError("Only transforms or transform/target_transform can be passed as argument")
+
+ # for backwards-compatibility
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if has_separate_transform:
+ transforms = StandardTransform(transform, target_transform)
+ self.transforms = transforms
+
+ def __getitem__(self, index: int) -> Any:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ (Any): Sample and meta data, optionally transformed by the respective transforms.
+ """
+ raise NotImplementedError
+
+ def __len__(self) -> int:
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ head = "Dataset " + self.__class__.__name__
+ body = [f"Number of datapoints: {self.__len__()}"]
+ if self.root is not None:
+ body.append(f"Root location: {self.root}")
+ body += self.extra_repr().splitlines()
+ if hasattr(self, "transforms") and self.transforms is not None:
+ body += [repr(self.transforms)]
+ lines = [head] + [" " * self._repr_indent + line for line in body]
+ return "\n".join(lines)
+
+ def _format_transform_repr(self, transform: Callable, head: str) -> list[str]:
+ lines = transform.__repr__().splitlines()
+ return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+ def extra_repr(self) -> str:
+ return ""
+
+
+class StandardTransform:
+ def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
+ self.transform = transform
+ self.target_transform = target_transform
+
+ def __call__(self, input: Any, target: Any) -> tuple[Any, Any]:
+ if self.transform is not None:
+ input = self.transform(input)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return input, target
+
+ def _format_transform_repr(self, transform: Callable, head: str) -> list[str]:
+ lines = transform.__repr__().splitlines()
+ return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+ def __repr__(self) -> str:
+ body = [self.__class__.__name__]
+ if self.transform is not None:
+ body += self._format_transform_repr(self.transform, "Transform: ")
+ if self.target_transform is not None:
+ body += self._format_transform_repr(self.target_transform, "Target transform: ")
+
+ return "\n".join(body)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/voc.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d3e502d84e4153bc57a7f2a431a20ecd35348e3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/voc.py
@@ -0,0 +1,224 @@
+import collections
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+from xml.etree.ElementTree import Element as ET_Element
+
+try:
+ from defusedxml.ElementTree import parse as ET_parse
+except ImportError:
+ from xml.etree.ElementTree import parse as ET_parse
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+DATASET_YEAR_DICT = {
+ "2012": {
+ "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
+ "filename": "VOCtrainval_11-May-2012.tar",
+ "md5": "6cd6e144f989b92b3379bac3b3de84fd",
+ "base_dir": os.path.join("VOCdevkit", "VOC2012"),
+ },
+ "2011": {
+ "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
+ "filename": "VOCtrainval_25-May-2011.tar",
+ "md5": "6c3384ef61512963050cb5d687e5bf1e",
+ "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
+ },
+ "2010": {
+ "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
+ "filename": "VOCtrainval_03-May-2010.tar",
+ "md5": "da459979d0c395079b5c75ee67908abb",
+ "base_dir": os.path.join("VOCdevkit", "VOC2010"),
+ },
+ "2009": {
+ "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
+ "filename": "VOCtrainval_11-May-2009.tar",
+ "md5": "a3e00b113cfcfebf17e343f59da3caa1",
+ "base_dir": os.path.join("VOCdevkit", "VOC2009"),
+ },
+ "2008": {
+ "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
+ "filename": "VOCtrainval_11-May-2012.tar",
+ "md5": "2629fa636546599198acfcfbfcf1904a",
+ "base_dir": os.path.join("VOCdevkit", "VOC2008"),
+ },
+ "2007": {
+ "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
+ "filename": "VOCtrainval_06-Nov-2007.tar",
+ "md5": "c52e279531787c972589f7e41ab4ae64",
+ "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+ },
+ "2007-test": {
+ "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
+ "filename": "VOCtest_06-Nov-2007.tar",
+ "md5": "b6e924de25625d8de591ea690078ad9f",
+ "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+ },
+}
+
+
+class _VOCBase(VisionDataset):
+ _SPLITS_DIR: str
+ _TARGET_DIR: str
+ _TARGET_FILE_EXT: str
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ year: str = "2012",
+ image_set: str = "train",
+ download: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ transforms: Optional[Callable] = None,
+ ):
+ super().__init__(root, transforms, transform, target_transform)
+
+ self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
+
+ valid_image_sets = ["train", "trainval", "val"]
+ if year == "2007":
+ valid_image_sets.append("test")
+ self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
+
+ key = "2007-test" if year == "2007" and image_set == "test" else year
+ dataset_year_dict = DATASET_YEAR_DICT[key]
+
+ self.url = dataset_year_dict["url"]
+ self.filename = dataset_year_dict["filename"]
+ self.md5 = dataset_year_dict["md5"]
+
+ base_dir = dataset_year_dict["base_dir"]
+ voc_root = os.path.join(self.root, base_dir)
+
+ if download:
+ download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+
+ if not os.path.isdir(voc_root):
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
+ split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
+ with open(os.path.join(split_f)) as f:
+ file_names = [x.strip() for x in f.readlines()]
+
+ image_dir = os.path.join(voc_root, "JPEGImages")
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+
+ target_dir = os.path.join(voc_root, self._TARGET_DIR)
+ self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
+
+ assert len(self.images) == len(self.targets)
+
+ def __len__(self) -> int:
+ return len(self.images)
+
+
+class VOCSegmentation(_VOCBase):
+ """`Pascal VOC `_ Segmentation Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
+ year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+ image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+ ``year=="2007"``, can also be ``"test"``.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version.
+ """
+
+ _SPLITS_DIR = "Segmentation"
+ _TARGET_DIR = "SegmentationClass"
+ _TARGET_FILE_EXT = ".png"
+
+ @property
+ def masks(self) -> list[str]:
+ return self.targets
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is the image segmentation.
+ """
+ img = Image.open(self.images[index]).convert("RGB")
+ target = Image.open(self.masks[index])
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+
+class VOCDetection(_VOCBase):
+ """`Pascal VOC `_ Detection Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the VOC Dataset.
+ year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+ image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+ ``year=="2007"``, can also be ``"test"``.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ (default: alphabetic indexing of VOC's 20 classes).
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, required): A function/transform that takes in the
+ target and transforms it.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version.
+ """
+
+ _SPLITS_DIR = "Main"
+ _TARGET_DIR = "Annotations"
+ _TARGET_FILE_EXT = ".xml"
+
+ @property
+ def annotations(self) -> list[str]:
+ return self.targets
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a dictionary of the XML tree.
+ """
+ img = Image.open(self.images[index]).convert("RGB")
+ target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ @staticmethod
+ def parse_voc_xml(node: ET_Element) -> dict[str, Any]:
+ voc_dict: dict[str, Any] = {}
+ children = list(node)
+ if children:
+ def_dic: dict[str, Any] = collections.defaultdict(list)
+ for dc in map(VOCDetection.parse_voc_xml, children):
+ for ind, v in dc.items():
+ def_dic[ind].append(v)
+ if node.tag == "annotation":
+ def_dic["object"] = [def_dic["object"]]
+ voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
+ if node.text:
+ text = node.text.strip()
+ if not children:
+ voc_dict[node.tag] = text
+ return voc_dict
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/widerface.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/widerface.py
new file mode 100644
index 0000000000000000000000000000000000000000..31ab28ebdba2660ba5ec0a16b19361ad30a8a692
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/datasets/widerface.py
@@ -0,0 +1,196 @@
+import os
+from os.path import abspath, expanduser
+from pathlib import Path
+
+from typing import Any, Callable, Optional, Union
+
+import torch
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class WIDERFace(VisionDataset):
+ """`WIDERFace `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images and annotations are downloaded to.
+ Expects the following folder structure if download=False:
+
+ .. code::
+
+
+ └── widerface
+ ├── wider_face_split ('wider_face_split.zip' if compressed)
+ ├── WIDER_train ('WIDER_train.zip' if compressed)
+ ├── WIDER_val ('WIDER_val.zip' if compressed)
+ └── WIDER_test ('WIDER_test.zip' if compressed)
+ split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
+ Defaults to ``train``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ .. warning::
+
+ To download the dataset `gdown `_ is required.
+
+ """
+
+ BASE_FOLDER = "widerface"
+ FILE_LIST = [
+ # File ID MD5 Hash Filename
+ ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
+ ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
+ ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
+ ]
+ ANNOTATIONS_FILE = (
+ "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
+ "0e3767bcf0e326556d407bf5bff5d27c",
+ "wider_face_split.zip",
+ )
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(
+ root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
+ )
+ # check arguments
+ self.split = verify_str_arg(split, "split", ("train", "val", "test"))
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
+
+ self.img_info: list[dict[str, Union[str, dict[str, torch.Tensor]]]] = []
+ if self.split in ("train", "val"):
+ self.parse_train_val_annotations_file()
+ else:
+ self.parse_test_annotations_file()
+
+ def __getitem__(self, index: int) -> tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a dict of annotations for all faces in the image.
+ target=None for the test split.
+ """
+
+ # stay consistent with other datasets and return a PIL Image
+ img = Image.open(self.img_info[index]["img_path"]) # type: ignore[arg-type]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ target = None if self.split == "test" else self.img_info[index]["annotations"]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.img_info)
+
+ def extra_repr(self) -> str:
+ lines = ["Split: {split}"]
+ return "\n".join(lines).format(**self.__dict__)
+
+ def parse_train_val_annotations_file(self) -> None:
+ filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
+ filepath = os.path.join(self.root, "wider_face_split", filename)
+
+ with open(filepath) as f:
+ lines = f.readlines()
+ file_name_line, num_boxes_line, box_annotation_line = True, False, False
+ num_boxes, box_counter = 0, 0
+ labels = []
+ for line in lines:
+ line = line.rstrip()
+ if file_name_line:
+ img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
+ img_path = abspath(expanduser(img_path))
+ file_name_line = False
+ num_boxes_line = True
+ elif num_boxes_line:
+ num_boxes = int(line)
+ num_boxes_line = False
+ box_annotation_line = True
+ elif box_annotation_line:
+ box_counter += 1
+ line_split = line.split(" ")
+ line_values = [int(x) for x in line_split]
+ labels.append(line_values)
+ if box_counter >= num_boxes:
+ box_annotation_line = False
+ file_name_line = True
+ labels_tensor = torch.tensor(labels)
+ self.img_info.append(
+ {
+ "img_path": img_path,
+ "annotations": {
+ "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
+ "blur": labels_tensor[:, 4].clone(),
+ "expression": labels_tensor[:, 5].clone(),
+ "illumination": labels_tensor[:, 6].clone(),
+ "occlusion": labels_tensor[:, 7].clone(),
+ "pose": labels_tensor[:, 8].clone(),
+ "invalid": labels_tensor[:, 9].clone(),
+ },
+ }
+ )
+ box_counter = 0
+ labels.clear()
+ else:
+ raise RuntimeError(f"Error parsing annotation file {filepath}")
+
+ def parse_test_annotations_file(self) -> None:
+ filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
+ filepath = abspath(expanduser(filepath))
+ with open(filepath) as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.rstrip()
+ img_path = os.path.join(self.root, "WIDER_test", "images", line)
+ img_path = abspath(expanduser(img_path))
+ self.img_info.append({"img_path": img_path})
+
+ def _check_integrity(self) -> bool:
+ # Allow original archive to be deleted (zip). Only need the extracted images
+ all_files = self.FILE_LIST.copy()
+ all_files.append(self.ANNOTATIONS_FILE)
+ for _, md5, filename in all_files:
+ file, ext = os.path.splitext(filename)
+ extracted_dir = os.path.join(self.root, file)
+ if not os.path.exists(extracted_dir):
+ return False
+ return True
+
+ def download(self) -> None:
+ if self._check_integrity():
+ return
+
+ # download and extract image data
+ for file_id, md5, filename in self.FILE_LIST:
+ download_file_from_google_drive(file_id, self.root, filename, md5)
+ filepath = os.path.join(self.root, filename)
+ extract_archive(filepath)
+
+ # download and extract annotation files
+ download_and_extract_archive(
+ url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/transforms/_functional_video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/transforms/_functional_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..91df7d42cd71fc554aba51fcf5e90db30e3c3851
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/torchvision/transforms/_functional_video.py
@@ -0,0 +1,114 @@
+import warnings
+
+import torch
+
+
+warnings.warn(
+ "The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. "
+ "Please use the 'torchvision.transforms.functional' module instead."
+)
+
+
+def _is_tensor_video_clip(clip):
+ if not torch.is_tensor(clip):
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+ if not clip.ndimension() == 4:
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+ return True
+
+
+def crop(clip, i, j, h, w):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ """
+ if len(clip.size()) != 4:
+ raise ValueError("clip should be a 4D tensor")
+ return clip[..., i : i + h, j : j + w]
+
+
+def resize(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
+
+
+def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
+ """
+ Do spatial cropping and resizing to the video clip
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
+ h (int): Height of the cropped region.
+ w (int): Width of the cropped region.
+ size (tuple(int, int)): height and width of resized clip
+ Returns:
+ clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ clip = crop(clip, i, j, h, w)
+ clip = resize(clip, size, interpolation_mode)
+ return clip
+
+
+def center_crop(clip, crop_size):
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ h, w = clip.size(-2), clip.size(-1)
+ th, tw = crop_size
+ if h < th or w < tw:
+ raise ValueError("height and width must be no smaller than crop_size")
+
+ i = int(round((h - th) / 2.0))
+ j = int(round((w - tw) / 2.0))
+ return crop(clip, i, j, th, tw)
+
+
+def to_tensor(clip):
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
+ """
+ _is_tensor_video_clip(clip)
+ if not clip.dtype == torch.uint8:
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
+ return clip.float().permute(3, 0, 1, 2) / 255.0
+
+
+def normalize(clip, mean, std, inplace=False):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
+ mean (tuple): pixel RGB mean. Size is (3)
+ std (tuple): pixel standard deviation. Size is (3)
+ Returns:
+ normalized clip (torch.tensor): Size is (C, T, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ if not inplace:
+ clip = clip.clone()
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
+ return clip
+
+
+def hflip(clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
+ Returns:
+ flipped clip (torch.tensor): Size is (C, T, H, W)
+ """
+ if not _is_tensor_video_clip(clip):
+ raise ValueError("clip should be a 4D torch.tensor")
+ return clip.flip(-1)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef8d9170755cb4ea7ff93b0a7c6a53c2d25a142e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/add_fast_image_processor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/add_fast_image_processor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2537481a4feebd4d708823c2242eec17b7da1b7c
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/add_fast_image_processor.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/add_new_model_like.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/add_new_model_like.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..885fbf2bff4eccf7db82f1095bc59cbf3d575f0a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/add_new_model_like.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/chat.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/chat.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e28961d888261506af7933f875b56e5a7c2abd9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/chat.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/convert.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/convert.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55c9c869a32ad9f7d0f2e786801d9a799107733d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/convert.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/download.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/download.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f54597cd8e6cd429194dc5f23b340b52517b45a0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/download.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/env.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/env.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..116e20cc30431271726efeffabb127b36ecaef2e
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/env.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/run.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/run.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6917014de235546a9486e4f508732a24e763f1d
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/run.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/serving.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/serving.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d8cee7f444308f0bbfe3e9b71c44f5b8e0f3eed
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/serving.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/train.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/train.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74613d7dfdc7efa22aaf7acffd1088af1d8b61f7
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/train.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/transformers_cli.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/transformers_cli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f431529dc97b68df5e6481116dc4ee8ab489e629
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/commands/__pycache__/transformers_cli.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe461d3f957caa25eefa3f70907a3d8ae6a96ef0
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..378894ab4bbb4704b67b1de4ab512f145b889d46
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .glue import GlueDataset, GlueDataTrainingArguments
+from .language_modeling import (
+ LineByLineTextDataset,
+ LineByLineWithRefDataset,
+ LineByLineWithSOPTextDataset,
+ TextDataset,
+ TextDatasetForNextSentencePrediction,
+)
+from .squad import SquadDataset, SquadDataTrainingArguments
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3f473a42d441e9857c64c876fe9cd08c3745125
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/glue.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/glue.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66905ad08910afda905726e0ae9844093b5648b2
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/glue.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/language_modeling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/language_modeling.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bce63748bd1f9cc42948d4f052c1e84da0a4f18
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/language_modeling.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/squad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/squad.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57e6c1e861a2458d47bb102b775ffdf38cccf3c6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/__pycache__/squad.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/glue.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/glue.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8db0dfebac1a8432d18320df4f1f4eba4eb4030
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/glue.py
@@ -0,0 +1,162 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import warnings
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Optional, Union
+
+import torch
+from filelock import FileLock
+from torch.utils.data import Dataset
+
+from ...tokenization_utils_base import PreTrainedTokenizerBase
+from ...utils import check_torch_load_is_safe, logging
+from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
+from ..processors.utils import InputFeatures
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class GlueDataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+
+ Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
+ line.
+ """
+
+ task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
+ data_dir: str = field(
+ metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
+ )
+ max_seq_length: int = field(
+ default=128,
+ metadata={
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
+ },
+ )
+ overwrite_cache: bool = field(
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+ )
+
+ def __post_init__(self):
+ self.task_name = self.task_name.lower()
+
+
+class Split(Enum):
+ train = "train"
+ dev = "dev"
+ test = "test"
+
+
+class GlueDataset(Dataset):
+ """
+ This will be superseded by a framework-agnostic approach soon.
+ """
+
+ args: GlueDataTrainingArguments
+ output_mode: str
+ features: list[InputFeatures]
+
+ def __init__(
+ self,
+ args: GlueDataTrainingArguments,
+ tokenizer: PreTrainedTokenizerBase,
+ limit_length: Optional[int] = None,
+ mode: Union[str, Split] = Split.train,
+ cache_dir: Optional[str] = None,
+ ):
+ warnings.warn(
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
+ "library. You can have a look at this example script for pointers: "
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
+ FutureWarning,
+ )
+ self.args = args
+ self.processor = glue_processors[args.task_name]()
+ self.output_mode = glue_output_modes[args.task_name]
+ if isinstance(mode, str):
+ try:
+ mode = Split[mode]
+ except KeyError:
+ raise KeyError("mode is not a valid split name")
+ # Load data features from cache or dataset file
+ cached_features_file = os.path.join(
+ cache_dir if cache_dir is not None else args.data_dir,
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
+ )
+ label_list = self.processor.get_labels()
+ if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
+ "RobertaTokenizer",
+ "RobertaTokenizerFast",
+ "XLMRobertaTokenizer",
+ "BartTokenizer",
+ "BartTokenizerFast",
+ ):
+ # HACK(label indices are swapped in RoBERTa pretrained model)
+ label_list[1], label_list[2] = label_list[2], label_list[1]
+ self.label_list = label_list
+
+ # Make sure only the first process in distributed training processes the dataset,
+ # and the others will use the cache.
+ lock_path = cached_features_file + ".lock"
+ with FileLock(lock_path):
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
+ start = time.time()
+ check_torch_load_is_safe()
+ self.features = torch.load(cached_features_file, weights_only=True)
+ logger.info(
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
+ )
+ else:
+ logger.info(f"Creating features from dataset file at {args.data_dir}")
+
+ if mode == Split.dev:
+ examples = self.processor.get_dev_examples(args.data_dir)
+ elif mode == Split.test:
+ examples = self.processor.get_test_examples(args.data_dir)
+ else:
+ examples = self.processor.get_train_examples(args.data_dir)
+ if limit_length is not None:
+ examples = examples[:limit_length]
+ self.features = glue_convert_examples_to_features(
+ examples,
+ tokenizer,
+ max_length=args.max_seq_length,
+ label_list=label_list,
+ output_mode=self.output_mode,
+ )
+ start = time.time()
+ torch.save(self.features, cached_features_file)
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
+ logger.info(
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
+ )
+
+ def __len__(self):
+ return len(self.features)
+
+ def __getitem__(self, i) -> InputFeatures:
+ return self.features[i]
+
+ def get_labels(self):
+ return self.label_list
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/language_modeling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/language_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..07250ef3cb5402603c75ed2c1a4c2e2200fb3dbe
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/language_modeling.py
@@ -0,0 +1,530 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import pickle
+import random
+import time
+import warnings
+from typing import Optional
+
+import torch
+from filelock import FileLock
+from torch.utils.data import Dataset
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+DEPRECATION_WARNING = (
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
+ "library. You can have a look at this example script for pointers: {0}"
+)
+
+
+class TextDataset(Dataset):
+ """
+ This will be superseded by a framework-agnostic approach soon.
+ """
+
+ def __init__(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ file_path: str,
+ block_size: int,
+ overwrite_cache=False,
+ cache_dir: Optional[str] = None,
+ ):
+ warnings.warn(
+ DEPRECATION_WARNING.format(
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
+ ),
+ FutureWarning,
+ )
+ if os.path.isfile(file_path) is False:
+ raise ValueError(f"Input file path {file_path} not found")
+
+ block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
+
+ directory, filename = os.path.split(file_path)
+ cached_features_file = os.path.join(
+ cache_dir if cache_dir is not None else directory,
+ f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
+ )
+
+ # Make sure only the first process in distributed training processes the dataset,
+ # and the others will use the cache.
+ lock_path = cached_features_file + ".lock"
+ with FileLock(lock_path):
+ if os.path.exists(cached_features_file) and not overwrite_cache:
+ start = time.time()
+ with open(cached_features_file, "rb") as handle:
+ self.examples = pickle.load(handle)
+ logger.info(
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
+ )
+
+ else:
+ logger.info(f"Creating features from dataset file at {directory}")
+
+ self.examples = []
+ with open(file_path, encoding="utf-8") as f:
+ text = f.read()
+
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
+
+ for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
+ self.examples.append(
+ tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
+ )
+ # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
+ # If your dataset is small, first you should look for a bigger one :-) and second you
+ # can change this behavior by adding (model specific) padding.
+
+ start = time.time()
+ with open(cached_features_file, "wb") as handle:
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
+ logger.info(
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
+ )
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, i) -> torch.Tensor:
+ return torch.tensor(self.examples[i], dtype=torch.long)
+
+
+class LineByLineTextDataset(Dataset):
+ """
+ This will be superseded by a framework-agnostic approach soon.
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
+ warnings.warn(
+ DEPRECATION_WARNING.format(
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
+ ),
+ FutureWarning,
+ )
+ if os.path.isfile(file_path) is False:
+ raise ValueError(f"Input file path {file_path} not found")
+ # Here, we do not cache the features, operating under the assumption
+ # that we will soon use fast multithreaded tokenizers from the
+ # `tokenizers` repo everywhere =)
+ logger.info(f"Creating features from dataset file at {file_path}")
+
+ with open(file_path, encoding="utf-8") as f:
+ lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
+
+ batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
+ self.examples = batch_encoding["input_ids"]
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, i) -> dict[str, torch.tensor]:
+ return self.examples[i]
+
+
+class LineByLineWithRefDataset(Dataset):
+ """
+ This will be superseded by a framework-agnostic approach soon.
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
+ warnings.warn(
+ DEPRECATION_WARNING.format(
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
+ ),
+ FutureWarning,
+ )
+ if os.path.isfile(file_path) is False:
+ raise ValueError(f"Input file path {file_path} not found")
+ if os.path.isfile(ref_path) is False:
+ raise ValueError(f"Ref file path {file_path} not found")
+ # Here, we do not cache the features, operating under the assumption
+ # that we will soon use fast multithreaded tokenizers from the
+ # `tokenizers` repo everywhere =)
+ logger.info(f"Creating features from dataset file at {file_path}")
+ logger.info(f"Use ref segment results at {ref_path}")
+ with open(file_path, encoding="utf-8") as f:
+ data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
+ data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
+ # Get ref inf from file
+ with open(ref_path, encoding="utf-8") as f:
+ ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
+ if len(data) != len(ref):
+ raise ValueError(
+ f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
+ f"while length of {ref_path} is {len(ref)}"
+ )
+
+ batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
+ self.examples = batch_encoding["input_ids"]
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
+
+ n = len(self.examples)
+ for i in range(n):
+ self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, i) -> dict[str, torch.tensor]:
+ return self.examples[i]
+
+
+class LineByLineWithSOPTextDataset(Dataset):
+ """
+ Dataset for sentence order prediction task, prepare sentence pairs for SOP task
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
+ warnings.warn(
+ DEPRECATION_WARNING.format(
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
+ ),
+ FutureWarning,
+ )
+ if os.path.isdir(file_dir) is False:
+ raise ValueError(f"{file_dir} is not a directory")
+ logger.info(f"Creating features from dataset file folder at {file_dir}")
+ self.examples = []
+ # TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
+ # file path looks like ./dataset/wiki_1, ./dataset/wiki_2
+ for file_name in os.listdir(file_dir):
+ file_path = os.path.join(file_dir, file_name)
+ if os.path.isfile(file_path) is False:
+ raise ValueError(f"{file_path} is not a file")
+ article_open = False
+ with open(file_path, encoding="utf-8") as f:
+ original_lines = f.readlines()
+ article_lines = []
+ for line in original_lines:
+ if "" in line:
+ article_open = False
+ document = [
+ tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
+ for line in article_lines[1:]
+ if (len(line) > 0 and not line.isspace())
+ ]
+
+ examples = self.create_examples_from_document(document, block_size, tokenizer)
+ self.examples.extend(examples)
+ article_lines = []
+ else:
+ if article_open:
+ article_lines.append(line)
+
+ logger.info("Dataset parse finished.")
+
+ def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
+ """Creates examples for a single document."""
+
+ # Account for special tokens
+ max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
+
+ # We *usually* want to fill up the entire sequence since we are padding
+ # to `block_size` anyways, so short sequences are generally wasted
+ # computation. However, we *sometimes*
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
+ # The `target_seq_length` is just a rough target however, whereas
+ # `block_size` is a hard limit.
+ target_seq_length = max_num_tokens
+ if random.random() < short_seq_prob:
+ target_seq_length = random.randint(2, max_num_tokens)
+
+ # We DON'T just concatenate all of the tokens from a document into a long
+ # sequence and choose an arbitrary split point because this would make the
+ # next sentence prediction task too easy. Instead, we split the input into
+ # segments "A" and "B" based on the actual "sentences" provided by the user
+ # input.
+ examples = []
+ current_chunk = [] # a buffer stored current working segments
+ current_length = 0
+ i = 0
+ while i < len(document):
+ segment = document[i] # get a segment
+ if not segment:
+ i += 1
+ continue
+ current_chunk.append(segment) # add a segment to current chunk
+ current_length += len(segment) # overall token length
+ # if current length goes to the target length or reaches the end of file, start building token a and b
+ if i == len(document) - 1 or current_length >= target_seq_length:
+ if current_chunk:
+ # `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
+ a_end = 1
+ # if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
+ if len(current_chunk) >= 2:
+ a_end = random.randint(1, len(current_chunk) - 1)
+ # token a
+ tokens_a = []
+ for j in range(a_end):
+ tokens_a.extend(current_chunk[j])
+
+ # token b
+ tokens_b = []
+ for j in range(a_end, len(current_chunk)):
+ tokens_b.extend(current_chunk[j])
+
+ if len(tokens_a) == 0 or len(tokens_b) == 0:
+ continue
+
+ # switch tokens_a and tokens_b randomly
+ if random.random() < 0.5:
+ is_next = False
+ tokens_a, tokens_b = tokens_b, tokens_a
+ else:
+ is_next = True
+
+ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
+ """Truncates a pair of sequences to a maximum sequence length."""
+ while True:
+ total_length = len(tokens_a) + len(tokens_b)
+ if total_length <= max_num_tokens:
+ break
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
+ if not (len(trunc_tokens) >= 1):
+ raise ValueError("Sequence length to be truncated must be no less than one")
+ # We want to sometimes truncate from the front and sometimes from the
+ # back to add more randomness and avoid biases.
+ if random.random() < 0.5:
+ del trunc_tokens[0]
+ else:
+ trunc_tokens.pop()
+
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
+ if not (len(tokens_a) >= 1):
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
+ if not (len(tokens_b) >= 1):
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
+
+ # add special tokens
+ input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
+ # add token type ids, 0 for sentence a, 1 for sentence b
+ token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
+
+ example = {
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
+ "sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
+ }
+ examples.append(example)
+ current_chunk = [] # clear current chunk
+ current_length = 0 # reset current text length
+ i += 1 # go to next line
+ return examples
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, i) -> dict[str, torch.tensor]:
+ return self.examples[i]
+
+
+class TextDatasetForNextSentencePrediction(Dataset):
+ """
+ This will be superseded by a framework-agnostic approach soon.
+ """
+
+ def __init__(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ file_path: str,
+ block_size: int,
+ overwrite_cache=False,
+ short_seq_probability=0.1,
+ nsp_probability=0.5,
+ ):
+ warnings.warn(
+ DEPRECATION_WARNING.format(
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
+ ),
+ FutureWarning,
+ )
+ if not os.path.isfile(file_path):
+ raise ValueError(f"Input file path {file_path} not found")
+
+ self.short_seq_probability = short_seq_probability
+ self.nsp_probability = nsp_probability
+
+ directory, filename = os.path.split(file_path)
+ cached_features_file = os.path.join(
+ directory,
+ f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
+ )
+
+ self.tokenizer = tokenizer
+
+ # Make sure only the first process in distributed training processes the dataset,
+ # and the others will use the cache.
+ lock_path = cached_features_file + ".lock"
+
+ # Input file format:
+ # (1) One sentence per line. These should ideally be actual sentences, not
+ # entire paragraphs or arbitrary spans of text. (Because we use the
+ # sentence boundaries for the "next sentence prediction" task).
+ # (2) Blank lines between documents. Document boundaries are needed so
+ # that the "next sentence prediction" task doesn't span between documents.
+ #
+ # Example:
+ # I am very happy.
+ # Here is the second sentence.
+ #
+ # A new document.
+
+ with FileLock(lock_path):
+ if os.path.exists(cached_features_file) and not overwrite_cache:
+ start = time.time()
+ with open(cached_features_file, "rb") as handle:
+ self.examples = pickle.load(handle)
+ logger.info(
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
+ )
+ else:
+ logger.info(f"Creating features from dataset file at {directory}")
+
+ self.documents = [[]]
+ with open(file_path, encoding="utf-8") as f:
+ while True:
+ line = f.readline()
+ if not line:
+ break
+ line = line.strip()
+
+ # Empty lines are used as document delimiters
+ if not line and len(self.documents[-1]) != 0:
+ self.documents.append([])
+ tokens = tokenizer.tokenize(line)
+ tokens = tokenizer.convert_tokens_to_ids(tokens)
+ if tokens:
+ self.documents[-1].append(tokens)
+
+ logger.info(f"Creating examples from {len(self.documents)} documents.")
+ self.examples = []
+ for doc_index, document in enumerate(self.documents):
+ self.create_examples_from_document(document, doc_index, block_size)
+
+ start = time.time()
+ with open(cached_features_file, "wb") as handle:
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
+ logger.info(
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
+ )
+
+ def create_examples_from_document(self, document: list[list[int]], doc_index: int, block_size: int):
+ """Creates examples for a single document."""
+
+ max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
+
+ # We *usually* want to fill up the entire sequence since we are padding
+ # to `block_size` anyways, so short sequences are generally wasted
+ # computation. However, we *sometimes*
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
+ # The `target_seq_length` is just a rough target however, whereas
+ # `block_size` is a hard limit.
+ target_seq_length = max_num_tokens
+ if random.random() < self.short_seq_probability:
+ target_seq_length = random.randint(2, max_num_tokens)
+
+ current_chunk = [] # a buffer stored current working segments
+ current_length = 0
+ i = 0
+
+ while i < len(document):
+ segment = document[i]
+ current_chunk.append(segment)
+ current_length += len(segment)
+ if i == len(document) - 1 or current_length >= target_seq_length:
+ if current_chunk:
+ # `a_end` is how many segments from `current_chunk` go into the `A`
+ # (first) sentence.
+ a_end = 1
+ if len(current_chunk) >= 2:
+ a_end = random.randint(1, len(current_chunk) - 1)
+
+ tokens_a = []
+ for j in range(a_end):
+ tokens_a.extend(current_chunk[j])
+
+ tokens_b = []
+
+ if len(current_chunk) == 1 or random.random() < self.nsp_probability:
+ is_random_next = True
+ target_b_length = target_seq_length - len(tokens_a)
+
+ # This should rarely go for more than one iteration for large
+ # corpora. However, just to be careful, we try to make sure that
+ # the random document is not the same as the document
+ # we're processing.
+ for _ in range(10):
+ random_document_index = random.randint(0, len(self.documents) - 1)
+ if random_document_index != doc_index:
+ break
+
+ random_document = self.documents[random_document_index]
+ random_start = random.randint(0, len(random_document) - 1)
+ for j in range(random_start, len(random_document)):
+ tokens_b.extend(random_document[j])
+ if len(tokens_b) >= target_b_length:
+ break
+ # We didn't actually use these segments so we "put them back" so
+ # they don't go to waste.
+ num_unused_segments = len(current_chunk) - a_end
+ i -= num_unused_segments
+ # Actual next
+ else:
+ is_random_next = False
+ for j in range(a_end, len(current_chunk)):
+ tokens_b.extend(current_chunk[j])
+
+ if not (len(tokens_a) >= 1):
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
+ if not (len(tokens_b) >= 1):
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
+
+ # add special tokens
+ input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
+ # add token type ids, 0 for sentence a, 1 for sentence b
+ token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
+
+ example = {
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
+ "next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
+ }
+
+ self.examples.append(example)
+
+ current_chunk = []
+ current_length = 0
+
+ i += 1
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, i):
+ return self.examples[i]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/squad.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/squad.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4f76a51f422b00ffce5215b65c9274ce5646489
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/datasets/squad.py
@@ -0,0 +1,230 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Optional, Union
+
+import torch
+from filelock import FileLock
+from torch.utils.data import Dataset
+
+from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import check_torch_load_is_safe, logging
+from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
+
+
+logger = logging.get_logger(__name__)
+
+MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
+MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
+
+
+@dataclass
+class SquadDataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ """
+
+ model_type: str = field(
+ default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
+ )
+ data_dir: str = field(
+ default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
+ )
+ max_seq_length: int = field(
+ default=128,
+ metadata={
+ "help": (
+ "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ )
+ },
+ )
+ doc_stride: int = field(
+ default=128,
+ metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
+ )
+ max_query_length: int = field(
+ default=64,
+ metadata={
+ "help": (
+ "The maximum number of tokens for the question. Questions longer than this will "
+ "be truncated to this length."
+ )
+ },
+ )
+ max_answer_length: int = field(
+ default=30,
+ metadata={
+ "help": (
+ "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ )
+ },
+ )
+ overwrite_cache: bool = field(
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+ )
+ version_2_with_negative: bool = field(
+ default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
+ )
+ null_score_diff_threshold: float = field(
+ default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
+ )
+ n_best_size: int = field(
+ default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
+ )
+ lang_id: int = field(
+ default=0,
+ metadata={
+ "help": (
+ "language id of input for language-specific xlm models (see"
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
+ )
+ },
+ )
+ threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
+
+
+class Split(Enum):
+ train = "train"
+ dev = "dev"
+
+
+class SquadDataset(Dataset):
+ """
+ This will be superseded by a framework-agnostic approach soon.
+ """
+
+ args: SquadDataTrainingArguments
+ features: list[SquadFeatures]
+ mode: Split
+ is_language_sensitive: bool
+
+ def __init__(
+ self,
+ args: SquadDataTrainingArguments,
+ tokenizer: PreTrainedTokenizer,
+ limit_length: Optional[int] = None,
+ mode: Union[str, Split] = Split.train,
+ is_language_sensitive: bool = False,
+ cache_dir: Optional[str] = None,
+ dataset_format: str = "pt",
+ ):
+ self.args = args
+ self.is_language_sensitive = is_language_sensitive
+ self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
+ if isinstance(mode, str):
+ try:
+ mode = Split[mode]
+ except KeyError:
+ raise KeyError("mode is not a valid split name")
+ self.mode = mode
+ # Load data features from cache or dataset file
+ version_tag = "v2" if args.version_2_with_negative else "v1"
+ cached_features_file = os.path.join(
+ cache_dir if cache_dir is not None else args.data_dir,
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
+ )
+
+ # Make sure only the first process in distributed training processes the dataset,
+ # and the others will use the cache.
+ lock_path = cached_features_file + ".lock"
+ with FileLock(lock_path):
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
+ start = time.time()
+ check_torch_load_is_safe()
+ self.old_features = torch.load(cached_features_file, weights_only=True)
+
+ # Legacy cache files have only features, while new cache files
+ # will have dataset and examples also.
+ self.features = self.old_features["features"]
+ self.dataset = self.old_features.get("dataset", None)
+ self.examples = self.old_features.get("examples", None)
+ logger.info(
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
+ )
+
+ if self.dataset is None or self.examples is None:
+ logger.warning(
+ f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
+ " future run"
+ )
+ else:
+ if mode == Split.dev:
+ self.examples = self.processor.get_dev_examples(args.data_dir)
+ else:
+ self.examples = self.processor.get_train_examples(args.data_dir)
+
+ self.features, self.dataset = squad_convert_examples_to_features(
+ examples=self.examples,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length,
+ doc_stride=args.doc_stride,
+ max_query_length=args.max_query_length,
+ is_training=mode == Split.train,
+ threads=args.threads,
+ return_dataset=dataset_format,
+ )
+
+ start = time.time()
+ torch.save(
+ {"features": self.features, "dataset": self.dataset, "examples": self.examples},
+ cached_features_file,
+ )
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
+ logger.info(
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
+ )
+
+ def __len__(self):
+ return len(self.features)
+
+ def __getitem__(self, i) -> dict[str, torch.Tensor]:
+ # Convert to Tensors and build dataset
+ feature = self.features[i]
+
+ input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
+ attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
+ token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
+ cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
+ p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
+ is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
+
+ inputs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ }
+
+ if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
+ del inputs["token_type_ids"]
+
+ if self.args.model_type in ["xlnet", "xlm"]:
+ inputs.update({"cls_index": cls_index, "p_mask": p_mask})
+ if self.args.version_2_with_negative:
+ inputs.update({"is_impossible": is_impossible})
+ if self.is_language_sensitive:
+ inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
+
+ if self.mode == Split.train:
+ start_positions = torch.tensor(feature.start_position, dtype=torch.long)
+ end_positions = torch.tensor(feature.end_position, dtype=torch.long)
+ inputs.update({"start_positions": start_positions, "end_positions": end_positions})
+
+ return inputs
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebd0d17aa55bb4529820ce347f6275d38f6c0caa
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__init__.py
@@ -0,0 +1,98 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+
+from ...utils import is_sklearn_available, requires_backends
+
+
+if is_sklearn_available():
+ from scipy.stats import pearsonr, spearmanr
+ from sklearn.metrics import f1_score, matthews_corrcoef
+
+
+DEPRECATION_WARNING = (
+ "This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
+ "library. You can have a look at this example script for pointers: "
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
+)
+
+
+def simple_accuracy(preds, labels):
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ requires_backends(simple_accuracy, "sklearn")
+ return (preds == labels).mean()
+
+
+def acc_and_f1(preds, labels):
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ requires_backends(acc_and_f1, "sklearn")
+ acc = simple_accuracy(preds, labels)
+ f1 = f1_score(y_true=labels, y_pred=preds)
+ return {
+ "acc": acc,
+ "f1": f1,
+ "acc_and_f1": (acc + f1) / 2,
+ }
+
+
+def pearson_and_spearman(preds, labels):
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ requires_backends(pearson_and_spearman, "sklearn")
+ pearson_corr = pearsonr(preds, labels)[0]
+ spearman_corr = spearmanr(preds, labels)[0]
+ return {
+ "pearson": pearson_corr,
+ "spearmanr": spearman_corr,
+ "corr": (pearson_corr + spearman_corr) / 2,
+ }
+
+
+def glue_compute_metrics(task_name, preds, labels):
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ requires_backends(glue_compute_metrics, "sklearn")
+ assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
+ if task_name == "cola":
+ return {"mcc": matthews_corrcoef(labels, preds)}
+ elif task_name == "sst-2":
+ return {"acc": simple_accuracy(preds, labels)}
+ elif task_name == "mrpc":
+ return acc_and_f1(preds, labels)
+ elif task_name == "sts-b":
+ return pearson_and_spearman(preds, labels)
+ elif task_name == "qqp":
+ return acc_and_f1(preds, labels)
+ elif task_name == "mnli":
+ return {"mnli/acc": simple_accuracy(preds, labels)}
+ elif task_name == "mnli-mm":
+ return {"mnli-mm/acc": simple_accuracy(preds, labels)}
+ elif task_name == "qnli":
+ return {"acc": simple_accuracy(preds, labels)}
+ elif task_name == "rte":
+ return {"acc": simple_accuracy(preds, labels)}
+ elif task_name == "wnli":
+ return {"acc": simple_accuracy(preds, labels)}
+ elif task_name == "hans":
+ return {"acc": simple_accuracy(preds, labels)}
+ else:
+ raise KeyError(task_name)
+
+
+def xnli_compute_metrics(task_name, preds, labels):
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ requires_backends(xnli_compute_metrics, "sklearn")
+ if len(preds) != len(labels):
+ raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
+ if task_name == "xnli":
+ return {"acc": simple_accuracy(preds, labels)}
+ else:
+ raise KeyError(task_name)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14fc819ba778cae6040759d64cf1679147fb4207
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__pycache__/squad_metrics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__pycache__/squad_metrics.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d34a488072a84aa4409ad04ccc036b9e7bd747d6
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/__pycache__/squad_metrics.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/squad_metrics.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/squad_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ffc025b65a0451523004df12f5a4ae5e9d17b9a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/metrics/squad_metrics.py
@@ -0,0 +1,779 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
+update `find_best_threshold` scripts for SQuAD V2.0
+
+In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
+additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
+probability that a question is unanswerable.
+"""
+
+import collections
+import json
+import math
+import re
+import string
+
+from ...models.bert import BasicTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
+ return re.sub(regex, " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def get_tokens(s):
+ if not s:
+ return []
+ return normalize_answer(s).split()
+
+
+def compute_exact(a_gold, a_pred):
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
+
+
+def compute_f1(a_gold, a_pred):
+ gold_toks = get_tokens(a_gold)
+ pred_toks = get_tokens(a_pred)
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
+ num_same = sum(common.values())
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
+ return int(gold_toks == pred_toks)
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(pred_toks)
+ recall = 1.0 * num_same / len(gold_toks)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def get_raw_scores(examples, preds):
+ """
+ Computes the exact and f1 scores from the examples and the model predictions
+ """
+ exact_scores = {}
+ f1_scores = {}
+
+ for example in examples:
+ qas_id = example.qas_id
+ gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
+
+ if not gold_answers:
+ # For unanswerable questions, only correct answer is empty string
+ gold_answers = [""]
+
+ if qas_id not in preds:
+ print(f"Missing prediction for {qas_id}")
+ continue
+
+ prediction = preds[qas_id]
+ exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
+ f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
+
+ return exact_scores, f1_scores
+
+
+def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
+ new_scores = {}
+ for qid, s in scores.items():
+ pred_na = na_probs[qid] > na_prob_thresh
+ if pred_na:
+ new_scores[qid] = float(not qid_to_has_ans[qid])
+ else:
+ new_scores[qid] = s
+ return new_scores
+
+
+def make_eval_dict(exact_scores, f1_scores, qid_list=None):
+ if not qid_list:
+ total = len(exact_scores)
+ return collections.OrderedDict(
+ [
+ ("exact", 100.0 * sum(exact_scores.values()) / total),
+ ("f1", 100.0 * sum(f1_scores.values()) / total),
+ ("total", total),
+ ]
+ )
+ else:
+ total = len(qid_list)
+ return collections.OrderedDict(
+ [
+ ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
+ ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
+ ("total", total),
+ ]
+ )
+
+
+def merge_eval(main_eval, new_eval, prefix):
+ for k in new_eval:
+ main_eval[f"{prefix}_{k}"] = new_eval[k]
+
+
+def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
+ cur_score = num_no_ans
+ best_score = cur_score
+ best_thresh = 0.0
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
+ for qid in qid_list:
+ if qid not in scores:
+ continue
+ if qid_to_has_ans[qid]:
+ diff = scores[qid]
+ else:
+ if preds[qid]:
+ diff = -1
+ else:
+ diff = 0
+ cur_score += diff
+ if cur_score > best_score:
+ best_score = cur_score
+ best_thresh = na_probs[qid]
+
+ has_ans_score, has_ans_cnt = 0, 0
+ for qid in qid_list:
+ if not qid_to_has_ans[qid]:
+ continue
+ has_ans_cnt += 1
+
+ if qid not in scores:
+ continue
+ has_ans_score += scores[qid]
+
+ return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
+
+
+def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
+ best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
+ best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
+ main_eval["best_exact"] = best_exact
+ main_eval["best_exact_thresh"] = exact_thresh
+ main_eval["best_f1"] = best_f1
+ main_eval["best_f1_thresh"] = f1_thresh
+ main_eval["has_ans_exact"] = has_ans_exact
+ main_eval["has_ans_f1"] = has_ans_f1
+
+
+def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
+ cur_score = num_no_ans
+ best_score = cur_score
+ best_thresh = 0.0
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
+ for _, qid in enumerate(qid_list):
+ if qid not in scores:
+ continue
+ if qid_to_has_ans[qid]:
+ diff = scores[qid]
+ else:
+ if preds[qid]:
+ diff = -1
+ else:
+ diff = 0
+ cur_score += diff
+ if cur_score > best_score:
+ best_score = cur_score
+ best_thresh = na_probs[qid]
+ return 100.0 * best_score / len(scores), best_thresh
+
+
+def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
+ best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
+ best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
+
+ main_eval["best_exact"] = best_exact
+ main_eval["best_exact_thresh"] = exact_thresh
+ main_eval["best_f1"] = best_f1
+ main_eval["best_f1_thresh"] = f1_thresh
+
+
+def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
+ qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
+ has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
+ no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
+
+ if no_answer_probs is None:
+ no_answer_probs = dict.fromkeys(preds, 0.0)
+
+ exact, f1 = get_raw_scores(examples, preds)
+
+ exact_threshold = apply_no_ans_threshold(
+ exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
+ )
+ f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
+
+ evaluation = make_eval_dict(exact_threshold, f1_threshold)
+
+ if has_answer_qids:
+ has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
+ merge_eval(evaluation, has_ans_eval, "HasAns")
+
+ if no_answer_qids:
+ no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
+ merge_eval(evaluation, no_ans_eval, "NoAns")
+
+ if no_answer_probs:
+ find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
+
+ return evaluation
+
+
+def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
+ """Project the tokenized prediction back to the original text."""
+
+ # When we created the data, we kept track of the alignment between original
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
+ # now `orig_text` contains the span of our original text corresponding to the
+ # span that we predicted.
+ #
+ # However, `orig_text` may contain extra characters that we don't want in
+ # our prediction.
+ #
+ # For example, let's say:
+ # pred_text = steve smith
+ # orig_text = Steve Smith's
+ #
+ # We don't want to return `orig_text` because it contains the extra "'s".
+ #
+ # We don't want to return `pred_text` because it's already been normalized
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
+ # our tokenizer does additional normalization like stripping accent
+ # characters).
+ #
+ # What we really want to return is "Steve Smith".
+ #
+ # Therefore, we have to apply a semi-complicated alignment heuristic between
+ # `pred_text` and `orig_text` to get a character-to-character alignment. This
+ # can fail in certain cases in which case we just return `orig_text`.
+
+ def _strip_spaces(text):
+ ns_chars = []
+ ns_to_s_map = collections.OrderedDict()
+ for i, c in enumerate(text):
+ if c == " ":
+ continue
+ ns_to_s_map[len(ns_chars)] = i
+ ns_chars.append(c)
+ ns_text = "".join(ns_chars)
+ return (ns_text, ns_to_s_map)
+
+ # We first tokenize `orig_text`, strip whitespace from the result
+ # and `pred_text`, and check if they are the same length. If they are
+ # NOT the same length, the heuristic has failed. If they are the same
+ # length, we assume the characters are one-to-one aligned.
+ tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
+
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
+
+ start_position = tok_text.find(pred_text)
+ if start_position == -1:
+ if verbose_logging:
+ logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
+ return orig_text
+ end_position = start_position + len(pred_text) - 1
+
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
+
+ if len(orig_ns_text) != len(tok_ns_text):
+ if verbose_logging:
+ logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
+ return orig_text
+
+ # We then project the characters in `pred_text` back to `orig_text` using
+ # the character-to-character alignment.
+ tok_s_to_ns_map = {}
+ for i, tok_index in tok_ns_to_s_map.items():
+ tok_s_to_ns_map[tok_index] = i
+
+ orig_start_position = None
+ if start_position in tok_s_to_ns_map:
+ ns_start_position = tok_s_to_ns_map[start_position]
+ if ns_start_position in orig_ns_to_s_map:
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
+
+ if orig_start_position is None:
+ if verbose_logging:
+ logger.info("Couldn't map start position")
+ return orig_text
+
+ orig_end_position = None
+ if end_position in tok_s_to_ns_map:
+ ns_end_position = tok_s_to_ns_map[end_position]
+ if ns_end_position in orig_ns_to_s_map:
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
+
+ if orig_end_position is None:
+ if verbose_logging:
+ logger.info("Couldn't map end position")
+ return orig_text
+
+ output_text = orig_text[orig_start_position : (orig_end_position + 1)]
+ return output_text
+
+
+def _get_best_indexes(logits, n_best_size):
+ """Get the n-best logits from a list."""
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
+
+ best_indexes = []
+ for i in range(len(index_and_score)):
+ if i >= n_best_size:
+ break
+ best_indexes.append(index_and_score[i][0])
+ return best_indexes
+
+
+def _compute_softmax(scores):
+ """Compute softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
+
+
+def compute_predictions_logits(
+ all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ verbose_logging,
+ version_2_with_negative,
+ null_score_diff_threshold,
+ tokenizer,
+):
+ """Write final predictions to the json file and log-odds of null if needed."""
+ if output_prediction_file:
+ logger.info(f"Writing predictions to: {output_prediction_file}")
+ if output_nbest_file:
+ logger.info(f"Writing nbest to: {output_nbest_file}")
+ if output_null_log_odds_file and version_2_with_negative:
+ logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
+
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
+ )
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for example_index, example in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+ min_null_feature_index = 0 # the paragraph slice with min null score
+ null_start_logit = 0 # the start logit at the slice with min null score
+ null_end_logit = 0 # the end logit at the slice with min null score
+ for feature_index, feature in enumerate(features):
+ result = unique_id_to_result[feature.unique_id]
+ start_indexes = _get_best_indexes(result.start_logits, n_best_size)
+ end_indexes = _get_best_indexes(result.end_logits, n_best_size)
+ # if we could have irrelevant answers, get the min score of irrelevant
+ if version_2_with_negative:
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
+ if feature_null_score < score_null:
+ score_null = feature_null_score
+ min_null_feature_index = feature_index
+ null_start_logit = result.start_logits[0]
+ null_end_logit = result.end_logits[0]
+ for start_index in start_indexes:
+ for end_index in end_indexes:
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index >= len(feature.tokens):
+ continue
+ if end_index >= len(feature.tokens):
+ continue
+ if start_index not in feature.token_to_orig_map:
+ continue
+ if end_index not in feature.token_to_orig_map:
+ continue
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index,
+ end_index=end_index,
+ start_logit=result.start_logits[start_index],
+ end_logit=result.end_logits[end_index],
+ )
+ )
+ if version_2_with_negative:
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=min_null_feature_index,
+ start_index=0,
+ end_index=0,
+ start_logit=null_start_logit,
+ end_logit=null_end_logit,
+ )
+ )
+ prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
+
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_logit", "end_logit"]
+ )
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+ if pred.start_index > 0: # this is a non-null prediction
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
+
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
+
+ # tok_text = " ".join(tok_tokens)
+ #
+ # # De-tokenize WordPieces that have been split off.
+ # tok_text = tok_text.replace(" ##", "")
+ # tok_text = tok_text.replace("##", "")
+
+ # Clean whitespace
+ tok_text = tok_text.strip()
+ tok_text = " ".join(tok_text.split())
+ orig_text = " ".join(orig_tokens)
+
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+ else:
+ final_text = ""
+ seen_predictions[final_text] = True
+
+ nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
+ # if we didn't include the empty option in the n-best, include it
+ if version_2_with_negative:
+ if "" not in seen_predictions:
+ nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
+
+ # In very rare edge cases we could only have single null prediction.
+ # So we just create a nonce prediction in this case to avoid failure.
+ if len(nbest) == 1:
+ nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
+
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
+
+ if len(nbest) < 1:
+ raise ValueError("No valid predictions")
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_logit + entry.end_logit)
+ if not best_non_null_entry:
+ if entry.text:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for i, entry in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_logit"] = entry.start_logit
+ output["end_logit"] = entry.end_logit
+ nbest_json.append(output)
+
+ if len(nbest_json) < 1:
+ raise ValueError("No valid predictions")
+
+ if not version_2_with_negative:
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
+ else:
+ # predict "" iff the null score - the score of best non-null > threshold
+ score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
+ scores_diff_json[example.qas_id] = score_diff
+ if score_diff > null_score_diff_threshold:
+ all_predictions[example.qas_id] = ""
+ else:
+ all_predictions[example.qas_id] = best_non_null_entry.text
+ all_nbest_json[example.qas_id] = nbest_json
+
+ if output_prediction_file:
+ with open(output_prediction_file, "w") as writer:
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
+
+ if output_nbest_file:
+ with open(output_nbest_file, "w") as writer:
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
+
+ if output_null_log_odds_file and version_2_with_negative:
+ with open(output_null_log_odds_file, "w") as writer:
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
+
+ return all_predictions
+
+
+def compute_predictions_log_probs(
+ all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ start_n_top,
+ end_n_top,
+ version_2_with_negative,
+ tokenizer,
+ verbose_logging,
+):
+ """
+ XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
+ null if needed.
+
+ Requires utils_squad_evaluate.py
+ """
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
+ )
+
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
+ )
+
+ logger.info(f"Writing predictions to: {output_prediction_file}")
+
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for example_index, example in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+
+ for feature_index, feature in enumerate(features):
+ result = unique_id_to_result[feature.unique_id]
+
+ cur_null_score = result.cls_logits
+
+ # if we could have irrelevant answers, get the min score of irrelevant
+ score_null = min(score_null, cur_null_score)
+
+ for i in range(start_n_top):
+ for j in range(end_n_top):
+ start_log_prob = result.start_logits[i]
+ start_index = result.start_top_index[i]
+
+ j_index = i * end_n_top + j
+
+ end_log_prob = result.end_logits[j_index]
+ end_index = result.end_top_index[j_index]
+
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index >= feature.paragraph_len - 1:
+ continue
+ if end_index >= feature.paragraph_len - 1:
+ continue
+
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index,
+ end_index=end_index,
+ start_log_prob=start_log_prob,
+ end_log_prob=end_log_prob,
+ )
+ )
+
+ prelim_predictions = sorted(
+ prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
+ )
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+
+ # XLNet un-tokenizer
+ # Let's keep it simple for now and see if we need all this later.
+ #
+ # tok_start_to_orig_index = feature.tok_start_to_orig_index
+ # tok_end_to_orig_index = feature.tok_end_to_orig_index
+ # start_orig_pos = tok_start_to_orig_index[pred.start_index]
+ # end_orig_pos = tok_end_to_orig_index[pred.end_index]
+ # paragraph_text = example.paragraph_text
+ # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
+
+ # Previously used Bert untokenizer
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
+
+ # Clean whitespace
+ tok_text = tok_text.strip()
+ tok_text = " ".join(tok_text.split())
+ orig_text = " ".join(orig_tokens)
+
+ if hasattr(tokenizer, "do_lower_case"):
+ do_lower_case = tokenizer.do_lower_case
+ else:
+ do_lower_case = tokenizer.do_lowercase_and_remove_accent
+
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
+
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+
+ nbest.append(
+ _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
+ )
+
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_log_prob + entry.end_log_prob)
+ if not best_non_null_entry:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for i, entry in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_log_prob"] = entry.start_log_prob
+ output["end_log_prob"] = entry.end_log_prob
+ nbest_json.append(output)
+
+ if len(nbest_json) < 1:
+ raise ValueError("No valid predictions")
+ if best_non_null_entry is None:
+ raise ValueError("No valid predictions")
+
+ score_diff = score_null
+ scores_diff_json[example.qas_id] = score_diff
+ # note(zhiliny): always predict best_non_null_entry
+ # and the evaluation script will search for the best threshold
+ all_predictions[example.qas_id] = best_non_null_entry.text
+
+ all_nbest_json[example.qas_id] = nbest_json
+
+ with open(output_prediction_file, "w") as writer:
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
+
+ with open(output_nbest_file, "w") as writer:
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
+
+ if version_2_with_negative:
+ with open(output_null_log_odds_file, "w") as writer:
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
+
+ return all_predictions
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a26ab5776d74715428b10c4d9cd943e53b253785
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
+from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
+from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
+from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29ed5b7ef35d4634a8cde77dce8c9c2ed18c7063
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/glue.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/glue.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc4dee11d5b40257bfae832d1e6a1bbcbf4c82c1
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/glue.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/squad.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/squad.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ef822fe054c94d9ce0e3771ea118eee9e31b223
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/squad.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b93ed4a62e8efae9a7a0b24494e0b05f2395679a
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/xnli.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/xnli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7c99cfef8a4c74b7e066ae7a42fd859e2b87321
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/__pycache__/xnli.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/glue.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/glue.py
new file mode 100644
index 0000000000000000000000000000000000000000..e005c9bcda13d15bc3aa32a50c79941166d0ba28
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/glue.py
@@ -0,0 +1,643 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GLUE processors and helpers"""
+
+import os
+import warnings
+from dataclasses import asdict
+from enum import Enum
+from typing import Optional, Union
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import is_tf_available, logging
+from .utils import DataProcessor, InputExample, InputFeatures
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+logger = logging.get_logger(__name__)
+
+DEPRECATION_WARNING = (
+ "This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
+ "library. You can have a look at this example script for pointers: "
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
+)
+
+
+def glue_convert_examples_to_features(
+ examples: Union[list[InputExample], "tf.data.Dataset"],
+ tokenizer: PreTrainedTokenizer,
+ max_length: Optional[int] = None,
+ task=None,
+ label_list=None,
+ output_mode=None,
+):
+ """
+ Loads a data file into a list of `InputFeatures`
+
+ Args:
+ examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
+ tokenizer: Instance of a tokenizer that will tokenize the examples
+ max_length: Maximum example length. Defaults to the tokenizer's max_len
+ task: GLUE task
+ label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
+ output_mode: String indicating the output mode. Either `regression` or `classification`
+
+ Returns:
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
+ features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
+ can be fed to the model.
+
+ """
+ warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
+ if is_tf_available() and isinstance(examples, tf.data.Dataset):
+ if task is None:
+ raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
+ return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
+ return _glue_convert_examples_to_features(
+ examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
+ )
+
+
+if is_tf_available():
+
+ def _tf_glue_convert_examples_to_features(
+ examples: tf.data.Dataset,
+ tokenizer: PreTrainedTokenizer,
+ task=str,
+ max_length: Optional[int] = None,
+ ) -> tf.data.Dataset:
+ """
+ Returns:
+ A `tf.data.Dataset` containing the task-specific features.
+
+ """
+ processor = glue_processors[task]()
+ examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
+ features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
+ label_type = tf.float32 if task == "sts-b" else tf.int64
+
+ def gen():
+ for ex in features:
+ d = {k: v for k, v in asdict(ex).items() if v is not None}
+ label = d.pop("label")
+ yield (d, label)
+
+ input_names = tokenizer.model_input_names
+
+ return tf.data.Dataset.from_generator(
+ gen,
+ (dict.fromkeys(input_names, tf.int32), label_type),
+ ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
+ )
+
+
+def _glue_convert_examples_to_features(
+ examples: list[InputExample],
+ tokenizer: PreTrainedTokenizer,
+ max_length: Optional[int] = None,
+ task=None,
+ label_list=None,
+ output_mode=None,
+):
+ if max_length is None:
+ max_length = tokenizer.model_max_length
+
+ if task is not None:
+ processor = glue_processors[task]()
+ if label_list is None:
+ label_list = processor.get_labels()
+ logger.info(f"Using label list {label_list} for task {task}")
+ if output_mode is None:
+ output_mode = glue_output_modes[task]
+ logger.info(f"Using output mode {output_mode} for task {task}")
+
+ label_map = {label: i for i, label in enumerate(label_list)}
+
+ def label_from_example(example: InputExample) -> Union[int, float, None]:
+ if example.label is None:
+ return None
+ if output_mode == "classification":
+ return label_map[example.label]
+ elif output_mode == "regression":
+ return float(example.label)
+ raise KeyError(output_mode)
+
+ labels = [label_from_example(example) for example in examples]
+
+ batch_encoding = tokenizer(
+ [(example.text_a, example.text_b) for example in examples],
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ )
+
+ features = []
+ for i in range(len(examples)):
+ inputs = {k: batch_encoding[k][i] for k in batch_encoding}
+
+ feature = InputFeatures(**inputs, label=labels[i])
+ features.append(feature)
+
+ for i, example in enumerate(examples[:5]):
+ logger.info("*** Example ***")
+ logger.info(f"guid: {example.guid}")
+ logger.info(f"features: {features[i]}")
+
+ return features
+
+
+class OutputMode(Enum):
+ classification = "classification"
+ regression = "regression"
+
+
+class MrpcProcessor(DataProcessor):
+ """Processor for the MRPC data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{i}"
+ text_a = line[3]
+ text_b = line[4]
+ label = None if set_type == "test" else line[0]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class MnliProcessor(DataProcessor):
+ """Processor for the MultiNLI data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["premise"].numpy().decode("utf-8"),
+ tensor_dict["hypothesis"].numpy().decode("utf-8"),
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{line[0]}"
+ text_a = line[8]
+ text_b = line[9]
+ label = None if set_type.startswith("test") else line[-1]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class MnliMismatchedProcessor(MnliProcessor):
+ """Processor for the MultiNLI Mismatched data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
+
+
+class ColaProcessor(DataProcessor):
+ """Processor for the CoLA data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["sentence"].numpy().decode("utf-8"),
+ None,
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ test_mode = set_type == "test"
+ if test_mode:
+ lines = lines[1:]
+ text_index = 1 if test_mode else 3
+ examples = []
+ for i, line in enumerate(lines):
+ guid = f"{set_type}-{i}"
+ text_a = line[text_index]
+ label = None if test_mode else line[1]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+ return examples
+
+
+class Sst2Processor(DataProcessor):
+ """Processor for the SST-2 data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["sentence"].numpy().decode("utf-8"),
+ None,
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ examples = []
+ text_index = 1 if set_type == "test" else 0
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{i}"
+ text_a = line[text_index]
+ label = None if set_type == "test" else line[1]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+ return examples
+
+
+class StsbProcessor(DataProcessor):
+ """Processor for the STS-B data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return [None]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{line[0]}"
+ text_a = line[7]
+ text_b = line[8]
+ label = None if set_type == "test" else line[-1]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class QqpProcessor(DataProcessor):
+ """Processor for the QQP data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["question1"].numpy().decode("utf-8"),
+ tensor_dict["question2"].numpy().decode("utf-8"),
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ test_mode = set_type == "test"
+ q1_index = 1 if test_mode else 3
+ q2_index = 2 if test_mode else 4
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{line[0]}"
+ try:
+ text_a = line[q1_index]
+ text_b = line[q2_index]
+ label = None if test_mode else line[5]
+ except IndexError:
+ continue
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class QnliProcessor(DataProcessor):
+ """Processor for the QNLI data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["question"].numpy().decode("utf-8"),
+ tensor_dict["sentence"].numpy().decode("utf-8"),
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["entailment", "not_entailment"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{line[0]}"
+ text_a = line[1]
+ text_b = line[2]
+ label = None if set_type == "test" else line[-1]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class RteProcessor(DataProcessor):
+ """Processor for the RTE data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["entailment", "not_entailment"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{line[0]}"
+ text_a = line[1]
+ text_b = line[2]
+ label = None if set_type == "test" else line[-1]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class WnliProcessor(DataProcessor):
+ """Processor for the WNLI data set (GLUE version)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """See base class."""
+ return InputExample(
+ tensor_dict["idx"].numpy(),
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
+ str(tensor_dict["label"].numpy()),
+ )
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training, dev and test sets."""
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"{set_type}-{line[0]}"
+ text_a = line[1]
+ text_b = line[2]
+ label = None if set_type == "test" else line[-1]
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+glue_tasks_num_labels = {
+ "cola": 2,
+ "mnli": 3,
+ "mrpc": 2,
+ "sst-2": 2,
+ "sts-b": 1,
+ "qqp": 2,
+ "qnli": 2,
+ "rte": 2,
+ "wnli": 2,
+}
+
+glue_processors = {
+ "cola": ColaProcessor,
+ "mnli": MnliProcessor,
+ "mnli-mm": MnliMismatchedProcessor,
+ "mrpc": MrpcProcessor,
+ "sst-2": Sst2Processor,
+ "sts-b": StsbProcessor,
+ "qqp": QqpProcessor,
+ "qnli": QnliProcessor,
+ "rte": RteProcessor,
+ "wnli": WnliProcessor,
+}
+
+glue_output_modes = {
+ "cola": "classification",
+ "mnli": "classification",
+ "mnli-mm": "classification",
+ "mrpc": "classification",
+ "sst-2": "classification",
+ "sts-b": "regression",
+ "qqp": "classification",
+ "qnli": "classification",
+ "rte": "classification",
+ "wnli": "classification",
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/squad.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/squad.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f37eb01813308a0c850e55ac283f13ccf231f68
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/squad.py
@@ -0,0 +1,845 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from functools import partial
+from multiprocessing import Pool, cpu_count
+from multiprocessing.pool import ThreadPool
+from typing import Optional
+
+import numpy as np
+from tqdm import tqdm
+
+from ...models.bert.tokenization_bert import whitespace_tokenize
+from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
+from ...utils import is_tf_available, is_torch_available, is_torch_hpu_available, logging
+from .utils import DataProcessor
+
+
+# Store the tokenizers which insert 2 separators tokens
+MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
+
+
+if is_torch_available():
+ import torch
+ from torch.utils.data import TensorDataset
+
+if is_tf_available():
+ import tensorflow as tf
+
+logger = logging.get_logger(__name__)
+
+
+def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
+ """Returns tokenized answer spans that better match the annotated answer."""
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
+
+ for new_start in range(input_start, input_end + 1):
+ for new_end in range(input_end, new_start - 1, -1):
+ text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
+ if text_span == tok_answer_text:
+ return (new_start, new_end)
+
+ return (input_start, input_end)
+
+
+def _check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the 'max context' doc span for the token."""
+ best_score = None
+ best_span_index = None
+ for span_index, doc_span in enumerate(doc_spans):
+ end = doc_span.start + doc_span.length - 1
+ if position < doc_span.start:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span.start
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+def _new_check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the 'max context' doc span for the token."""
+ # if len(doc_spans) == 1:
+ # return True
+ best_score = None
+ best_span_index = None
+ for span_index, doc_span in enumerate(doc_spans):
+ end = doc_span["start"] + doc_span["length"] - 1
+ if position < doc_span["start"]:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span["start"]
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+def _is_whitespace(c):
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
+ return True
+ return False
+
+
+def squad_convert_example_to_features(
+ example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
+):
+ features = []
+ if is_training and not example.is_impossible:
+ # Get start and end position
+ start_position = example.start_position
+ end_position = example.end_position
+
+ # If the answer cannot be found in the text, then skip this example.
+ actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
+ cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
+ if actual_text.find(cleaned_answer_text) == -1:
+ logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
+ return []
+
+ tok_to_orig_index = []
+ orig_to_tok_index = []
+ all_doc_tokens = []
+ for i, token in enumerate(example.doc_tokens):
+ orig_to_tok_index.append(len(all_doc_tokens))
+ if tokenizer.__class__.__name__ in [
+ "RobertaTokenizer",
+ "LongformerTokenizer",
+ "BartTokenizer",
+ "RobertaTokenizerFast",
+ "LongformerTokenizerFast",
+ "BartTokenizerFast",
+ ]:
+ sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
+ else:
+ sub_tokens = tokenizer.tokenize(token)
+ for sub_token in sub_tokens:
+ tok_to_orig_index.append(i)
+ all_doc_tokens.append(sub_token)
+
+ if is_training and not example.is_impossible:
+ tok_start_position = orig_to_tok_index[example.start_position]
+ if example.end_position < len(example.doc_tokens) - 1:
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
+ else:
+ tok_end_position = len(all_doc_tokens) - 1
+
+ (tok_start_position, tok_end_position) = _improve_answer_span(
+ all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
+ )
+
+ spans = []
+
+ truncated_query = tokenizer.encode(
+ example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
+ )
+
+ # Tokenizers who insert 2 SEP tokens in-between & need to have special handling
+ # in the way they compute mask of added tokens.
+ tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
+ sequence_added_tokens = (
+ tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
+ if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
+ else tokenizer.model_max_length - tokenizer.max_len_single_sentence
+ )
+ sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
+
+ span_doc_tokens = all_doc_tokens
+ while len(spans) * doc_stride < len(all_doc_tokens):
+ # Define the side we want to truncate / pad and the text/pair sorting
+ if tokenizer.padding_side == "right":
+ texts = truncated_query
+ pairs = span_doc_tokens
+ truncation = TruncationStrategy.ONLY_SECOND.value
+ else:
+ texts = span_doc_tokens
+ pairs = truncated_query
+ truncation = TruncationStrategy.ONLY_FIRST.value
+
+ encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
+ texts,
+ pairs,
+ truncation=truncation,
+ padding=padding_strategy,
+ max_length=max_seq_length,
+ return_overflowing_tokens=True,
+ stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
+ return_token_type_ids=True,
+ )
+
+ paragraph_len = min(
+ len(all_doc_tokens) - len(spans) * doc_stride,
+ max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
+ )
+
+ if tokenizer.pad_token_id in encoded_dict["input_ids"]:
+ if tokenizer.padding_side == "right":
+ non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
+ else:
+ last_padding_id_position = (
+ len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
+ )
+ non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
+
+ else:
+ non_padded_ids = encoded_dict["input_ids"]
+
+ tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
+
+ token_to_orig_map = {}
+ for i in range(paragraph_len):
+ index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
+ token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
+
+ encoded_dict["paragraph_len"] = paragraph_len
+ encoded_dict["tokens"] = tokens
+ encoded_dict["token_to_orig_map"] = token_to_orig_map
+ encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
+ encoded_dict["token_is_max_context"] = {}
+ encoded_dict["start"] = len(spans) * doc_stride
+ encoded_dict["length"] = paragraph_len
+
+ spans.append(encoded_dict)
+
+ if "overflowing_tokens" not in encoded_dict or (
+ "overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
+ ):
+ break
+ span_doc_tokens = encoded_dict["overflowing_tokens"]
+
+ for doc_span_index in range(len(spans)):
+ for j in range(spans[doc_span_index]["paragraph_len"]):
+ is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
+ index = (
+ j
+ if tokenizer.padding_side == "left"
+ else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
+ )
+ spans[doc_span_index]["token_is_max_context"][index] = is_max_context
+
+ for span in spans:
+ # Identify the position of the CLS token
+ cls_index = span["input_ids"].index(tokenizer.cls_token_id)
+
+ # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
+ # Original TF implementation also keep the classification token (set to 0)
+ p_mask = np.ones_like(span["token_type_ids"])
+ if tokenizer.padding_side == "right":
+ p_mask[len(truncated_query) + sequence_added_tokens :] = 0
+ else:
+ p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
+
+ pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
+ special_token_indices = np.asarray(
+ tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
+ ).nonzero()
+
+ p_mask[pad_token_indices] = 1
+ p_mask[special_token_indices] = 1
+
+ # Set the cls index to 0: the CLS index can be used for impossible answers
+ p_mask[cls_index] = 0
+
+ span_is_impossible = example.is_impossible
+ start_position = 0
+ end_position = 0
+ if is_training and not span_is_impossible:
+ # For training, if our document chunk does not contain an annotation
+ # we throw it out, since there is nothing to predict.
+ doc_start = span["start"]
+ doc_end = span["start"] + span["length"] - 1
+ out_of_span = False
+
+ if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
+ out_of_span = True
+
+ if out_of_span:
+ start_position = cls_index
+ end_position = cls_index
+ span_is_impossible = True
+ else:
+ if tokenizer.padding_side == "left":
+ doc_offset = 0
+ else:
+ doc_offset = len(truncated_query) + sequence_added_tokens
+
+ start_position = tok_start_position - doc_start + doc_offset
+ end_position = tok_end_position - doc_start + doc_offset
+ features.append(
+ SquadFeatures(
+ span["input_ids"],
+ span["attention_mask"],
+ span["token_type_ids"],
+ cls_index,
+ p_mask.tolist(),
+ example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
+ unique_id=0,
+ paragraph_len=span["paragraph_len"],
+ token_is_max_context=span["token_is_max_context"],
+ tokens=span["tokens"],
+ token_to_orig_map=span["token_to_orig_map"],
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=span_is_impossible,
+ qas_id=example.qas_id,
+ )
+ )
+ return features
+
+
+def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase):
+ global tokenizer
+ tokenizer = tokenizer_for_convert
+
+
+def squad_convert_examples_to_features(
+ examples,
+ tokenizer,
+ max_seq_length,
+ doc_stride,
+ max_query_length,
+ is_training,
+ padding_strategy="max_length",
+ return_dataset=False,
+ threads=1,
+ tqdm_enabled=True,
+):
+ """
+ Converts a list of examples into a list of features that can be directly given as input to a model. It is
+ model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
+
+ Args:
+ examples: list of [`~data.processors.squad.SquadExample`]
+ tokenizer: an instance of a child of [`PreTrainedTokenizer`]
+ max_seq_length: The maximum sequence length of the inputs.
+ doc_stride: The stride used when the context is too large and is split across several features.
+ max_query_length: The maximum length of the query.
+ is_training: whether to create features for model evaluation or model training.
+ padding_strategy: Default to "max_length". Which padding strategy to use
+ return_dataset: Default False. Either 'pt' or 'tf'.
+ if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
+ threads: multiple processing threads.
+
+
+ Returns:
+ list of [`~data.processors.squad.SquadFeatures`]
+
+ Example:
+
+ ```python
+ processor = SquadV2Processor()
+ examples = processor.get_dev_examples(data_dir)
+
+ features = squad_convert_examples_to_features(
+ examples=examples,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length,
+ doc_stride=args.doc_stride,
+ max_query_length=args.max_query_length,
+ is_training=not evaluate,
+ )
+ ```"""
+
+ threads = min(threads, cpu_count())
+ pool_cls = ThreadPool if is_torch_hpu_available() else Pool
+ with pool_cls(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
+ annotate_ = partial(
+ squad_convert_example_to_features,
+ max_seq_length=max_seq_length,
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ padding_strategy=padding_strategy,
+ is_training=is_training,
+ )
+ features = list(
+ tqdm(
+ p.imap(annotate_, examples, chunksize=32),
+ total=len(examples),
+ desc="convert squad examples to features",
+ disable=not tqdm_enabled,
+ )
+ )
+
+ new_features = []
+ unique_id = 1000000000
+ example_index = 0
+ for example_features in tqdm(
+ features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled
+ ):
+ if not example_features:
+ continue
+ for example_feature in example_features:
+ example_feature.example_index = example_index
+ example_feature.unique_id = unique_id
+ new_features.append(example_feature)
+ unique_id += 1
+ example_index += 1
+ features = new_features
+ del new_features
+ if return_dataset == "pt":
+ if not is_torch_available():
+ raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
+
+ # Convert to Tensors and build dataset
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+ all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
+ all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
+ all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
+ all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
+
+ if not is_training:
+ all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
+ dataset = TensorDataset(
+ all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
+ )
+ else:
+ all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
+ all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
+ dataset = TensorDataset(
+ all_input_ids,
+ all_attention_masks,
+ all_token_type_ids,
+ all_start_positions,
+ all_end_positions,
+ all_cls_index,
+ all_p_mask,
+ all_is_impossible,
+ )
+
+ return features, dataset
+ elif return_dataset == "tf":
+ if not is_tf_available():
+ raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
+
+ def gen():
+ for i, ex in enumerate(features):
+ if ex.token_type_ids is None:
+ yield (
+ {
+ "input_ids": ex.input_ids,
+ "attention_mask": ex.attention_mask,
+ "feature_index": i,
+ "qas_id": ex.qas_id,
+ },
+ {
+ "start_positions": ex.start_position,
+ "end_positions": ex.end_position,
+ "cls_index": ex.cls_index,
+ "p_mask": ex.p_mask,
+ "is_impossible": ex.is_impossible,
+ },
+ )
+ else:
+ yield (
+ {
+ "input_ids": ex.input_ids,
+ "attention_mask": ex.attention_mask,
+ "token_type_ids": ex.token_type_ids,
+ "feature_index": i,
+ "qas_id": ex.qas_id,
+ },
+ {
+ "start_positions": ex.start_position,
+ "end_positions": ex.end_position,
+ "cls_index": ex.cls_index,
+ "p_mask": ex.p_mask,
+ "is_impossible": ex.is_impossible,
+ },
+ )
+
+ # Why have we split the batch into a tuple? PyTorch just has a list of tensors.
+ if "token_type_ids" in tokenizer.model_input_names:
+ train_types = (
+ {
+ "input_ids": tf.int32,
+ "attention_mask": tf.int32,
+ "token_type_ids": tf.int32,
+ "feature_index": tf.int64,
+ "qas_id": tf.string,
+ },
+ {
+ "start_positions": tf.int64,
+ "end_positions": tf.int64,
+ "cls_index": tf.int64,
+ "p_mask": tf.int32,
+ "is_impossible": tf.int32,
+ },
+ )
+
+ train_shapes = (
+ {
+ "input_ids": tf.TensorShape([None]),
+ "attention_mask": tf.TensorShape([None]),
+ "token_type_ids": tf.TensorShape([None]),
+ "feature_index": tf.TensorShape([]),
+ "qas_id": tf.TensorShape([]),
+ },
+ {
+ "start_positions": tf.TensorShape([]),
+ "end_positions": tf.TensorShape([]),
+ "cls_index": tf.TensorShape([]),
+ "p_mask": tf.TensorShape([None]),
+ "is_impossible": tf.TensorShape([]),
+ },
+ )
+ else:
+ train_types = (
+ {"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
+ {
+ "start_positions": tf.int64,
+ "end_positions": tf.int64,
+ "cls_index": tf.int64,
+ "p_mask": tf.int32,
+ "is_impossible": tf.int32,
+ },
+ )
+
+ train_shapes = (
+ {
+ "input_ids": tf.TensorShape([None]),
+ "attention_mask": tf.TensorShape([None]),
+ "feature_index": tf.TensorShape([]),
+ "qas_id": tf.TensorShape([]),
+ },
+ {
+ "start_positions": tf.TensorShape([]),
+ "end_positions": tf.TensorShape([]),
+ "cls_index": tf.TensorShape([]),
+ "p_mask": tf.TensorShape([None]),
+ "is_impossible": tf.TensorShape([]),
+ },
+ )
+
+ return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
+ else:
+ return features
+
+
+class SquadProcessor(DataProcessor):
+ """
+ Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
+ version 2.0 of SQuAD, respectively.
+ """
+
+ train_file = None
+ dev_file = None
+
+ def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
+ if not evaluate:
+ answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
+ answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
+ answers = []
+ else:
+ answers = [
+ {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
+ for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
+ ]
+
+ answer = None
+ answer_start = None
+
+ return SquadExample(
+ qas_id=tensor_dict["id"].numpy().decode("utf-8"),
+ question_text=tensor_dict["question"].numpy().decode("utf-8"),
+ context_text=tensor_dict["context"].numpy().decode("utf-8"),
+ answer_text=answer,
+ start_position_character=answer_start,
+ title=tensor_dict["title"].numpy().decode("utf-8"),
+ answers=answers,
+ )
+
+ def get_examples_from_dataset(self, dataset, evaluate=False):
+ """
+ Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.
+
+ Args:
+ dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
+ evaluate: Boolean specifying if in evaluation mode or in training mode
+
+ Returns:
+ List of SquadExample
+
+ Examples:
+
+ ```python
+ >>> import tensorflow_datasets as tfds
+
+ >>> dataset = tfds.load("squad")
+
+ >>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
+ >>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
+ ```"""
+
+ if evaluate:
+ dataset = dataset["validation"]
+ else:
+ dataset = dataset["train"]
+
+ examples = []
+ for tensor_dict in tqdm(dataset):
+ examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
+
+ return examples
+
+ def get_train_examples(self, data_dir, filename=None):
+ """
+ Returns the training examples from the data directory.
+
+ Args:
+ data_dir: Directory containing the data files used for training and evaluating.
+ filename: None by default, specify this if the training file has a different name than the original one
+ which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
+
+ """
+ if data_dir is None:
+ data_dir = ""
+
+ if self.train_file is None:
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
+
+ with open(
+ os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
+ ) as reader:
+ input_data = json.load(reader)["data"]
+ return self._create_examples(input_data, "train")
+
+ def get_dev_examples(self, data_dir, filename=None):
+ """
+ Returns the evaluation example from the data directory.
+
+ Args:
+ data_dir: Directory containing the data files used for training and evaluating.
+ filename: None by default, specify this if the evaluation file has a different name than the original one
+ which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
+ """
+ if data_dir is None:
+ data_dir = ""
+
+ if self.dev_file is None:
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
+
+ with open(
+ os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
+ ) as reader:
+ input_data = json.load(reader)["data"]
+ return self._create_examples(input_data, "dev")
+
+ def _create_examples(self, input_data, set_type):
+ is_training = set_type == "train"
+ examples = []
+ for entry in tqdm(input_data):
+ title = entry["title"]
+ for paragraph in entry["paragraphs"]:
+ context_text = paragraph["context"]
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question_text = qa["question"]
+ start_position_character = None
+ answer_text = None
+ answers = []
+
+ is_impossible = qa.get("is_impossible", False)
+ if not is_impossible:
+ if is_training:
+ answer = qa["answers"][0]
+ answer_text = answer["text"]
+ start_position_character = answer["answer_start"]
+ else:
+ answers = qa["answers"]
+
+ example = SquadExample(
+ qas_id=qas_id,
+ question_text=question_text,
+ context_text=context_text,
+ answer_text=answer_text,
+ start_position_character=start_position_character,
+ title=title,
+ is_impossible=is_impossible,
+ answers=answers,
+ )
+ examples.append(example)
+ return examples
+
+
+class SquadV1Processor(SquadProcessor):
+ train_file = "train-v1.1.json"
+ dev_file = "dev-v1.1.json"
+
+
+class SquadV2Processor(SquadProcessor):
+ train_file = "train-v2.0.json"
+ dev_file = "dev-v2.0.json"
+
+
+class SquadExample:
+ """
+ A single training/test example for the Squad dataset, as loaded from disk.
+
+ Args:
+ qas_id: The example's unique identifier
+ question_text: The question string
+ context_text: The context string
+ answer_text: The answer string
+ start_position_character: The character position of the start of the answer
+ title: The title of the example
+ answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
+ is_impossible: False by default, set to True if the example has no possible answer.
+ """
+
+ def __init__(
+ self,
+ qas_id,
+ question_text,
+ context_text,
+ answer_text,
+ start_position_character,
+ title,
+ answers=[],
+ is_impossible=False,
+ ):
+ self.qas_id = qas_id
+ self.question_text = question_text
+ self.context_text = context_text
+ self.answer_text = answer_text
+ self.title = title
+ self.is_impossible = is_impossible
+ self.answers = answers
+
+ self.start_position, self.end_position = 0, 0
+
+ doc_tokens = []
+ char_to_word_offset = []
+ prev_is_whitespace = True
+
+ # Split on whitespace so that different tokens may be attributed to their original position.
+ for c in self.context_text:
+ if _is_whitespace(c):
+ prev_is_whitespace = True
+ else:
+ if prev_is_whitespace:
+ doc_tokens.append(c)
+ else:
+ doc_tokens[-1] += c
+ prev_is_whitespace = False
+ char_to_word_offset.append(len(doc_tokens) - 1)
+
+ self.doc_tokens = doc_tokens
+ self.char_to_word_offset = char_to_word_offset
+
+ # Start and end positions only has a value during evaluation.
+ if start_position_character is not None and not is_impossible:
+ self.start_position = char_to_word_offset[start_position_character]
+ self.end_position = char_to_word_offset[
+ min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
+ ]
+
+
+class SquadFeatures:
+ """
+ Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
+ [`~data.processors.squad.SquadExample`] using the
+ :method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.
+
+ Args:
+ input_ids: Indices of input sequence tokens in the vocabulary.
+ attention_mask: Mask to avoid performing attention on padding token indices.
+ token_type_ids: Segment token indices to indicate first and second portions of the inputs.
+ cls_index: the index of the CLS token.
+ p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
+ Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
+ example_index: the index of the example
+ unique_id: The unique Feature identifier
+ paragraph_len: The length of the context
+ token_is_max_context:
+ List of booleans identifying which tokens have their maximum context in this feature object. If a token
+ does not have their maximum context in this feature object, it means that another feature object has more
+ information related to that token and should be prioritized over this feature for that token.
+ tokens: list of tokens corresponding to the input ids
+ token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
+ start_position: start of the answer token index
+ end_position: end of the answer token index
+ encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
+ """
+
+ def __init__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ cls_index,
+ p_mask,
+ example_index,
+ unique_id,
+ paragraph_len,
+ token_is_max_context,
+ tokens,
+ token_to_orig_map,
+ start_position,
+ end_position,
+ is_impossible,
+ qas_id: Optional[str] = None,
+ encoding: Optional[BatchEncoding] = None,
+ ):
+ self.input_ids = input_ids
+ self.attention_mask = attention_mask
+ self.token_type_ids = token_type_ids
+ self.cls_index = cls_index
+ self.p_mask = p_mask
+
+ self.example_index = example_index
+ self.unique_id = unique_id
+ self.paragraph_len = paragraph_len
+ self.token_is_max_context = token_is_max_context
+ self.tokens = tokens
+ self.token_to_orig_map = token_to_orig_map
+
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+ self.qas_id = qas_id
+
+ self.encoding = encoding
+
+
+class SquadResult:
+ """
+ Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
+
+ Args:
+ unique_id: The unique identifier corresponding to that example.
+ start_logits: The logits corresponding to the start of the answer
+ end_logits: The logits corresponding to the end of the answer
+ """
+
+ def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
+ self.start_logits = start_logits
+ self.end_logits = end_logits
+ self.unique_id = unique_id
+
+ if start_top_index:
+ self.start_top_index = start_top_index
+ self.end_top_index = end_top_index
+ self.cls_logits = cls_logits
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..462156ebac384e08d78a7b42ea06f35a457e5feb
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/utils.py
@@ -0,0 +1,349 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import csv
+import dataclasses
+import json
+from dataclasses import dataclass
+from typing import Optional, Union
+
+from ...utils import is_tf_available, is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class InputExample:
+ """
+ A single training/test example for simple sequence classification.
+
+ Args:
+ guid: Unique id for the example.
+ text_a: string. The untokenized text of the first sequence. For single
+ sequence tasks, only this sequence must be specified.
+ text_b: (Optional) string. The untokenized text of the second sequence.
+ Only must be specified for sequence pair tasks.
+ label: (Optional) string. The label of the example. This should be
+ specified for train and dev examples, but not for test examples.
+ """
+
+ guid: str
+ text_a: str
+ text_b: Optional[str] = None
+ label: Optional[str] = None
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
+
+
+@dataclass(frozen=True)
+class InputFeatures:
+ """
+ A single set of features of data. Property names are the same names as the corresponding inputs to a model.
+
+ Args:
+ input_ids: Indices of input sequence tokens in the vocabulary.
+ attention_mask: Mask to avoid performing attention on padding token indices.
+ Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
+ tokens.
+ token_type_ids: (Optional) Segment token indices to indicate first and second
+ portions of the inputs. Only some models use them.
+ label: (Optional) Label corresponding to the input. Int for classification problems,
+ float for regression problems.
+ """
+
+ input_ids: list[int]
+ attention_mask: Optional[list[int]] = None
+ token_type_ids: Optional[list[int]] = None
+ label: Optional[Union[int, float]] = None
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(dataclasses.asdict(self)) + "\n"
+
+
+class DataProcessor:
+ """Base class for data converters for sequence classification data sets."""
+
+ def get_example_from_tensor_dict(self, tensor_dict):
+ """
+ Gets an example from a dict with tensorflow tensors.
+
+ Args:
+ tensor_dict: Keys and values should match the corresponding Glue
+ tensorflow_dataset examples.
+ """
+ raise NotImplementedError()
+
+ def get_train_examples(self, data_dir):
+ """Gets a collection of [`InputExample`] for the train set."""
+ raise NotImplementedError()
+
+ def get_dev_examples(self, data_dir):
+ """Gets a collection of [`InputExample`] for the dev set."""
+ raise NotImplementedError()
+
+ def get_test_examples(self, data_dir):
+ """Gets a collection of [`InputExample`] for the test set."""
+ raise NotImplementedError()
+
+ def get_labels(self):
+ """Gets the list of labels for this data set."""
+ raise NotImplementedError()
+
+ def tfds_map(self, example):
+ """
+ Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
+ examples to the correct format.
+ """
+ if len(self.get_labels()) > 1:
+ example.label = self.get_labels()[int(example.label)]
+ return example
+
+ @classmethod
+ def _read_tsv(cls, input_file, quotechar=None):
+ """Reads a tab separated value file."""
+ with open(input_file, "r", encoding="utf-8-sig") as f:
+ return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
+
+
+class SingleSentenceClassificationProcessor(DataProcessor):
+ """Generic processor for a single sentence classification data set."""
+
+ def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
+ self.labels = [] if labels is None else labels
+ self.examples = [] if examples is None else examples
+ self.mode = mode
+ self.verbose = verbose
+
+ def __len__(self):
+ return len(self.examples)
+
+ def __getitem__(self, idx):
+ if isinstance(idx, slice):
+ return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
+ return self.examples[idx]
+
+ @classmethod
+ def create_from_csv(
+ cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
+ ):
+ processor = cls(**kwargs)
+ processor.add_examples_from_csv(
+ file_name,
+ split_name=split_name,
+ column_label=column_label,
+ column_text=column_text,
+ column_id=column_id,
+ skip_first_row=skip_first_row,
+ overwrite_labels=True,
+ overwrite_examples=True,
+ )
+ return processor
+
+ @classmethod
+ def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
+ processor = cls(**kwargs)
+ processor.add_examples(texts_or_text_and_labels, labels=labels)
+ return processor
+
+ def add_examples_from_csv(
+ self,
+ file_name,
+ split_name="",
+ column_label=0,
+ column_text=1,
+ column_id=None,
+ skip_first_row=False,
+ overwrite_labels=False,
+ overwrite_examples=False,
+ ):
+ lines = self._read_tsv(file_name)
+ if skip_first_row:
+ lines = lines[1:]
+ texts = []
+ labels = []
+ ids = []
+ for i, line in enumerate(lines):
+ texts.append(line[column_text])
+ labels.append(line[column_label])
+ if column_id is not None:
+ ids.append(line[column_id])
+ else:
+ guid = f"{split_name}-{i}" if split_name else str(i)
+ ids.append(guid)
+
+ return self.add_examples(
+ texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
+ )
+
+ def add_examples(
+ self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
+ ):
+ if labels is not None and len(texts_or_text_and_labels) != len(labels):
+ raise ValueError(
+ f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
+ )
+ if ids is not None and len(texts_or_text_and_labels) != len(ids):
+ raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
+ if ids is None:
+ ids = [None] * len(texts_or_text_and_labels)
+ if labels is None:
+ labels = [None] * len(texts_or_text_and_labels)
+ examples = []
+ added_labels = set()
+ for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
+ if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
+ text, label = text_or_text_and_label
+ else:
+ text = text_or_text_and_label
+ added_labels.add(label)
+ examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
+
+ # Update examples
+ if overwrite_examples:
+ self.examples = examples
+ else:
+ self.examples.extend(examples)
+
+ # Update labels
+ if overwrite_labels:
+ self.labels = list(added_labels)
+ else:
+ self.labels = list(set(self.labels).union(added_labels))
+
+ return self.examples
+
+ def get_features(
+ self,
+ tokenizer,
+ max_length=None,
+ pad_on_left=False,
+ pad_token=0,
+ mask_padding_with_zero=True,
+ return_tensors=None,
+ ):
+ """
+ Convert examples in a list of `InputFeatures`
+
+ Args:
+ tokenizer: Instance of a tokenizer that will tokenize the examples
+ max_length: Maximum example length
+ pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
+ pad_token: Padding token
+ mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
+ and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
+ values)
+
+ Returns:
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the
+ task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific
+ `InputFeatures` which can be fed to the model.
+
+ """
+ if max_length is None:
+ max_length = tokenizer.max_len
+
+ label_map = {label: i for i, label in enumerate(self.labels)}
+
+ all_input_ids = []
+ for ex_index, example in enumerate(self.examples):
+ if ex_index % 10000 == 0:
+ logger.info(f"Tokenizing example {ex_index}")
+
+ input_ids = tokenizer.encode(
+ example.text_a,
+ add_special_tokens=True,
+ max_length=min(max_length, tokenizer.max_len),
+ )
+ all_input_ids.append(input_ids)
+
+ batch_length = max(len(input_ids) for input_ids in all_input_ids)
+
+ features = []
+ for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
+ if ex_index % 10000 == 0:
+ logger.info(f"Writing example {ex_index}/{len(self.examples)}")
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ padding_length = batch_length - len(input_ids)
+ if pad_on_left:
+ input_ids = ([pad_token] * padding_length) + input_ids
+ attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
+ else:
+ input_ids = input_ids + ([pad_token] * padding_length)
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
+
+ if len(input_ids) != batch_length:
+ raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
+ if len(attention_mask) != batch_length:
+ raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
+
+ if self.mode == "classification":
+ label = label_map[example.label]
+ elif self.mode == "regression":
+ label = float(example.label)
+ else:
+ raise ValueError(self.mode)
+
+ if ex_index < 5 and self.verbose:
+ logger.info("*** Example ***")
+ logger.info(f"guid: {example.guid}")
+ logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
+ logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
+ logger.info(f"label: {example.label} (id = {label})")
+
+ features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
+
+ if return_tensors is None:
+ return features
+ elif return_tensors == "tf":
+ if not is_tf_available():
+ raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
+ import tensorflow as tf
+
+ def gen():
+ for ex in features:
+ yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
+
+ dataset = tf.data.Dataset.from_generator(
+ gen,
+ ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
+ ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
+ )
+ return dataset
+ elif return_tensors == "pt":
+ if not is_torch_available():
+ raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
+ import torch
+ from torch.utils.data import TensorDataset
+
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
+ if self.mode == "classification":
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
+ elif self.mode == "regression":
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
+
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
+ return dataset
+ else:
+ raise ValueError("return_tensors should be one of 'tf' or 'pt'")
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/xnli.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/xnli.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d8ec17a8345db5bf08325a334a4c6eb8af29157
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/data/processors/xnli.py
@@ -0,0 +1,96 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""XNLI utils (dataset loading and evaluation)"""
+
+import os
+
+from ...utils import logging
+from .utils import DataProcessor, InputExample
+
+
+logger = logging.get_logger(__name__)
+
+
+class XnliProcessor(DataProcessor):
+ """
+ Processor for the XNLI dataset. Adapted from
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207
+ """
+
+ def __init__(self, language, train_language=None):
+ self.language = language
+ self.train_language = train_language
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lg = self.language if self.train_language is None else self.train_language
+ lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = f"train-{i}"
+ text_a = line[0]
+ text_b = line[1]
+ label = "contradiction" if line[2] == "contradictory" else line[2]
+ if not isinstance(text_a, str):
+ raise TypeError(f"Training input {text_a} is not a string")
+ if not isinstance(text_b, str):
+ raise TypeError(f"Training input {text_b} is not a string")
+ if not isinstance(label, str):
+ raise TypeError(f"Training label {label} is not a string")
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ language = line[0]
+ if language != self.language:
+ continue
+ guid = f"test-{i}"
+ text_a = line[6]
+ text_b = line[7]
+ label = line[1]
+ if not isinstance(text_a, str):
+ raise TypeError(f"Training input {text_a} is not a string")
+ if not isinstance(text_b, str):
+ raise TypeError(f"Training input {text_b} is not a string")
+ if not isinstance(label, str):
+ raise TypeError(f"Training label {label} is not a string")
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+
+xnli_processors = {
+ "xnli": XnliProcessor,
+}
+
+xnli_output_modes = {
+ "xnli": "classification",
+}
+
+xnli_tasks_num_labels = {
+ "xnli": 3,
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/distributed/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/distributed/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b14bbf97440645d190cc0aacd87a32415ead1cc5
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/distributed/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/distributed/__pycache__/configuration_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/distributed/__pycache__/configuration_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73924c5ccb6610cd95fe9d46d8e87a1160a162fc
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/distributed/__pycache__/configuration_utils.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bb6f7033657b77331d144093c1cf24dab9e444b
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..388a73d22d4c9b561e2a887b50a1897b8cf2def9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,40 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..7eac8c8bcd1bf529bb9c13d54d2d4215c9e4c89f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,32 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8ea1d7fabe2684dbb85f00fae2c47b469687cb2c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,156 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+#pragma once
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..34f8ae9cb77bbaa8cb4dd25e0cb86632db9ad05d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh
@@ -0,0 +1,1467 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..fbcf4543e66bb1162f42ce2ae57e1bac92243cb4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,29 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..c0db0c88c9db2c09d7f601937ea0f6ac480913bf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/ms_deform_attn.h b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..119b1fa317d1e5fcfb61a4837e560e9248db05f3
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/ms_deform_attn.h
@@ -0,0 +1,61 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/vision.cpp b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6ce3875568b9ba8d660c90acc805077cca98f891
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/deta/vision.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
\ No newline at end of file
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..da88e3394f653369a7443245c67dcbe57f2ed23e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__init__.py
@@ -0,0 +1,15 @@
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .selective_scan_with_ln_interface import mamba_inner_fn
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..acdad26358d001b4d32326f70fec3696bf909fe9
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__pycache__/__init__.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ab10aea9391675bf7c9acb13984e46d1f71c38f
Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-312.pyc differ
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a74986a81a13f9428eab353de5b61a4d101972d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py
@@ -0,0 +1,525 @@
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+
+try:
+ import causal_conv1d_cuda
+except ImportError:
+ causal_conv1d_cuda = None
+
+import mamba_ssm
+import selective_scan_cuda
+
+
+# For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
+if hasattr(mamba_ssm.ops.triton, "layernorm"):
+ from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
+else:
+ from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
+
+
+class SelectiveScanFn(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
+ ):
+ if u.stride(-1) != 1:
+ u = u.contiguous()
+ if delta.stride(-1) != 1:
+ delta = delta.contiguous()
+ if D is not None:
+ D = D.contiguous()
+ if B.stride(-1) != 1:
+ B = B.contiguous()
+ if C.stride(-1) != 1:
+ C = C.contiguous()
+ if z is not None and z.stride(-1) != 1:
+ z = z.contiguous()
+ if B.dim() == 3:
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
+ ctx.squeeze_B = True
+ if C.dim() == 3:
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
+ ctx.squeeze_C = True
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
+ ctx.delta_softplus = delta_softplus
+ ctx.has_z = z is not None
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
+ if not ctx.has_z:
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
+ return out if not return_last_state else (out, last_state)
+ else:
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
+ out_z = rest[0]
+ return out_z if not return_last_state else (out_z, last_state)
+
+ @staticmethod
+ def backward(ctx, dout, *args):
+ if not ctx.has_z:
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
+ z = None
+ out = None
+ else:
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
+ if dout.stride(-1) != 1:
+ dout = dout.contiguous()
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
+ # backward of selective_scan_cuda with the backward of chunk).
+ # Here we just pass in None and dz will be allocated in the C++ code.
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
+ u,
+ delta,
+ A,
+ B,
+ C,
+ D,
+ z,
+ delta_bias,
+ dout,
+ x,
+ out,
+ None,
+ ctx.delta_softplus,
+ False, # option to recompute out_z, not used here
+ )
+ dz = rest[0] if ctx.has_z else None
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
+ return (
+ du,
+ ddelta,
+ dA,
+ dB,
+ dC,
+ dD if D is not None else None,
+ dz,
+ ddelta_bias if delta_bias is not None else None,
+ None,
+ None,
+ )
+
+
+def rms_norm_forward(
+ x,
+ weight,
+ bias,
+ eps=1e-6,
+ is_rms_norm=True,
+):
+ # x (b l) d
+ if x.stride(-1) != 1:
+ x = x.contiguous()
+ weight = weight.contiguous()
+ if bias is not None:
+ bias = bias.contiguous()
+ y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
+ # y (b l) d
+ return y
+
+
+def selective_scan_fn(
+ u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
+):
+ """if return_last_state is True, returns (out, last_state)
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
+ not considered in the backward pass.
+ """
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
+
+
+def selective_scan_ref(
+ u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
+):
+ """
+ u: r(B D L)
+ delta: r(B D L)
+ A: c(D N) or r(D N)
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
+ D: r(D)
+ z: r(B D L)
+ delta_bias: r(D), fp32
+
+ out: r(B D L)
+ last_state (optional): r(B D dstate) or c(B D dstate)
+ """
+ dtype_in = u.dtype
+ u = u.float()
+ delta = delta.float()
+ if delta_bias is not None:
+ delta = delta + delta_bias[..., None].float()
+ if delta_softplus:
+ delta = F.softplus(delta)
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
+ is_variable_B = B.dim() >= 3
+ is_variable_C = C.dim() >= 3
+ if A.is_complex():
+ if is_variable_B:
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
+ if is_variable_C:
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
+ else:
+ B = B.float()
+ C = C.float()
+ x = A.new_zeros((batch, dim, dstate))
+ ys = []
+ deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
+ if not is_variable_B:
+ deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
+ else:
+ if B.dim() == 3:
+ deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
+ else:
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
+ deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
+ if is_variable_C and C.dim() == 4:
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
+ last_state = None
+ for i in range(u.shape[2]):
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
+ if not is_variable_C:
+ y = torch.einsum("bdn,dn->bd", x, C)
+ else:
+ if C.dim() == 3:
+ y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
+ else:
+ y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
+ if i == u.shape[2] - 1:
+ last_state = x
+ if y.is_complex():
+ y = y.real * 2
+ ys.append(y)
+ y = torch.stack(ys, dim=2) # (batch dim L)
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
+ if z is not None:
+ out = out * F.silu(z)
+ out = out.to(dtype=dtype_in)
+ return out if not return_last_state else (out, last_state)
+
+
+class MambaInnerFn(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ xz,
+ conv1d_weight,
+ conv1d_bias,
+ x_proj_weight,
+ delta_proj_weight,
+ out_proj_weight,
+ out_proj_bias,
+ A,
+ B=None,
+ C=None,
+ D=None,
+ delta_bias=None,
+ B_proj_bias=None,
+ C_proj_bias=None,
+ delta_softplus=True,
+ checkpoint_lvl=1,
+ b_rms_weight=None,
+ c_rms_weight=None,
+ dt_rms_weight=None,
+ b_c_dt_rms_eps=1e-6,
+ ):
+ """
+ xz: (batch, dim, seqlen)
+ """
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
+ assert checkpoint_lvl in [0, 1]
+ L = xz.shape[-1]
+ delta_rank = delta_proj_weight.shape[1]
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
+ if torch.is_autocast_enabled():
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
+ out_proj_bias = (
+ out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None
+ )
+ if xz.stride(-1) != 1:
+ xz = xz.contiguous()
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
+ x, z = xz.chunk(2, dim=1)
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
+ # We're being very careful here about the layout, to avoid extra transposes.
+ # We want delta to have d as the slowest moving dimension
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
+ x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
+ ctx.is_variable_B = B is None
+ ctx.is_variable_C = C is None
+ ctx.B_proj_bias_is_None = B_proj_bias is None
+ ctx.C_proj_bias_is_None = C_proj_bias is None
+ if B is None: # variable B
+ B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
+ if B_proj_bias is not None:
+ B = B + B_proj_bias.to(dtype=B.dtype)
+ if not A.is_complex():
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
+ else:
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
+ else:
+ if B.stride(-1) != 1:
+ B = B.contiguous()
+ if C is None: # variable C
+ C = x_dbl[:, -d_state:] # (bl dstate)
+ if C_proj_bias is not None:
+ C = C + C_proj_bias.to(dtype=C.dtype)
+ if not A.is_complex():
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
+ else:
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
+ else:
+ if C.stride(-1) != 1:
+ C = C.contiguous()
+ if D is not None:
+ D = D.contiguous()
+
+ if b_rms_weight is not None:
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
+ B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
+ if c_rms_weight is not None:
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
+ C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
+ if dt_rms_weight is not None:
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
+ delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
+
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
+ )
+ ctx.delta_softplus = delta_softplus
+ ctx.out_proj_bias_is_None = out_proj_bias is None
+ ctx.checkpoint_lvl = checkpoint_lvl
+ ctx.b_rms_weight = b_rms_weight
+ ctx.c_rms_weight = c_rms_weight
+ ctx.dt_rms_weight = dt_rms_weight
+ ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
+ conv1d_out, delta = None, None
+ ctx.save_for_backward(
+ xz,
+ conv1d_weight,
+ conv1d_bias,
+ x_dbl,
+ x_proj_weight,
+ delta_proj_weight,
+ out_proj_weight,
+ conv1d_out,
+ delta,
+ A,
+ B,
+ C,
+ D,
+ delta_bias,
+ scan_intermediates,
+ b_rms_weight,
+ c_rms_weight,
+ dt_rms_weight,
+ out,
+ )
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dout):
+ # dout: (batch, seqlen, dim)
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
+ (
+ xz,
+ conv1d_weight,
+ conv1d_bias,
+ x_dbl,
+ x_proj_weight,
+ delta_proj_weight,
+ out_proj_weight,
+ conv1d_out,
+ delta,
+ A,
+ B,
+ C,
+ D,
+ delta_bias,
+ scan_intermediates,
+ b_rms_weight,
+ c_rms_weight,
+ dt_rms_weight,
+ out,
+ ) = ctx.saved_tensors
+ L = xz.shape[-1]
+ delta_rank = delta_proj_weight.shape[1]
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
+ x, z = xz.chunk(2, dim=1)
+ if dout.stride(-1) != 1:
+ dout = dout.contiguous()
+ if ctx.checkpoint_lvl == 1:
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
+ if dt_rms_weight is not None:
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
+ delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
+ if b_rms_weight is not None:
+ # Recompute & RMSNorm B
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
+ B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
+ if c_rms_weight is not None:
+ # Recompute & RMSNorm C
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
+ C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
+
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
+ # backward of selective_scan_cuda with the backward of chunk).
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
+ dx, dz = dxz.chunk(2, dim=1)
+ dout = rearrange(dout, "b l e -> e (b l)")
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
+ conv1d_out,
+ delta,
+ A,
+ B,
+ C,
+ D,
+ z,
+ delta_bias,
+ dout_y,
+ scan_intermediates,
+ out,
+ dz,
+ ctx.delta_softplus,
+ True, # option to recompute out_z
+ )
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
+ dD = dD if D is not None else None
+ dx_dbl = torch.empty_like(x_dbl)
+ dB_proj_bias = None
+ if ctx.is_variable_B:
+ if not A.is_complex():
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
+ else:
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
+ dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
+ dB = None
+ dC_proj_bias = None
+ if ctx.is_variable_C:
+ if not A.is_complex():
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
+ else:
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
+ dx_dbl[:, -d_state:] = dC # (bl d)
+ dC = None
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
+ # backward of conv1d with the backward of chunk).
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
+ )
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
+ return (
+ dxz,
+ dconv1d_weight,
+ dconv1d_bias,
+ dx_proj_weight,
+ ddelta_proj_weight,
+ dout_proj_weight,
+ dout_proj_bias,
+ dA,
+ dB,
+ dC,
+ dD,
+ ddelta_bias if delta_bias is not None else None,
+ # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
+ dB_proj_bias,
+ dC_proj_bias,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def mamba_inner_fn(
+ xz,
+ conv1d_weight,
+ conv1d_bias,
+ x_proj_weight,
+ delta_proj_weight,
+ out_proj_weight,
+ out_proj_bias,
+ A,
+ B=None,
+ C=None,
+ D=None,
+ delta_bias=None,
+ B_proj_bias=None,
+ C_proj_bias=None,
+ delta_softplus=True,
+ checkpoint_lvl=1,
+ b_rms_weight=None,
+ c_rms_weight=None,
+ dt_rms_weight=None,
+ b_c_dt_rms_eps=1e-6,
+):
+ return MambaInnerFn.apply(
+ xz,
+ conv1d_weight,
+ conv1d_bias,
+ x_proj_weight,
+ delta_proj_weight,
+ out_proj_weight,
+ out_proj_bias,
+ A,
+ B,
+ C,
+ D,
+ delta_bias,
+ B_proj_bias,
+ C_proj_bias,
+ delta_softplus,
+ checkpoint_lvl,
+ b_rms_weight,
+ c_rms_weight,
+ dt_rms_weight,
+ b_c_dt_rms_eps,
+ )
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_kernel.cu b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..87ed89052873813153786bd416a981d3e5279af9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_kernel.cu
@@ -0,0 +1,383 @@
+#include "cuda_kernel.h"
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+//////////////////////////////////////////////////////////////////////////////////////////////////
+
+__global__ void index_max_cuda_kernel(
+ float *index_vals, // [batch_size, 32, num_block]
+ int *indices, // [batch_size, num_block]
+ float *max_vals, // [batch_size, A_num_block * 32]
+ float *max_vals_scatter, // [batch_size, 32, num_block]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long num_block
+) {
+
+ long batch_idx = blockIdx.x;
+
+ long thread_idx = threadIdx.x;
+ long num_thread = blockDim.x;
+
+ extern __shared__ float buffer[];
+ int *max_buffer = (int*)buffer;
+
+ for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
+ int idx = i + thread_idx;
+ if (idx < A_num_block * 32) {
+ max_buffer[idx] = -1e8;
+ }
+ }
+ __syncthreads();
+
+ int *indices_pt = &indices[batch_idx * num_block];
+ float *index_vals_pt = &index_vals[batch_idx * num_block * 32];
+
+ for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
+ int idx = idx_start + thread_idx;
+ int A_block_idx = indices_pt[idx % num_block] / B_num_block;
+ atomicMax(&max_buffer[A_block_idx * 32 + idx / num_block], (int)(index_vals_pt[idx] * 1000));
+ }
+ __syncthreads();
+
+ float *max_vals_pt = &max_vals[batch_idx * A_num_block * 32];
+ for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
+ int idx = i + thread_idx;
+ if (idx < A_num_block * 32) {
+ max_vals_pt[idx] = (float)max_buffer[idx] / 1000.;
+ }
+ }
+
+ float *max_vals_scatter_pt = &max_vals_scatter[batch_idx * num_block * 32];
+ for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
+ int idx = idx_start + thread_idx;
+ int A_block_idx = indices_pt[idx % num_block] / B_num_block;
+ max_vals_scatter_pt[idx] = (float)max_buffer[A_block_idx * 32 + idx / num_block] / 1000.;
+ }
+
+}
+
+__global__ void mm_to_sparse_cuda_kernel(
+ float *dense_A, // [batch_size, A_num_block, dim, 32]
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
+ int *indices, // [batch_size, num_block]
+ float *sparse_C, // [batch_size, num_block, 32, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long dim,
+ long num_block
+) {
+
+ long batch_idx = blockIdx.y;
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
+
+ long thread_idx = threadIdx.x;
+
+ __shared__ float buffer[4096];
+ float *A_buffer = &buffer[threadIdx.y * 1024]; // [2, 8, 32]
+ float *B_buffer = &buffer[threadIdx.y * 1024 + 512]; // [2, 8, 32]
+
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
+
+ long AB_block_idx = indices[batch_idx__block_idx];
+ float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * dim * 32];
+ float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * dim * 32];
+
+ int reg_1_idx = thread_idx / 8; // [0000000011111111222222223333333344444444555555556666666677777777]
+ int reg_2_idx = thread_idx % 8; // [0123456701234567012345670123456701234567012345670123456701234567]
+
+ float reg_1[8];
+ float reg_2[8];
+
+ float reg_array[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ A_buffer[i * 64 + thread_idx] = dense_A_pt[i * 64 + thread_idx];
+ B_buffer[i * 64 + thread_idx] = dense_B_pt[i * 64 + thread_idx];
+ }
+
+ __syncthreads();
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ reg_1[i] = A_buffer[reg_1_idx * 4 + i];
+ reg_2[i] = B_buffer[reg_2_idx * 4 + i];
+ }
+
+ for (int dim_stride = 1; dim_stride < (dim / 8); dim_stride++) {
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ A_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_A_pt[dim_stride * 256 + i * 64 + thread_idx];
+ B_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_B_pt[dim_stride * 256 + i * 64 + thread_idx];
+ }
+
+ #pragma unroll
+ for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
+ reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
+ }
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
+ }
+ }
+ }
+
+ __syncthreads();
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ reg_1[i] = A_buffer[(dim_stride % 2) * 256 + reg_1_idx * 4 + i];
+ reg_2[i] = B_buffer[(dim_stride % 2) * 256 + reg_2_idx * 4 + i];
+ }
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
+ }
+ }
+
+ }
+
+ #pragma unroll
+ for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
+ reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
+ }
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
+ }
+ }
+ }
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
+ }
+ }
+ __syncthreads();
+
+ float *C_buffer = &buffer[threadIdx.y * 1024]; // [32, 32]
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ C_buffer[(reg_2_idx * 4 + j) * 32 + reg_1_idx * 4 + i] = reg_array[i * 4 + j];
+ }
+ }
+ __syncthreads();
+
+ float *sparse_C_pt = &sparse_C[batch_idx__block_idx * 1024];
+
+ #pragma unroll
+ for (int i = 0; i < 16; i++) {
+ sparse_C_pt[i * 64 + thread_idx] = C_buffer[i * 64 + thread_idx];
+ }
+
+}
+
+__global__ void sparse_dense_mm_cuda_kernel(
+ float *sparse_A, // [batch_size, num_block, 32, 32]
+ int *indices, // [batch_size, num_block]
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
+ float *dense_C, // [batch_size, A_num_block, dim, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long dim,
+ long num_block
+) {
+
+ long batch_idx = blockIdx.y;
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
+
+ long thread_idx = threadIdx.x;
+
+ __shared__ float buffer[6144];
+ float *A_buffer = &buffer[threadIdx.y * 3072]; // [32, 32]
+ float *B_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [32, 64]
+
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
+
+ float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ A_buffer[i * 128 + thread_idx] = sparse_A_pt[i * 128 + thread_idx];
+ }
+
+ long AB_block_idx = indices[batch_idx__block_idx];
+ float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * 32 * dim];
+ float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32 * dim];
+
+ // [0000000011111111222222223333333344444444555555556666666677777777]
+ // [0123456701234567012345670123456701234567012345670123456701234567]
+ int reg_1_idx = thread_idx / 8;
+ int reg_2_idx = thread_idx % 8;
+
+ float reg_1[8];
+ float reg_2[8];
+
+ float reg_array[16];
+
+ for (int dim_stride = 0; dim_stride < dim; dim_stride = dim_stride + 64) {
+
+ #pragma unroll
+ for (int i = 0; i < 16; i++) {
+ B_buffer[i * 128 + thread_idx] = dense_B_pt[dim_stride * 32 + i * 128 + thread_idx];
+ }
+
+ #pragma unroll
+ for (int i = 0; i < 16; i++) {
+ reg_array[i] = 0;
+ }
+
+ __syncthreads();
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ reg_1[i] = B_buffer[(reg_1_idx * 4 + i) * 32];
+ reg_2[i] = A_buffer[reg_2_idx * 4 + i];
+ }
+
+ #pragma unroll
+ for (int mini_dim_idx = 1; mini_dim_idx < 32; mini_dim_idx++) {
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ reg_1[(mini_dim_idx % 2) * 4 + i] = B_buffer[(reg_1_idx * 4 + i) * 32 + mini_dim_idx];
+ reg_2[(mini_dim_idx % 2) * 4 + i] = A_buffer[mini_dim_idx * 32 + reg_2_idx * 4 + i];
+ }
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
+ }
+ }
+ }
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
+ }
+ }
+
+ __syncthreads();
+
+ float *C_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [64, 32]
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ C_buffer[(reg_1_idx * 4 + i) * 32 + reg_2_idx * 4 + j] = reg_array[i * 4 + j];
+ }
+ }
+ __syncthreads();
+
+ #pragma unroll
+ for (int i = 0; i < 16; i++) {
+ atomicAdd(&dense_C_pt[dim_stride * 32 + i * 128 + thread_idx], C_buffer[i * 128 + thread_idx]);
+ }
+ __syncthreads();
+
+ }
+
+}
+
+
+__global__ void reduce_sum_cuda_kernel(
+ float *sparse_A, // [batch_size, num_block, 32, 32]
+ int *indices, // [batch_size, num_block]
+ float *dense_C, // [batch_size, A_num_block, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long num_block
+) {
+
+ long batch_idx = blockIdx.y;
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
+
+ long thread_idx = threadIdx.x;
+
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
+
+ long AB_block_idx = indices[batch_idx__block_idx];
+ float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
+
+ float reg_array[16];
+ float value = 0;
+
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ reg_array[i] = sparse_A_pt[i * 32 + thread_idx];
+ }
+ #pragma unroll
+ for (int stride = 8; stride < 32; stride = stride + 8) {
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ reg_array[(stride + i) % 16] = sparse_A_pt[(stride + i) * 32 + thread_idx];
+ }
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ value = value + reg_array[(stride - 8 + i) % 16];
+ }
+ }
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ value = value + reg_array[8 + i];
+ }
+
+ float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
+
+ atomicAdd(&dense_C_pt[thread_idx], value);
+
+}
+
+__global__ void scatter_cuda_kernel(
+ float *dense_A, // [batch_size, A_num_block, 32]
+ int *indices, // [batch_size, num_block]
+ float *sparse_C, // [batch_size, num_block, 32, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long num_block
+) {
+
+ long batch_idx = blockIdx.y;
+ long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
+
+ long thread_idx = threadIdx.x;
+
+ long batch_idx__block_idx = batch_idx * num_block + block_idx;
+
+ long AB_block_idx = indices[batch_idx__block_idx];
+ float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
+ float *sparse_C_pt = &sparse_C[(batch_idx * num_block + block_idx) * 1024];
+
+ float value = dense_A_pt[thread_idx];
+
+ #pragma unroll
+ for (int i = 0; i < 32; i++) {
+ sparse_C_pt[i * 32 + thread_idx] = value;
+ }
+
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_kernel.h b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..a95b46f7d159b11851143710034cf80c20aa6bf8
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_kernel.h
@@ -0,0 +1,59 @@
+
+#define WARP_SIZE 32
+#define FULL_MASK 0xffffffff
+#define OPTIMAL_THREADS 256
+
+__global__ void index_max_cuda_kernel(
+ float *index_vals, // [batch_size, 32, num_block]
+ int *indices, // [batch_size, num_block]
+ float *max_vals, // [batch_size, A_num_block * 32]
+ float *max_vals_scatter, // [batch_size, 32, num_block]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long num_block
+);
+
+__global__ void mm_to_sparse_cuda_kernel(
+ float *dense_A, // [batch_size, A_num_block, dim, 32]
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
+ int *indices, // [batch_size, num_block]
+ float *sparse_C, // [batch_size, num_block, 32, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long dim,
+ long num_block
+);
+
+__global__ void sparse_dense_mm_cuda_kernel(
+ float *sparse_A, // [batch_size, num_block, 32, 32]
+ int *indices, // [batch_size, num_block]
+ float *dense_B, // [batch_size, B_num_block, dim, 32]
+ float *dense_C, // [batch_size, A_num_block, dim, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long dim,
+ long num_block
+);
+
+__global__ void reduce_sum_cuda_kernel(
+ float *sparse_A, // [batch_size, num_block, 32, 32]
+ int *indices, // [batch_size, num_block]
+ float *dense_C, // [batch_size, A_num_block, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long num_block
+);
+
+__global__ void scatter_cuda_kernel(
+ float *dense_A, // [batch_size, A_num_block, 32]
+ int *indices, // [batch_size, num_block]
+ float *sparse_C, // [batch_size, num_block, 32, 32]
+ long batch_size,
+ long A_num_block,
+ long B_num_block,
+ long num_block
+);
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_launch.cu b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_launch.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ba2a0cacfe614e75e06d2dde80dc77a6e8a4ec1a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_launch.cu
@@ -0,0 +1,154 @@
+#include
+#include
+#include "cuda_launch.h"
+#include "cuda_kernel.h"
+#include
+
+//////////////////////////////////////////////////////////////////////////////////////////////////
+//////////////////////////////////////////////////////////////////////////////////////////////////
+
+std::vector index_max_kernel(
+ at::Tensor index_vals, // [batch_size, 32, num_block]
+ at::Tensor indices, // [batch_size, num_block],
+ int A_num_block,
+ int B_num_block
+) {
+ int batch_size = indices.size(0);
+ int num_block = indices.size(1);
+
+ at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options());
+ at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options());
+
+ dim3 threads(256);
+ dim3 blocks(batch_size);
+ int shared_mem = A_num_block * 32 * sizeof(float);
+
+ index_max_cuda_kernel<<>>(
+ index_vals.data_ptr(),
+ indices.data_ptr(),
+ max_vals.data_ptr(),
+ max_vals_scatter.data_ptr(),
+ batch_size,
+ A_num_block,
+ B_num_block,
+ num_block
+ );
+
+ return {max_vals, max_vals_scatter};
+}
+
+at::Tensor mm_to_sparse_kernel(
+ at::Tensor dense_A, // [batch_size, A_num_block, dim, 32]
+ at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
+ at::Tensor indices // [batch_size, num_block]
+) {
+ int batch_size = dense_A.size(0);
+ int A_num_block = dense_A.size(1);
+ int B_num_block = dense_B.size(1);
+ int dim = dense_A.size(2);
+ int num_block = indices.size(1);
+
+ at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
+
+ dim3 threads(64, 4);
+ dim3 blocks(num_block / 4, batch_size);
+
+ mm_to_sparse_cuda_kernel<<>>(
+ dense_A.data_ptr(),
+ dense_B.data_ptr(),
+ indices.data_ptr(),
+ sparse_C.data_ptr(),
+ batch_size,
+ A_num_block,
+ B_num_block,
+ dim,
+ num_block
+ );
+
+ return sparse_C;
+}
+
+at::Tensor sparse_dense_mm_kernel(
+ at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
+ at::Tensor indices, // [batch_size, num_block]
+ at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
+ int A_num_block
+) {
+ int batch_size = sparse_A.size(0);
+ int num_block = sparse_A.size(1);
+ int B_num_block = dense_B.size(1);
+ int dim = dense_B.size(2);
+
+ at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options());
+
+ dim3 threads(128, 2);
+ dim3 blocks(num_block / 2, batch_size);
+
+ sparse_dense_mm_cuda_kernel<<>>(
+ sparse_A.data_ptr(),
+ indices.data_ptr(),
+ dense_B.data_ptr(),
+ dense_C.data_ptr(),
+ batch_size,
+ A_num_block,
+ B_num_block,
+ dim,
+ num_block
+ );
+
+ return dense_C;
+}
+
+at::Tensor reduce_sum_kernel(
+ at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
+ at::Tensor indices, // [batch_size, num_block]
+ int A_num_block,
+ int B_num_block
+) {
+ int batch_size = sparse_A.size(0);
+ int num_block = sparse_A.size(1);
+
+ at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options());
+
+ dim3 threads(32, 4);
+ dim3 blocks(num_block / 4, batch_size);
+
+ reduce_sum_cuda_kernel<<>>(
+ sparse_A.data_ptr(),
+ indices.data_ptr(),
+ dense_C.data_ptr(),
+ batch_size,
+ A_num_block,
+ B_num_block,
+ num_block
+ );
+
+ return dense_C;
+}
+
+at::Tensor scatter_kernel(
+ at::Tensor dense_A, // [batch_size, A_num_block, 32]
+ at::Tensor indices, // [batch_size, num_block]
+ int B_num_block
+) {
+ int batch_size = dense_A.size(0);
+ int A_num_block = dense_A.size(1);
+ int num_block = indices.size(1);
+
+ at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
+
+ dim3 threads(32, 4);
+ dim3 blocks(num_block / 4, batch_size);
+
+ scatter_cuda_kernel<<>>(
+ dense_A.data_ptr(),
+ indices.data_ptr(),
+ sparse_C.data_ptr(),
+ batch_size,
+ A_num_block,
+ B_num_block,
+ num_block
+ );
+
+ return sparse_C;
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_launch.h b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_launch.h
new file mode 100644
index 0000000000000000000000000000000000000000..0200140ee337b8c5d9583767bbad1e842e9d4677
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/cuda_launch.h
@@ -0,0 +1,39 @@
+#include
+#include
+#include
+
+#define min(a, b) ((a)<(b)?(a):(b))
+#define max(a, b) ((a)>(b)?(a):(b))
+
+std::vector index_max_kernel(
+ at::Tensor index_vals,
+ at::Tensor indices,
+ int A_num_block,
+ int B_num_block
+);
+
+at::Tensor mm_to_sparse_kernel(
+ at::Tensor dense_A,
+ at::Tensor dense_B,
+ at::Tensor indices
+);
+
+at::Tensor sparse_dense_mm_kernel(
+ at::Tensor sparse_A,
+ at::Tensor indices,
+ at::Tensor dense_B,
+ int A_num_block
+);
+
+at::Tensor reduce_sum_kernel(
+ at::Tensor sparse_A,
+ at::Tensor indices,
+ int A_num_block,
+ int B_num_block
+);
+
+at::Tensor scatter_kernel(
+ at::Tensor dense_A,
+ at::Tensor indices,
+ int B_num_block
+);
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/torch_extension.cpp b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/torch_extension.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..60c9262b779270a6e95ae54f53a67daa6d740a9e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/mra/torch_extension.cpp
@@ -0,0 +1,78 @@
+#include
+#include
+#include "cuda_launch.h"
+#include
+
+std::vector index_max(
+ at::Tensor index_vals,
+ at::Tensor indices,
+ int A_num_block,
+ int B_num_block
+) {
+ return index_max_kernel(
+ index_vals,
+ indices,
+ A_num_block,
+ B_num_block
+ );
+}
+
+at::Tensor mm_to_sparse(
+ at::Tensor dense_A,
+ at::Tensor dense_B,
+ at::Tensor indices
+) {
+ return mm_to_sparse_kernel(
+ dense_A,
+ dense_B,
+ indices
+ );
+}
+
+at::Tensor sparse_dense_mm(
+ at::Tensor sparse_A,
+ at::Tensor indices,
+ at::Tensor dense_B,
+ int A_num_block
+) {
+ return sparse_dense_mm_kernel(
+ sparse_A,
+ indices,
+ dense_B,
+ A_num_block
+ );
+}
+
+at::Tensor reduce_sum(
+ at::Tensor sparse_A,
+ at::Tensor indices,
+ int A_num_block,
+ int B_num_block
+) {
+ return reduce_sum_kernel(
+ sparse_A,
+ indices,
+ A_num_block,
+ B_num_block
+ );
+}
+
+at::Tensor scatter(
+ at::Tensor dense_A,
+ at::Tensor indices,
+ int B_num_block
+) {
+ return scatter_kernel(
+ dense_A,
+ indices,
+ B_num_block
+ );
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("index_max", &index_max, "index_max (CUDA)");
+ m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)");
+ m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)");
+ m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)");
+ m.def("scatter", &scatter, "scatter (CUDA)");
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/rwkv/wkv_cuda.cu b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/rwkv/wkv_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..571d5a8a8307e95aac689eb3c9333d1ad350c7de
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/rwkv/wkv_cuda.cu
@@ -0,0 +1,187 @@
+#include
+#include
+
+#define MIN_VALUE (-1e38)
+
+template
+__global__ void kernel_forward(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ F aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+}
+
+template
+__global__ void kernel_forward_with_state(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset_s = _b * C * 3 + _c * 3;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+ F *__restrict__ const s = _s + _offset_s;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ F aa = s[0], bb = s[1], pp = s[2];
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ s[0] = aa;
+ s[1] = bb;
+ s[2] = pp;
+}
+
+template
+__global__ void kernel_backward(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
+ const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
+ F *__restrict__ const _gv
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ const F *__restrict__ const y = _y + _offset;
+ const F *__restrict__ const gy = _gy + _offset;
+ F *__restrict__ const gk = _gk + _offset;
+ F *__restrict__ const gv = _gv + _offset;
+
+ F q[Tmax], r[Tmax];
+
+ F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+ const F yy = y[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ const F qq = gy[ii] / (e1 * bb + e2);
+ gw += (ga - gb * yy) * e1 * qq;
+ gu += (vv - yy) * e2 * qq;
+ q[i] = qq;
+ r[i] = ww - p;
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ ga = e1 * (aa + ga);
+ gb = e1 * (bb + gb);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
+ _gu[_offsetBC] = gu;
+
+ aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+ const F yy = y[ii];
+ const F qq = q[i];
+ const F rr = r[i];
+
+ F e1 = qq * exp(rr);
+ F e2 = exp(kk + pp);
+ gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
+ gv[ii] = e1 + e2 * aa;
+
+ const F ww = w + pp;
+ const F www = rr - u - kk;
+ const F p = max(ww, www);
+ e1 = exp(ww - p);
+ e2 = qq * exp(www - p);
+ aa = e1 * aa + e2;
+ bb = e1 * bb - e2 * yy;
+ pp = p;
+ }
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward<<>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_with_state<<>>(B, T, C, w, u, k, v, y, s);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
+}
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu
new file mode 100644
index 0000000000000000000000000000000000000000..042cb4aba1db98be5916aea1de86a7fed0b6510d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu
@@ -0,0 +1,186 @@
+#include
+#include
+#include "ATen/ATen.h"
+#define MIN_VALUE (-1e38)
+typedef at::BFloat16 bf16;
+
+__global__ void kernel_forward_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ bf16 *__restrict__ const y = _y + _offset;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ float aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+}
+
+__global__ void kernel_forward_with_state_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y,
+ float *__restrict__ const _s
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset_s = _b * C * 3 + _c * 3;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ bf16 *__restrict__ const y = _y + _offset;
+ float *__restrict__ const s = _s + _offset_s;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ float aa = s[0], bb = s[1], pp = s[2];
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ s[0] = aa;
+ s[1] = bb;
+ s[2] = pp;
+}
+
+__global__ void kernel_backward_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y,
+ const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu,
+ bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ const bf16 *__restrict__ const y = _y + _offset;
+ const bf16 *__restrict__ const gy = _gy + _offset;
+ bf16 *__restrict__ const gk = _gk + _offset;
+ bf16 *__restrict__ const gv = _gv + _offset;
+
+ float q[Tmax], r[Tmax];
+
+ float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+ const float yy = float(y[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ const float qq = float(gy[ii]) / (e1 * bb + e2);
+ gw += (ga - gb * yy) * e1 * qq;
+ gu += (vv - yy) * e2 * qq;
+ q[i] = qq;
+ r[i] = ww - p;
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ ga = e1 * (aa + ga);
+ gb = e1 * (bb + gb);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
+ _gu[_offsetBC] = bf16(gu);
+
+ aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+ const float yy = float(y[ii]);
+ const float qq = q[i];
+ const float rr = r[i];
+
+ float e1 = qq * exp(rr);
+ float e2 = exp(kk + pp);
+ gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
+ gv[ii] = bf16(e1 + e2 * aa);
+
+ const float ww = w + pp;
+ const float www = rr - u - kk;
+ const float p = max(ww, www);
+ e1 = exp(ww - p);
+ e2 = qq * exp(www - p);
+ aa = e1 * aa + e2;
+ bb = e1 * bb - e2 * yy;
+ pp = p;
+ }
+}
+
+void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_bf16<<>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_with_state_bf16<<