Spaces:
Running on Zero
Running on Zero
FeiElysia commited on
Commit Β·
161b19e
1
Parent(s): 5087c07
π Initial deploy
Browse files- .gitattributes +2 -0
- app.py +384 -0
- examples/cover_videomme_FjS2LzrHEO8.png +3 -0
- examples/cover_videomme_fFjv93ACGo8.png +3 -0
- examples/demo.mp4 +3 -0
- examples/demo_cases.json +40 -0
- examples/description_honkai3_becauseofyou.png +3 -0
- examples/honkai3_becauseofyou.mp4 +3 -0
- examples/hsr_helloworld.mp4 +3 -0
- examples/lvbench_gXnhqF0TqqI.mp4 +3 -0
- examples/meme_hsr_helloworld.png +3 -0
- examples/ocr_honkai3_becauseofyou.png +3 -0
- examples/performance_hsr_helloworld.png +3 -0
- examples/tempo.png +3 -0
- examples/tempo.svg +0 -0
- examples/videomme_FjS2LzrHEO8.mp4 +3 -0
- examples/videomme_FsLaTZmP6Uw.mp4 +3 -0
- examples/videomme_Sp2nxlrQ89w.mp4 +3 -0
- examples/videomme_fFjv93ACGo8.mp4 +3 -0
- packages.txt +1 -0
- requirements.txt +11 -0
- tempo/__init__.py +6 -0
- tempo/__pycache__/__init__.cpython-312.pyc +0 -0
- tempo/__pycache__/builder.cpython-312.pyc +0 -0
- tempo/__pycache__/constants.cpython-312.pyc +0 -0
- tempo/__pycache__/conversation.cpython-312.pyc +0 -0
- tempo/__pycache__/mm_datautils.cpython-312.pyc +0 -0
- tempo/__pycache__/mm_utils.cpython-312.pyc +0 -0
- tempo/__pycache__/tempo_arch.cpython-312.pyc +0 -0
- tempo/__pycache__/vlm_multimodal_processor.cpython-312.pyc +0 -0
- tempo/builder.py +62 -0
- tempo/constants.py +13 -0
- tempo/conversation.py +545 -0
- tempo/language_model/__pycache__/modeling_tempo_qwen.cpython-312.pyc +0 -0
- tempo/language_model/modeling_tempo_qwen.py +231 -0
- tempo/mm_datautils.py +1607 -0
- tempo/multimodal_encoder/__pycache__/base_encoder.cpython-312.pyc +0 -0
- tempo/multimodal_encoder/__pycache__/builder.cpython-312.pyc +0 -0
- tempo/multimodal_encoder/__pycache__/qwen3vl_encoder.cpython-312.pyc +0 -0
- tempo/multimodal_encoder/__pycache__/siglip_encoder.cpython-312.pyc +0 -0
- tempo/multimodal_encoder/base_encoder.py +135 -0
- tempo/multimodal_encoder/builder.py +21 -0
- tempo/multimodal_encoder/qwen3vl_encoder.py +336 -0
- tempo/multimodal_encoder/siglip_encoder.py +75 -0
- tempo/multimodal_projector/__pycache__/builder.cpython-312.pyc +0 -0
- tempo/multimodal_projector/builder.py +51 -0
- tempo/tempo_arch.py +464 -0
- tempo/vlm_multimodal_processor.py +332 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import multiprocessing
|
| 8 |
+
from decord import cpu, VideoReader
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.ticker as ticker
|
| 12 |
+
import matplotlib.colors as mcolors
|
| 13 |
+
from scipy.interpolate import make_interp_spline
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from tempo.builder import load_pretrained_model
|
| 17 |
+
from tempo.conversation import conv_templates, SeparatorStyle
|
| 18 |
+
from tempo.constants import (
|
| 19 |
+
DEFAULT_IM_END_TOKEN,
|
| 20 |
+
DEFAULT_IM_START_TOKEN,
|
| 21 |
+
DEFAULT_IMAGE_TOKEN,
|
| 22 |
+
IMAGE_TOKEN_INDEX,
|
| 23 |
+
)
|
| 24 |
+
from tempo.mm_datautils import (
|
| 25 |
+
compute_segment_timestamp,
|
| 26 |
+
KeywordsStoppingCriteria,
|
| 27 |
+
process_qwen_content,
|
| 28 |
+
tokenizer_image_token,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
import spaces
|
| 32 |
+
from huggingface_hub import snapshot_download
|
| 33 |
+
|
| 34 |
+
def get_real_cpu_cores():
|
| 35 |
+
"""use multiple threads for video decoding"""
|
| 36 |
+
try:
|
| 37 |
+
# HF Spaces
|
| 38 |
+
cores = len(os.sched_getaffinity(0))
|
| 39 |
+
except AttributeError:
|
| 40 |
+
# Local environments
|
| 41 |
+
cores = multiprocessing.cpu_count()
|
| 42 |
+
return cores
|
| 43 |
+
|
| 44 |
+
def compute_sample_indices(
|
| 45 |
+
total_frames: int,
|
| 46 |
+
original_fps: float,
|
| 47 |
+
video_fps: float = 2.0,
|
| 48 |
+
min_frames_num: int = 4,
|
| 49 |
+
max_frames_num: int = 1024
|
| 50 |
+
) -> list[int]:
|
| 51 |
+
|
| 52 |
+
start_frame, end_frame = 0, total_frames - 1
|
| 53 |
+
clip_frames = end_frame - start_frame + 1
|
| 54 |
+
if clip_frames <= 1:
|
| 55 |
+
return [start_frame]
|
| 56 |
+
|
| 57 |
+
if original_fps is None or original_fps <= 0:
|
| 58 |
+
original_fps = video_fps
|
| 59 |
+
|
| 60 |
+
clip_duration = clip_frames / original_fps
|
| 61 |
+
target_num_frames = max(1, round(clip_duration * video_fps))
|
| 62 |
+
final_num_frames = min(max(target_num_frames, min_frames_num), max_frames_num)
|
| 63 |
+
|
| 64 |
+
if final_num_frames == 1:
|
| 65 |
+
return [end_frame]
|
| 66 |
+
|
| 67 |
+
indices = np.round(np.linspace(start_frame, end_frame, final_num_frames)).astype(int)
|
| 68 |
+
indices = np.clip(indices, start_frame, end_frame)
|
| 69 |
+
|
| 70 |
+
return indices.tolist()
|
| 71 |
+
|
| 72 |
+
def load_video(video_path: str, video_fps: float = 2.0, max_frames: int = 1024) -> tuple:
|
| 73 |
+
|
| 74 |
+
available_cores = get_real_cpu_cores()
|
| 75 |
+
optimal_threads = min(max(1, available_cores - 1), 16)
|
| 76 |
+
print(f"[Profiling] Detected {available_cores} CPU cores. Decord using {optimal_threads} threads.")
|
| 77 |
+
|
| 78 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=optimal_threads)
|
| 79 |
+
total_frames = len(vr)
|
| 80 |
+
original_fps = vr.get_avg_fps()
|
| 81 |
+
frame_idx = compute_sample_indices(total_frames, original_fps, video_fps, max_frames_num=max_frames)
|
| 82 |
+
images = vr.get_batch(frame_idx).asnumpy()
|
| 83 |
+
clip_duration = total_frames / original_fps
|
| 84 |
+
|
| 85 |
+
real_fps = len(images) / clip_duration if clip_duration > 0 else video_fps
|
| 86 |
+
|
| 87 |
+
return images, real_fps
|
| 88 |
+
|
| 89 |
+
def generate_allocation_plot(allocations):
|
| 90 |
+
"""
|
| 91 |
+
Token allocation visualization function
|
| 92 |
+
"""
|
| 93 |
+
if allocations is None or len(allocations) == 0:
|
| 94 |
+
# if disable_dynamic_compress is True, we return a blank image
|
| 95 |
+
return Image.new('RGB', (1600, 350), color='white')
|
| 96 |
+
|
| 97 |
+
allocations = np.array(allocations)
|
| 98 |
+
num_segments = len(allocations)
|
| 99 |
+
|
| 100 |
+
plt.rcParams.update({'font.size': 14, 'font.family': 'serif'})
|
| 101 |
+
fig = plt.figure(figsize=(16, 3.5), layout='constrained')
|
| 102 |
+
gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1.0], hspace=0.05)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
ax_heat = fig.add_subplot(gs[0])
|
| 106 |
+
ax_heat.set_title(" ", pad=50)
|
| 107 |
+
|
| 108 |
+
colors = ["#EBF5FB", "#85C1E9", "#F2D7D5", "#E74C3C", "#641E16"]
|
| 109 |
+
cmap_custom = mcolors.LinearSegmentedColormap.from_list("custom_heat", colors)
|
| 110 |
+
|
| 111 |
+
vmax_val = max(128, allocations.max())
|
| 112 |
+
ax_heat.imshow([allocations], cmap=cmap_custom, aspect='auto', extent=[0.5, num_segments + 0.5, 0, 1], vmin=4, vmax=vmax_val)
|
| 113 |
+
ax_heat.set_yticks([])
|
| 114 |
+
ax_heat.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
| 115 |
+
for spine in ax_heat.spines.values():
|
| 116 |
+
spine.set_linewidth(1.2)
|
| 117 |
+
|
| 118 |
+
ax_line = fig.add_subplot(gs[1], sharex=ax_heat)
|
| 119 |
+
|
| 120 |
+
x = np.arange(1, num_segments + 1)
|
| 121 |
+
|
| 122 |
+
if num_segments > 3:
|
| 123 |
+
spl = make_interp_spline(x, allocations, k=3)
|
| 124 |
+
x_smooth = np.linspace(1, num_segments, 800)
|
| 125 |
+
y_smooth = spl(x_smooth)
|
| 126 |
+
y_smooth = np.clip(y_smooth, 4, vmax_val)
|
| 127 |
+
else:
|
| 128 |
+
x_smooth = x
|
| 129 |
+
y_smooth = allocations
|
| 130 |
+
|
| 131 |
+
line_color = '#1A252C'
|
| 132 |
+
fill_color = '#D5D8DC'
|
| 133 |
+
|
| 134 |
+
ax_line.plot(x_smooth, y_smooth, color=line_color, linewidth=2.0)
|
| 135 |
+
ax_line.fill_between(x_smooth, y_smooth, color=fill_color, alpha=0.4)
|
| 136 |
+
|
| 137 |
+
ax_line.axhline(vmax_val, color='#C0392B', linestyle='--', linewidth=1.2, alpha=0.8)
|
| 138 |
+
ax_line.axhline(4, color='#2980B9', linestyle='--', linewidth=1.2, alpha=0.8)
|
| 139 |
+
|
| 140 |
+
ax_line.set_xlim(0.5, num_segments + 0.5)
|
| 141 |
+
ax_line.set_ylim(0, vmax_val + 12)
|
| 142 |
+
ax_line.set_ylabel("Tokens / Seg", fontsize=14, fontweight='bold')
|
| 143 |
+
ax_line.set_xlabel("Temporal Segments", fontsize=14, fontweight='bold')
|
| 144 |
+
ax_line.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
|
| 145 |
+
|
| 146 |
+
ax_line.spines['top'].set_visible(False)
|
| 147 |
+
ax_line.spines['right'].set_visible(False)
|
| 148 |
+
ax_line.spines['bottom'].set_linewidth(1.2)
|
| 149 |
+
ax_line.spines['left'].set_linewidth(1.2)
|
| 150 |
+
ax_line.grid(axis='y', linestyle=':', color='gray', alpha=0.5)
|
| 151 |
+
|
| 152 |
+
buf = io.BytesIO()
|
| 153 |
+
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100, transparent=True)
|
| 154 |
+
plt.close(fig)
|
| 155 |
+
buf.seek(0)
|
| 156 |
+
return Image.open(buf)
|
| 157 |
+
|
| 158 |
+
model_id = "Vision-CAIR/Tempo-6B"
|
| 159 |
+
print(f"[Init] Downloading/Loading weights from {model_id}...")
|
| 160 |
+
MODEL_PATH = snapshot_download(repo_id=model_id)
|
| 161 |
+
print(f"[Init] Loading Tempo model from {MODEL_PATH}...")
|
| 162 |
+
tokenizer, model, image_processor = load_pretrained_model(
|
| 163 |
+
MODEL_PATH,
|
| 164 |
+
device_map="cuda",
|
| 165 |
+
use_flash_attn=True
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
FIXED_MAX_LENGTH = 16384
|
| 169 |
+
model.config.tokenizer_model_max_length = FIXED_MAX_LENGTH
|
| 170 |
+
tokenizer.model_max_length = FIXED_MAX_LENGTH
|
| 171 |
+
model.eval()
|
| 172 |
+
model.to(torch.bfloat16)
|
| 173 |
+
print(f"[Init] Model loaded! Max context length set to {FIXED_MAX_LENGTH}.")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ==========================================
|
| 177 |
+
# inference
|
| 178 |
+
# ==========================================
|
| 179 |
+
@spaces.GPU
|
| 180 |
+
def predict(video_path, query, max_frames, visual_token_budget, temperature, max_new_tokens, disable_dynamic_compress):
|
| 181 |
+
if not video_path:
|
| 182 |
+
return "β οΈ Error: Please upload a video first."
|
| 183 |
+
if not query:
|
| 184 |
+
return "β οΈ Error: Please enter a question."
|
| 185 |
+
|
| 186 |
+
print(f"\n[Request] Video: {video_path} | Query: {query}")
|
| 187 |
+
|
| 188 |
+
model.config.visual_token_budget = int(visual_token_budget)
|
| 189 |
+
model.get_vision_tower_aux_list()[0].dynamic_compress = not disable_dynamic_compress
|
| 190 |
+
|
| 191 |
+
# video process
|
| 192 |
+
start_prep_time = time.perf_counter()
|
| 193 |
+
try:
|
| 194 |
+
video_frames, real_fps = load_video(video_path, video_fps=2.0, max_frames=int(max_frames))
|
| 195 |
+
except Exception as e:
|
| 196 |
+
return f"β οΈ Error loading video: {str(e)}"
|
| 197 |
+
|
| 198 |
+
# process local compressor inputs
|
| 199 |
+
frame_windows, frame_stride = 8, 8
|
| 200 |
+
vlm_inputs = process_qwen_content(
|
| 201 |
+
video_frames, "video", query, image_processor[0], real_fps, frame_windows, frame_stride, is_eval=True
|
| 202 |
+
)
|
| 203 |
+
vlm_inputs = {key: v.cuda() for key, v in vlm_inputs.items()}
|
| 204 |
+
|
| 205 |
+
# compute timestamp for each segment
|
| 206 |
+
seg_timestamps = compute_segment_timestamp(
|
| 207 |
+
len(vlm_inputs["video_grid_thw"]), tokenizer, real_fps, frame_stride, frame_windows
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# stat info
|
| 211 |
+
num_segments = len(vlm_inputs["video_grid_thw"])
|
| 212 |
+
segment_duration = frame_windows / real_fps
|
| 213 |
+
stats_info = f"π¬ Video Stats: Total Segments: {num_segments} | Segment Duration: {segment_duration:.2f}s | Real FPS: {real_fps:.2f}"
|
| 214 |
+
|
| 215 |
+
# prompt
|
| 216 |
+
if getattr(model.config, "mm_use_im_start_end", False):
|
| 217 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + query
|
| 218 |
+
else:
|
| 219 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + query
|
| 220 |
+
|
| 221 |
+
conv_version = "qwen"
|
| 222 |
+
conv = conv_templates[conv_version].copy()
|
| 223 |
+
conv.append_message(conv.roles[0], qs)
|
| 224 |
+
conv.append_message(conv.roles[1], None)
|
| 225 |
+
prompt = conv.get_prompt()
|
| 226 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 227 |
+
|
| 228 |
+
# tokenization
|
| 229 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
|
| 230 |
+
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
|
| 231 |
+
|
| 232 |
+
model._demo_count_allocations = []
|
| 233 |
+
|
| 234 |
+
start_infer_time = time.perf_counter()
|
| 235 |
+
|
| 236 |
+
# generating
|
| 237 |
+
with torch.inference_mode():
|
| 238 |
+
output_ids = model.generate(
|
| 239 |
+
input_ids,
|
| 240 |
+
images=None, # Qwen-VL architecture usually uses vlm_inputs instead of raw images in kwargs if projector is vlm
|
| 241 |
+
image_sizes=None,
|
| 242 |
+
do_sample=(temperature > 0),
|
| 243 |
+
temperature=temperature if temperature > 0 else None,
|
| 244 |
+
max_new_tokens=int(max_new_tokens),
|
| 245 |
+
use_cache=True,
|
| 246 |
+
stopping_criteria=[stopping_criteria],
|
| 247 |
+
vlm_inputs=vlm_inputs,
|
| 248 |
+
seg_timestamps=seg_timestamps,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
end_infer_time = time.perf_counter()
|
| 252 |
+
|
| 253 |
+
if isinstance(output_ids, tuple):
|
| 254 |
+
output_ids = output_ids[0]
|
| 255 |
+
|
| 256 |
+
prep_duration = start_infer_time - start_prep_time
|
| 257 |
+
infer_duration = end_infer_time - start_infer_time
|
| 258 |
+
total_duration = end_infer_time - start_prep_time
|
| 259 |
+
stats_info += f"\nβ‘ Profiling : Prep Time: {prep_duration:.2f}s | Inference Time: {infer_duration:.2f}s | Total: {total_duration:.2f}s"
|
| 260 |
+
|
| 261 |
+
pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 262 |
+
if pred.endswith(stop_str):
|
| 263 |
+
pred = pred[: -len(stop_str)].strip()
|
| 264 |
+
|
| 265 |
+
# token allocation plot
|
| 266 |
+
allocations_data = model._demo_count_allocations
|
| 267 |
+
plot_img = generate_allocation_plot(allocations_data)
|
| 268 |
+
|
| 269 |
+
return pred, plot_img, stats_info
|
| 270 |
+
|
| 271 |
+
# ==========================================
|
| 272 |
+
# UI
|
| 273 |
+
# ==========================================
|
| 274 |
+
with gr.Blocks(title="Tempo Video Understanding", theme=gr.themes.Soft()) as demo:
|
| 275 |
+
gr.Markdown(
|
| 276 |
+
"""
|
| 277 |
+
# β±οΈ Tempo: Small Vision-Language Models are Smart Compressors for Long Video Understanding
|
| 278 |
+
Upload a video and ask any question! Tempo dynamically compresses visual tokens based on your query to achieve SOTA performance.
|
| 279 |
+
**[π Project Page](https://feielysia.github.io/)** | **[π» GitHub](https://github.com/FeiElysia)** | **[π Paper](https://arxiv.org/abs/xxxx)** | **[π¨βπ» @Junjie Fei](https://feielysia.github.io/)**
|
| 280 |
+
|
| 281 |
+
*β³ **Slow preprocessing?** Try Examples 4 & 5 below, decrease `Max Sampled Frames` in Advanced Settings, or check our [GitHub](https://github.com/FeiElysia) for full-speed local deployment.*
|
| 282 |
+
"""
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
with gr.Row():
|
| 286 |
+
# left column: inputs
|
| 287 |
+
with gr.Column(scale=1):
|
| 288 |
+
video_input = gr.Video(label="Upload Video")
|
| 289 |
+
example_poster = gr.Image(label="Video Poster", interactive=False, height=150, visible=False)
|
| 290 |
+
query_input = gr.Textbox(label="Your Question", placeholder="e.g., What is the person doing in the video?", lines=3)
|
| 291 |
+
with gr.Row():
|
| 292 |
+
clear_btn = gr.Button("π§Ή Clear", variant="secondary")
|
| 293 |
+
submit_btn = gr.Button("π Generate Response", variant="primary")
|
| 294 |
+
|
| 295 |
+
# hyperparameters
|
| 296 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 297 |
+
max_frames_slider = gr.Slider(minimum=16, maximum=2048, value=1024, step=16, label="Max Sampled Frames")
|
| 298 |
+
budget_slider = gr.Slider(minimum=64, maximum=16384, value=8192, step=64, label="Visual Token Budget")
|
| 299 |
+
temp_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature (0 = Greedy)")
|
| 300 |
+
max_tokens_slider = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max New Tokens")
|
| 301 |
+
disable_compress_chk = gr.Checkbox(label="Disable Dynamic Compression (Baseline)", value=False)
|
| 302 |
+
|
| 303 |
+
# right column: outputs
|
| 304 |
+
with gr.Column(scale=1):
|
| 305 |
+
output_text = gr.Textbox(label="Tempo Response", lines=12, interactive=False)
|
| 306 |
+
stats_text = gr.Textbox(label="π Video Segment Stats", lines=1, interactive=False)
|
| 307 |
+
output_plot = gr.Image(label="Query-Aware Visual Feature Intensity (Visual Token Allocation)", interactive=False, height=180)
|
| 308 |
+
|
| 309 |
+
# clicking submit_btn or pressing enter in query_input will trigger prediction
|
| 310 |
+
submit_btn.click(
|
| 311 |
+
fn=predict,
|
| 312 |
+
inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk],
|
| 313 |
+
outputs=[output_text, output_plot, stats_text]
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
query_input.submit(
|
| 317 |
+
fn=predict,
|
| 318 |
+
inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk],
|
| 319 |
+
outputs=[output_text, output_plot, stats_text]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
clear_btn.click(
|
| 323 |
+
fn=lambda: (None, None, None, None, None, None),
|
| 324 |
+
inputs=None,
|
| 325 |
+
outputs=[video_input, example_poster, query_input, output_text, stats_text, output_plot]
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Examples
|
| 329 |
+
gr.Markdown("---")
|
| 330 |
+
gr.Markdown("### π‘ Try an Example")
|
| 331 |
+
gr.Examples(
|
| 332 |
+
examples=[
|
| 333 |
+
[
|
| 334 |
+
"examples/hsr_helloworld.mp4",
|
| 335 |
+
"Task: Please examine the provided media and answer the following three questions regarding the specific puppy in the scene:\n"
|
| 336 |
+
"Q1: What is the primary fur color of the puppy positioned on the swing?\n"
|
| 337 |
+
"Q2: Specify the exact time interval (in seconds, e.g., XX-XXs) during which the puppy is seen sitting on the swing.\n"
|
| 338 |
+
"Q3: Provide a brief description of the puppy's appearance and its surroundings.",
|
| 339 |
+
"examples/meme_hsr_helloworld.png"
|
| 340 |
+
],
|
| 341 |
+
[
|
| 342 |
+
"examples/hsr_helloworld.mp4",
|
| 343 |
+
"Task: Please analyze the provided video and answer the following 7 questions precisely.\n"
|
| 344 |
+
"Q1: How many performers are visible on the stage?\n"
|
| 345 |
+
"Q2: Describe the architectural elements in the background. What historical civilization do they remind you of?\n"
|
| 346 |
+
"Q3: What is happening in the night sky above the performers, and what does this suggest about the event?\n"
|
| 347 |
+
"Q4: List the hair colors of the performers in order from left to right.\n"
|
| 348 |
+
"Q5: Identify the specific musical instrument being played by the performer located on the far left of the stage.\n"
|
| 349 |
+
"Q6: What is the specific time interval (in seconds, e.g., XX-XXs) during which this fireworks performance scene occurs in the video?\n"
|
| 350 |
+
"Q7: Look at the audience in the foreground. How does their silhouette-like depiction affect the viewer's perspective of the stage?",
|
| 351 |
+
"examples/performance_hsr_helloworld.png"
|
| 352 |
+
],
|
| 353 |
+
[
|
| 354 |
+
"examples/honkai3_becauseofyou.mp4",
|
| 355 |
+
"What text appears in the center of the video behind a sea of pink flowers?",
|
| 356 |
+
"examples/ocr_honkai3_becauseofyou.png"
|
| 357 |
+
],
|
| 358 |
+
[
|
| 359 |
+
"examples/videomme_fFjv93ACGo8.mp4",
|
| 360 |
+
"How many red socks are above the fireplace at the end of this video?",
|
| 361 |
+
"examples/cover_videomme_fFjv93ACGo8.png"
|
| 362 |
+
],
|
| 363 |
+
[
|
| 364 |
+
"examples/videomme_FjS2LzrHEO8.mp4",
|
| 365 |
+
"What was the purpose of using a hammer to hit the car in the video?\n"
|
| 366 |
+
"A. To show the hammer works well.\n"
|
| 367 |
+
"B. To show the solidity of the car.\n"
|
| 368 |
+
"C. To warn people not to hit cars with hammers.\n"
|
| 369 |
+
"D. To illustrate that a hammer is harder than a bullet.",
|
| 370 |
+
"examples/cover_videomme_FjS2LzrHEO8.png"
|
| 371 |
+
],
|
| 372 |
+
[
|
| 373 |
+
"examples/honkai3_becauseofyou.mp4",
|
| 374 |
+
"Describe the video in detail.",
|
| 375 |
+
"examples/description_honkai3_becauseofyou.png"
|
| 376 |
+
]
|
| 377 |
+
],
|
| 378 |
+
inputs=[video_input, query_input, example_poster],
|
| 379 |
+
cache_examples=False,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
demo.queue().launch(share=True)
|
| 384 |
+
|
examples/cover_videomme_FjS2LzrHEO8.png
ADDED
|
Git LFS Details
|
examples/cover_videomme_fFjv93ACGo8.png
ADDED
|
Git LFS Details
|
examples/demo.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec316c8e5fe7f2a62137060c0a35123de75802522816c74478c69dbc329d45da
|
| 3 |
+
size 85447789
|
examples/demo_cases.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"video_path": "./examples/hsr_helloworld.mp4",
|
| 4 |
+
"query": "Task: Please examine the provided media and answer the following three questions regarding the specific puppy in the scene:\nQ1: What is the primary fur color of the puppy positioned on the swing?\nQ2: Specify the exact time interval (in seconds, e.g., XX-XXs) during which the puppy is seen sitting on the swing.\nQ3: Provide a brief description of the puppy's appearance and its surroundings."
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"video_path": "./examples/hsr_helloworld.mp4",
|
| 8 |
+
"query": "Task: Please analyze the provided video and answer the following 7 questions precisely.\nQ1: How many performers are visible on the stage?\nQ2: Describe the architectural elements in the background. What historical civilization do they remind you of?\nQ3: What is happening in the night sky above the performers, and what does this suggest about the event?\nQ4: List the hair colors of the performers in order from left to right.\nQ5: Identify the specific musical instrument being played by the performer located on the far left of the stage.\nQ6: What is the specific time interval (in seconds, e.g., XX-XXs) during which this fireworks performance scene occurs in the video?\nQ7: Look at the audience in the foreground. How does their silhouette-like depiction affect the viewer's perspective of the stage?"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"video_path": "./examples/honkai3_becauseofyou.mp4",
|
| 12 |
+
"query": "What text appears in the center of the video behind a sea of pink flowers?"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"video_path": "./examples/honkai3_becauseofyou.mp4",
|
| 16 |
+
"query": "Describe the video in detail."
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"video_path": "./examples/videomme_fFjv93ACGo8.mp4",
|
| 20 |
+
"query": "What colors are the clothes worn by the two announcers in the studio?"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"video_path": "./examples/videomme_FjS2LzrHEO8.mp4",
|
| 24 |
+
"query": "What was the purpose of using a hammer to hit the car in the video?\nA. To show the hammer works well.\nB. To show the solidity of the car.\nC. To warn people not to hit cars with hammers.\nD. To illustrate that a hammer is harder than a bullet."
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"video_path": "./examples/videomme_FsLaTZmP6Uw.mp4",
|
| 28 |
+
"query": "Which year was the game held?"
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"video_path": "./examples/videomme_Sp2nxlrQ89w.mp4",
|
| 32 |
+
"query": "In line with the video evidence, why does the orange stickman want to destroy the Minecraft world?\nA. He wants to save his son.\nB. He is too sad.\nC. He loses his son.\nD. He does like the world."
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"video_path": "./examples/lvbench_gXnhqF0TqqI.mp4",
|
| 36 |
+
"query": "Where are the woman and children when they first appear in the video?"
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
]
|
examples/description_honkai3_becauseofyou.png
ADDED
|
Git LFS Details
|
examples/honkai3_becauseofyou.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12f943f66683a6a5a49888d4651dd868d75ce4be4fc35ea5af0853b5921de62f
|
| 3 |
+
size 43919755
|
examples/hsr_helloworld.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd2fb2c3a84719bb586925e68acb70de282020268a42bfef53635b88c8e6afcd
|
| 3 |
+
size 17860778
|
examples/lvbench_gXnhqF0TqqI.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7135161f9b45bf968ec320cc27b3c381593f47dbd8f17013a11a59da50980315
|
| 3 |
+
size 276007868
|
examples/meme_hsr_helloworld.png
ADDED
|
Git LFS Details
|
examples/ocr_honkai3_becauseofyou.png
ADDED
|
Git LFS Details
|
examples/performance_hsr_helloworld.png
ADDED
|
Git LFS Details
|
examples/tempo.png
ADDED
|
Git LFS Details
|
examples/tempo.svg
ADDED
|
|
examples/videomme_FjS2LzrHEO8.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:43b9cb496d7e3c56b9de35e777df206f84c99a907bdbd569a788ba098ffd46e8
|
| 3 |
+
size 6328020
|
examples/videomme_FsLaTZmP6Uw.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c68ee0eab9c6ef905de336222e127e7b393e0a5d862c2cdc5997f50b3b130d48
|
| 3 |
+
size 8838644
|
examples/videomme_Sp2nxlrQ89w.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67436aa59d44434e2b115bc5d6e07ea7e06a6de1bdfaea7c91ca320fbdd787b3
|
| 3 |
+
size 136659511
|
examples/videomme_fFjv93ACGo8.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba82f02cf6acc6a25c6efc4783fc2252c65afbedb8dbc4c7148af08834fdf999
|
| 3 |
+
size 17126035
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spaces
|
| 2 |
+
qwen-vl-utils==0.0.14
|
| 3 |
+
transformers==4.57.1
|
| 4 |
+
accelerate==1.13.0
|
| 5 |
+
gradio
|
| 6 |
+
matplotlib
|
| 7 |
+
scipy
|
| 8 |
+
huggingface_hub
|
| 9 |
+
decord
|
| 10 |
+
av
|
| 11 |
+
flash-attn==2.7.4.post1
|
tempo/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from tempo.language_model.modeling_tempo_qwen import TempoConfig, TempoQwenForCausalLM
|
| 3 |
+
__all__ = [
|
| 4 |
+
"TempoConfig",
|
| 5 |
+
"TempoQwenForCausalLM",
|
| 6 |
+
]
|
tempo/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (305 Bytes). View file
|
|
|
tempo/__pycache__/builder.cpython-312.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
tempo/__pycache__/constants.cpython-312.pyc
ADDED
|
Binary file (556 Bytes). View file
|
|
|
tempo/__pycache__/conversation.cpython-312.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
tempo/__pycache__/mm_datautils.cpython-312.pyc
ADDED
|
Binary file (57.1 kB). View file
|
|
|
tempo/__pycache__/mm_utils.cpython-312.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
tempo/__pycache__/tempo_arch.cpython-312.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
tempo/__pycache__/vlm_multimodal_processor.cpython-312.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
tempo/builder.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from tempo.constants import (
|
| 17 |
+
DEFAULT_IM_END_TOKEN,
|
| 18 |
+
DEFAULT_IM_START_TOKEN,
|
| 19 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
| 20 |
+
)
|
| 21 |
+
from transformers import AutoTokenizer
|
| 22 |
+
|
| 23 |
+
from tempo.language_model.modeling_tempo_qwen import TempoQwenForCausalLM
|
| 24 |
+
|
| 25 |
+
def load_pretrained_model(
|
| 26 |
+
model_path,
|
| 27 |
+
device_map="auto",
|
| 28 |
+
device="cuda",
|
| 29 |
+
use_flash_attn=False,
|
| 30 |
+
**kwargs,
|
| 31 |
+
):
|
| 32 |
+
kwargs = {"device_map": device_map, **kwargs}
|
| 33 |
+
|
| 34 |
+
if device != "cuda":
|
| 35 |
+
kwargs["device_map"] = {"": device}
|
| 36 |
+
|
| 37 |
+
kwargs["dtype"] = torch.float16
|
| 38 |
+
if use_flash_attn:
|
| 39 |
+
kwargs["attn_implementation"] = "flash_attention_2"
|
| 40 |
+
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 42 |
+
model = TempoQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
| 43 |
+
add_tokens_flag = False
|
| 44 |
+
if getattr(model.config, "mm_use_im_patch_token", False):
|
| 45 |
+
add_tokens_flag = True
|
| 46 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 47 |
+
if getattr(model.config, "mm_use_im_start_end", False):
|
| 48 |
+
add_tokens_flag = True
|
| 49 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
| 50 |
+
if add_tokens_flag:
|
| 51 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 52 |
+
|
| 53 |
+
vision_tower_aux_list = model.get_vision_tower_aux_list()
|
| 54 |
+
for vision_tower_aux in vision_tower_aux_list:
|
| 55 |
+
if not vision_tower_aux.is_loaded:
|
| 56 |
+
vision_tower_aux.load_model(device_map=device_map)
|
| 57 |
+
vision_tower_aux.to(device=device, dtype=torch.float16)
|
| 58 |
+
|
| 59 |
+
image_processor = None
|
| 60 |
+
image_processor = [vision_tower_aux.image_processor for vision_tower_aux in vision_tower_aux_list]
|
| 61 |
+
|
| 62 |
+
return tokenizer, model, image_processor
|
tempo/constants.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 3 |
+
|
| 4 |
+
LOGDIR = "."
|
| 5 |
+
|
| 6 |
+
# Model Constants
|
| 7 |
+
IGNORE_INDEX = -100
|
| 8 |
+
IMAGE_TOKEN_INDEX = -200
|
| 9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
tempo/conversation.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import dataclasses
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from enum import auto, Enum
|
| 5 |
+
from typing import Any, Union
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SeparatorStyle(Enum):
|
| 12 |
+
"""Different separator style."""
|
| 13 |
+
|
| 14 |
+
SINGLE = auto()
|
| 15 |
+
TWO = auto()
|
| 16 |
+
MPT = auto()
|
| 17 |
+
PLAIN = auto()
|
| 18 |
+
LLAMA_2 = auto()
|
| 19 |
+
LLAMA_3 = auto()
|
| 20 |
+
LLAMA_3_1 = auto()
|
| 21 |
+
LLAMA_3_2 = auto()
|
| 22 |
+
QWEN = auto()
|
| 23 |
+
CHATML = auto()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclasses.dataclass
|
| 27 |
+
class Conversation:
|
| 28 |
+
"""A class that keeps all conversation history."""
|
| 29 |
+
|
| 30 |
+
system: str
|
| 31 |
+
roles: list[str]
|
| 32 |
+
messages: list[list[str]]
|
| 33 |
+
offset: int
|
| 34 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 35 |
+
sep: str = "###"
|
| 36 |
+
sep2: str = None
|
| 37 |
+
version: str = "Unknown"
|
| 38 |
+
|
| 39 |
+
tokenizer: Any = None
|
| 40 |
+
# Stop criteria (the default one is EOS token)
|
| 41 |
+
stop_str: Union[str, list[str]] = None
|
| 42 |
+
# Stops generation if meeting any token in this list
|
| 43 |
+
stop_token_ids: list[int] = None
|
| 44 |
+
|
| 45 |
+
skip_next: bool = False
|
| 46 |
+
|
| 47 |
+
def get_prompt(self):
|
| 48 |
+
messages = self.messages
|
| 49 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 50 |
+
messages = self.messages.copy()
|
| 51 |
+
init_role, init_msg = messages[0].copy()
|
| 52 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
| 53 |
+
if "mmtag" in self.version:
|
| 54 |
+
messages[0] = (init_role, init_msg)
|
| 55 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
| 56 |
+
messages.insert(1, (self.roles[1], "Received."))
|
| 57 |
+
else:
|
| 58 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
| 59 |
+
|
| 60 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
| 61 |
+
ret = self.system + self.sep
|
| 62 |
+
for role, message in messages:
|
| 63 |
+
if message:
|
| 64 |
+
if type(message) is tuple:
|
| 65 |
+
message, _, _ = message
|
| 66 |
+
ret += role + ": " + message + self.sep
|
| 67 |
+
else:
|
| 68 |
+
ret += role + ":"
|
| 69 |
+
|
| 70 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
| 71 |
+
seps = [self.sep, self.sep2]
|
| 72 |
+
ret = self.system + seps[0]
|
| 73 |
+
for i, (role, message) in enumerate(messages):
|
| 74 |
+
if message:
|
| 75 |
+
if type(message) is tuple:
|
| 76 |
+
message, _, _ = message
|
| 77 |
+
ret += role + ": " + message + seps[i % 2]
|
| 78 |
+
else:
|
| 79 |
+
ret += role + ":"
|
| 80 |
+
|
| 81 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
| 82 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
| 83 |
+
for role, message in messages:
|
| 84 |
+
if message:
|
| 85 |
+
if type(message) is tuple:
|
| 86 |
+
message, images, _ = message
|
| 87 |
+
message = "<image>" * len(images) + message
|
| 88 |
+
ret += role + "\n" + message + self.sep + "\n"
|
| 89 |
+
else:
|
| 90 |
+
ret += role + "\n"
|
| 91 |
+
return ret
|
| 92 |
+
|
| 93 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
| 94 |
+
ret = self.system + self.sep
|
| 95 |
+
for role, message in messages:
|
| 96 |
+
if message:
|
| 97 |
+
if type(message) is tuple:
|
| 98 |
+
message, _, _ = message
|
| 99 |
+
ret += role + message + self.sep
|
| 100 |
+
else:
|
| 101 |
+
ret += role
|
| 102 |
+
|
| 103 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
| 104 |
+
wrap_sys = lambda msg: (
|
| 105 |
+
f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
| 106 |
+
)
|
| 107 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
| 108 |
+
ret = ""
|
| 109 |
+
|
| 110 |
+
for i, (role, message) in enumerate(messages):
|
| 111 |
+
if i == 0:
|
| 112 |
+
assert message, "first message should not be none"
|
| 113 |
+
assert role == self.roles[0], "first message should come from user"
|
| 114 |
+
if message:
|
| 115 |
+
if type(message) is tuple:
|
| 116 |
+
message, _, _ = message
|
| 117 |
+
if i == 0:
|
| 118 |
+
message = wrap_sys(self.system) + message
|
| 119 |
+
if i % 2 == 0:
|
| 120 |
+
message = wrap_inst(message)
|
| 121 |
+
ret += self.sep + message
|
| 122 |
+
else:
|
| 123 |
+
ret += " " + message + " " + self.sep2
|
| 124 |
+
else:
|
| 125 |
+
ret += ""
|
| 126 |
+
ret = ret.lstrip(self.sep)
|
| 127 |
+
|
| 128 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
| 129 |
+
if self.tokenizer is None:
|
| 130 |
+
self.tokenizer = AutoTokenizer.from_pretrained("//path/to/llama3/tokenizer")
|
| 131 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
| 132 |
+
for role, message in messages:
|
| 133 |
+
if message:
|
| 134 |
+
if type(message) is tuple:
|
| 135 |
+
message, images = message
|
| 136 |
+
message = "<image>" * len(images) + message
|
| 137 |
+
chat_template_messages.append({"role": role, "content": message})
|
| 138 |
+
|
| 139 |
+
return self.tokenizer.apply_chat_template(
|
| 140 |
+
chat_template_messages, tokenize=False, add_generation_prompt=True
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3_1:
|
| 144 |
+
if self.tokenizer is None:
|
| 145 |
+
self.tokenizer = AutoTokenizer.from_pretrained("//path/to/llama3.1/tokenizer")
|
| 146 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
| 147 |
+
for role, message in messages:
|
| 148 |
+
if message:
|
| 149 |
+
if type(message) is tuple:
|
| 150 |
+
message, images = message
|
| 151 |
+
message = "<image>" * len(images) + message
|
| 152 |
+
chat_template_messages.append({"role": role, "content": message})
|
| 153 |
+
|
| 154 |
+
return self.tokenizer.apply_chat_template(
|
| 155 |
+
chat_template_messages, tokenize=False, add_generation_prompt=False
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3_2:
|
| 159 |
+
wrap_sys = lambda msg: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{msg}<|eot_id|>" if len(msg) > 0 else msg
|
| 160 |
+
wrap_inst_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>{msg}<|eot_id|>"
|
| 161 |
+
wrap_inst_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>{msg}<|eot_id|>"
|
| 162 |
+
ret = ""
|
| 163 |
+
|
| 164 |
+
for i, (role, message) in enumerate(messages):
|
| 165 |
+
if i == 0:
|
| 166 |
+
assert message, "first message should not be none"
|
| 167 |
+
assert role == self.roles[0], "first message should come from user"
|
| 168 |
+
if message:
|
| 169 |
+
if type(message) is tuple:
|
| 170 |
+
message, _, _ = message
|
| 171 |
+
if i == 0:
|
| 172 |
+
ret += wrap_sys(self.system)
|
| 173 |
+
|
| 174 |
+
if i % 2 == 0:
|
| 175 |
+
message = wrap_inst_user(message)
|
| 176 |
+
ret += message
|
| 177 |
+
else:
|
| 178 |
+
message = wrap_inst_assistant(message)
|
| 179 |
+
ret += message
|
| 180 |
+
else:
|
| 181 |
+
ret += ""
|
| 182 |
+
ret += "<|start_header_id|>assistant<|end_header_id|>"
|
| 183 |
+
|
| 184 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 185 |
+
seps = [self.sep, self.sep2]
|
| 186 |
+
ret = self.system
|
| 187 |
+
for i, (role, message) in enumerate(messages):
|
| 188 |
+
if message:
|
| 189 |
+
if type(message) is tuple:
|
| 190 |
+
message, _, _ = message
|
| 191 |
+
ret += message + seps[i % 2]
|
| 192 |
+
else:
|
| 193 |
+
ret += ""
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 196 |
+
|
| 197 |
+
return ret
|
| 198 |
+
|
| 199 |
+
def append_message(self, role, message):
|
| 200 |
+
self.messages.append([role, message])
|
| 201 |
+
|
| 202 |
+
def process_image(
|
| 203 |
+
self,
|
| 204 |
+
image,
|
| 205 |
+
image_process_mode,
|
| 206 |
+
return_pil=False,
|
| 207 |
+
image_format="PNG",
|
| 208 |
+
max_len=1344,
|
| 209 |
+
min_len=672,
|
| 210 |
+
):
|
| 211 |
+
if image_process_mode == "Pad":
|
| 212 |
+
|
| 213 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
| 214 |
+
width, height = pil_img.size
|
| 215 |
+
if width == height:
|
| 216 |
+
return pil_img
|
| 217 |
+
elif width > height:
|
| 218 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 219 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 220 |
+
return result
|
| 221 |
+
else:
|
| 222 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 223 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 224 |
+
return result
|
| 225 |
+
|
| 226 |
+
image = expand2square(image)
|
| 227 |
+
elif image_process_mode in ["Default", "Crop"]:
|
| 228 |
+
pass
|
| 229 |
+
elif image_process_mode == "Resize":
|
| 230 |
+
image = image.resize((336, 336))
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
| 233 |
+
if max(image.size) > max_len:
|
| 234 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 235 |
+
aspect_ratio = max_hw / min_hw
|
| 236 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 237 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 238 |
+
W, H = image.size
|
| 239 |
+
if H > W:
|
| 240 |
+
H, W = longest_edge, shortest_edge
|
| 241 |
+
else:
|
| 242 |
+
H, W = shortest_edge, longest_edge
|
| 243 |
+
image = image.resize((W, H))
|
| 244 |
+
if return_pil:
|
| 245 |
+
return image
|
| 246 |
+
else:
|
| 247 |
+
buffered = BytesIO()
|
| 248 |
+
image.save(buffered, format=image_format)
|
| 249 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 250 |
+
return img_b64_str
|
| 251 |
+
|
| 252 |
+
def get_images(self, return_pil=False):
|
| 253 |
+
images = []
|
| 254 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 255 |
+
if i % 2 == 0:
|
| 256 |
+
if type(msg) is tuple:
|
| 257 |
+
msg, image, image_process_mode = msg
|
| 258 |
+
image = self.process_image(
|
| 259 |
+
image, image_process_mode, return_pil=return_pil
|
| 260 |
+
)
|
| 261 |
+
images.append(image)
|
| 262 |
+
return images
|
| 263 |
+
|
| 264 |
+
def to_gradio_chatbot(self):
|
| 265 |
+
ret = []
|
| 266 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 267 |
+
if i % 2 == 0:
|
| 268 |
+
if type(msg) is tuple:
|
| 269 |
+
msg, image, image_process_mode = msg
|
| 270 |
+
img_b64_str = self.process_image(
|
| 271 |
+
image, "Default", return_pil=False, image_format="JPEG"
|
| 272 |
+
)
|
| 273 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
| 274 |
+
msg = img_str + msg.replace("<image>", "").strip()
|
| 275 |
+
ret.append([msg, None])
|
| 276 |
+
else:
|
| 277 |
+
ret.append([msg, None])
|
| 278 |
+
else:
|
| 279 |
+
ret[-1][-1] = msg
|
| 280 |
+
return ret
|
| 281 |
+
|
| 282 |
+
def copy(self):
|
| 283 |
+
return Conversation(
|
| 284 |
+
system=self.system,
|
| 285 |
+
roles=self.roles,
|
| 286 |
+
messages=[[x, y] for x, y in self.messages],
|
| 287 |
+
offset=self.offset,
|
| 288 |
+
sep_style=self.sep_style,
|
| 289 |
+
sep=self.sep,
|
| 290 |
+
sep2=self.sep2,
|
| 291 |
+
version=self.version,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def dict(self):
|
| 295 |
+
if len(self.get_images()) > 0:
|
| 296 |
+
return {
|
| 297 |
+
"system": self.system,
|
| 298 |
+
"roles": self.roles,
|
| 299 |
+
"messages": [
|
| 300 |
+
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
|
| 301 |
+
],
|
| 302 |
+
"offset": self.offset,
|
| 303 |
+
"sep": self.sep,
|
| 304 |
+
"sep2": self.sep2,
|
| 305 |
+
}
|
| 306 |
+
return {
|
| 307 |
+
"system": self.system,
|
| 308 |
+
"roles": self.roles,
|
| 309 |
+
"messages": self.messages,
|
| 310 |
+
"offset": self.offset,
|
| 311 |
+
"sep": self.sep,
|
| 312 |
+
"sep2": self.sep2,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
conv_vicuna_v0 = Conversation(
|
| 317 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 318 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 319 |
+
roles=("Human", "Assistant"),
|
| 320 |
+
messages=(
|
| 321 |
+
(
|
| 322 |
+
"Human",
|
| 323 |
+
"What are the key differences between renewable and non-renewable energy sources?",
|
| 324 |
+
),
|
| 325 |
+
(
|
| 326 |
+
"Assistant",
|
| 327 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
| 328 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
| 329 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
| 330 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
| 331 |
+
"renewable and non-renewable energy sources:\n"
|
| 332 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
| 333 |
+
"energy sources are finite and will eventually run out.\n"
|
| 334 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
| 335 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
| 336 |
+
"and other negative effects.\n"
|
| 337 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
| 338 |
+
"have lower operational costs than non-renewable sources.\n"
|
| 339 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
| 340 |
+
"locations than non-renewable sources.\n"
|
| 341 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
| 342 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
| 343 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
| 344 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
| 345 |
+
),
|
| 346 |
+
),
|
| 347 |
+
offset=2,
|
| 348 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 349 |
+
sep="###",
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
conv_vicuna_v1 = Conversation(
|
| 353 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 354 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 355 |
+
roles=("USER", "ASSISTANT"),
|
| 356 |
+
version="v1",
|
| 357 |
+
messages=(),
|
| 358 |
+
offset=0,
|
| 359 |
+
sep_style=SeparatorStyle.TWO,
|
| 360 |
+
sep=" ",
|
| 361 |
+
sep2="</s>",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
conv_llama_2 = Conversation(
|
| 365 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
| 366 |
+
|
| 367 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
| 368 |
+
roles=("USER", "ASSISTANT"),
|
| 369 |
+
version="llama_v2",
|
| 370 |
+
messages=(),
|
| 371 |
+
offset=0,
|
| 372 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 373 |
+
sep="<s>",
|
| 374 |
+
sep2="</s>",
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
conv_llava_llama_2 = Conversation(
|
| 378 |
+
system="You are a helpful language and vision assistant. "
|
| 379 |
+
"You are able to understand the visual content that the user provides, "
|
| 380 |
+
"and assist the user with a variety of tasks using natural language.",
|
| 381 |
+
roles=("USER", "ASSISTANT"),
|
| 382 |
+
version="llama_v2",
|
| 383 |
+
messages=(),
|
| 384 |
+
offset=0,
|
| 385 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 386 |
+
sep="<s>",
|
| 387 |
+
sep2="</s>",
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
conv_mpt = Conversation(
|
| 391 |
+
system="""<|im_start|>system
|
| 392 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
| 393 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 394 |
+
version="mpt",
|
| 395 |
+
messages=(),
|
| 396 |
+
offset=0,
|
| 397 |
+
sep_style=SeparatorStyle.MPT,
|
| 398 |
+
sep="<|im_end|>",
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
conv_llava_plain = Conversation(
|
| 402 |
+
system="",
|
| 403 |
+
roles=("", ""),
|
| 404 |
+
messages=(),
|
| 405 |
+
offset=0,
|
| 406 |
+
sep_style=SeparatorStyle.PLAIN,
|
| 407 |
+
sep="\n",
|
| 408 |
+
version="plain",
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
conv_llava_v0 = Conversation(
|
| 412 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 413 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 414 |
+
roles=("Human", "Assistant"),
|
| 415 |
+
messages=(),
|
| 416 |
+
offset=0,
|
| 417 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 418 |
+
sep="###",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
conv_llava_v0_mmtag = Conversation(
|
| 422 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 423 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
| 424 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
| 425 |
+
roles=("Human", "Assistant"),
|
| 426 |
+
messages=(),
|
| 427 |
+
offset=0,
|
| 428 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 429 |
+
sep="###",
|
| 430 |
+
version="v0_mmtag",
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
conv_llava_v1 = Conversation(
|
| 434 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 435 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 436 |
+
roles=("USER", "ASSISTANT"),
|
| 437 |
+
version="v1",
|
| 438 |
+
messages=(),
|
| 439 |
+
offset=0,
|
| 440 |
+
sep_style=SeparatorStyle.TWO,
|
| 441 |
+
sep=" ",
|
| 442 |
+
sep2="</s>",
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
conv_llava_v1_mmtag = Conversation(
|
| 446 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 447 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
| 448 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
| 449 |
+
roles=("USER", "ASSISTANT"),
|
| 450 |
+
messages=(),
|
| 451 |
+
offset=0,
|
| 452 |
+
sep_style=SeparatorStyle.TWO,
|
| 453 |
+
sep=" ",
|
| 454 |
+
sep2="</s>",
|
| 455 |
+
version="v1_mmtag",
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
conv_mistral_instruct = Conversation(
|
| 459 |
+
system="",
|
| 460 |
+
roles=("USER", "ASSISTANT"),
|
| 461 |
+
version="llama_v2",
|
| 462 |
+
messages=(),
|
| 463 |
+
offset=0,
|
| 464 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 465 |
+
sep="",
|
| 466 |
+
sep2="</s>",
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
conv_chatml_direct = Conversation(
|
| 470 |
+
system="""<|im_start|>system
|
| 471 |
+
Answer the questions.""",
|
| 472 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 473 |
+
version="mpt",
|
| 474 |
+
messages=(),
|
| 475 |
+
offset=0,
|
| 476 |
+
sep_style=SeparatorStyle.MPT,
|
| 477 |
+
sep="<|im_end|>",
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
conv_llama3 = Conversation(
|
| 481 |
+
system="""You are a helpful assistant.""",
|
| 482 |
+
roles=("user", "assistant"),
|
| 483 |
+
version="llama3",
|
| 484 |
+
messages=(),
|
| 485 |
+
offset=0,
|
| 486 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
| 487 |
+
sep="<|eot_id|>",
|
| 488 |
+
)
|
| 489 |
+
conv_llama3_2 = Conversation(
|
| 490 |
+
system="""You are a helpful assistant.""",
|
| 491 |
+
roles=("user", "assistant"),
|
| 492 |
+
version="llama3_2",
|
| 493 |
+
messages=(),
|
| 494 |
+
offset=0,
|
| 495 |
+
sep_style=SeparatorStyle.LLAMA_3_2,
|
| 496 |
+
sep="<|eot_id|>",
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
conv_phi3_instruct = Conversation(
|
| 500 |
+
system="""<|system|>\nYou are a helpful AI assistant.""",
|
| 501 |
+
roles=("\n<|user|>\n", "\n<|assistant|>\n"),
|
| 502 |
+
version="phi3",
|
| 503 |
+
messages=(),
|
| 504 |
+
offset=0,
|
| 505 |
+
sep_style=SeparatorStyle.MPT,
|
| 506 |
+
sep="<|end|>",
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
conv_qwen = Conversation(
|
| 510 |
+
system="""<|im_start|>system
|
| 511 |
+
You are a helpful assistant.""",
|
| 512 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
| 513 |
+
version="qwen",
|
| 514 |
+
messages=[],
|
| 515 |
+
offset=0,
|
| 516 |
+
sep_style=SeparatorStyle.CHATML,
|
| 517 |
+
sep="<|im_end|>",
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
default_conversation = conv_qwen
|
| 521 |
+
conv_templates = {
|
| 522 |
+
"default": conv_vicuna_v0,
|
| 523 |
+
"v0": conv_vicuna_v0,
|
| 524 |
+
"v1": conv_vicuna_v1,
|
| 525 |
+
"vicuna_v1": conv_vicuna_v1,
|
| 526 |
+
"llama_2": conv_llama_2,
|
| 527 |
+
"mistral_instruct": conv_mistral_instruct,
|
| 528 |
+
"chatml_direct": conv_chatml_direct,
|
| 529 |
+
"mistral_direct": conv_chatml_direct,
|
| 530 |
+
"plain": conv_llava_plain,
|
| 531 |
+
"v0_plain": conv_llava_plain,
|
| 532 |
+
"llava_v0": conv_llava_v0,
|
| 533 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
| 534 |
+
"llava_v1": conv_llava_v1,
|
| 535 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
| 536 |
+
"llava_llama_2": conv_llava_llama_2,
|
| 537 |
+
"mpt": conv_mpt,
|
| 538 |
+
"llama3": conv_llama3,
|
| 539 |
+
"llama3_2": conv_llama3_2,
|
| 540 |
+
"phi3": conv_phi3_instruct,
|
| 541 |
+
"qwen": conv_qwen,
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
if __name__ == "__main__":
|
| 545 |
+
print(default_conversation.get_prompt())
|
tempo/language_model/__pycache__/modeling_tempo_qwen.cpython-312.pyc
ADDED
|
Binary file (7.46 kB). View file
|
|
|
tempo/language_model/modeling_tempo_qwen.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from typing import Optional, Union, Callable
|
| 19 |
+
from transformers.utils import logging
|
| 20 |
+
from transformers.cache_utils import Cache
|
| 21 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 22 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 23 |
+
from transformers import Qwen3Config, Qwen3ForCausalLM, Qwen3Model
|
| 24 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 25 |
+
from transformers.generation.streamers import BaseStreamer
|
| 26 |
+
from transformers.generation.utils import (
|
| 27 |
+
GenerateOutput,
|
| 28 |
+
GenerationConfig,
|
| 29 |
+
LogitsProcessorList,
|
| 30 |
+
StoppingCriteriaList,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
from tempo.tempo_arch import TempoMetaForCausalLM, TempoMetaModel
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
class TempoConfig(Qwen3Config):
|
| 39 |
+
model_type = "tempo_qwen"
|
| 40 |
+
debug = "debug"
|
| 41 |
+
|
| 42 |
+
class TempoQwenModel(TempoMetaModel, Qwen3Model):
|
| 43 |
+
config_class = TempoConfig
|
| 44 |
+
|
| 45 |
+
def __init__(self, config: Qwen3Config):
|
| 46 |
+
super(TempoQwenModel, self).__init__(config)
|
| 47 |
+
|
| 48 |
+
class TempoQwenForCausalLM(Qwen3ForCausalLM, TempoMetaForCausalLM):
|
| 49 |
+
config_class = TempoConfig
|
| 50 |
+
|
| 51 |
+
def __init__(self, config):
|
| 52 |
+
super(Qwen3ForCausalLM, self).__init__(config)
|
| 53 |
+
config.model_type = "tempo_qwen"
|
| 54 |
+
config.rope_scaling = None
|
| 55 |
+
|
| 56 |
+
self.model = TempoQwenModel(config)
|
| 57 |
+
self.vocab_size = config.vocab_size
|
| 58 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 59 |
+
|
| 60 |
+
# Initialize weights and apply final processing
|
| 61 |
+
self.post_init()
|
| 62 |
+
|
| 63 |
+
def get_model(self):
|
| 64 |
+
return self.model
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 69 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 70 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 71 |
+
past_key_values: Optional[Cache] = None,
|
| 72 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 73 |
+
labels: Optional[torch.LongTensor] = None,
|
| 74 |
+
use_cache: Optional[bool] = None,
|
| 75 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 76 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 77 |
+
images: Optional[torch.FloatTensor] = None,
|
| 78 |
+
image_sizes: Optional[list[list[int]]] = None,
|
| 79 |
+
vlm_inputs: Optional[dict] = None,
|
| 80 |
+
seg_timestamps: Optional[torch.LongTensor] = None,
|
| 81 |
+
batch_split_size: Optional[list[int]] = None,
|
| 82 |
+
**kwargs,
|
| 83 |
+
) -> CausalLMOutputWithPast:
|
| 84 |
+
if inputs_embeds is None:
|
| 85 |
+
(
|
| 86 |
+
input_ids,
|
| 87 |
+
position_ids,
|
| 88 |
+
attention_mask,
|
| 89 |
+
past_key_values,
|
| 90 |
+
inputs_embeds,
|
| 91 |
+
labels,
|
| 92 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 93 |
+
input_ids,
|
| 94 |
+
position_ids,
|
| 95 |
+
attention_mask,
|
| 96 |
+
past_key_values,
|
| 97 |
+
labels,
|
| 98 |
+
images,
|
| 99 |
+
image_sizes,
|
| 100 |
+
vlm_inputs,
|
| 101 |
+
seg_timestamps,
|
| 102 |
+
batch_split_size,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return super().forward(
|
| 106 |
+
input_ids=input_ids,
|
| 107 |
+
attention_mask=attention_mask,
|
| 108 |
+
position_ids=position_ids,
|
| 109 |
+
past_key_values=past_key_values,
|
| 110 |
+
inputs_embeds=inputs_embeds,
|
| 111 |
+
labels=labels,
|
| 112 |
+
use_cache=use_cache,
|
| 113 |
+
cache_position=cache_position,
|
| 114 |
+
logits_to_keep=logits_to_keep,
|
| 115 |
+
**kwargs,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
def generate(
|
| 120 |
+
self,
|
| 121 |
+
inputs: Optional[torch.Tensor] = None,
|
| 122 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 123 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 124 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 125 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
| 126 |
+
synced_gpus: Optional[bool] = None,
|
| 127 |
+
assistant_model: Optional["PreTrainedModel"] = None,
|
| 128 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 129 |
+
negative_prompt_ids: Optional[torch.Tensor] = None,
|
| 130 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 131 |
+
use_model_defaults: Optional[bool] = None,
|
| 132 |
+
custom_generate: Optional[Union[str, Callable]] = None,
|
| 133 |
+
images: Optional[torch.Tensor] = None,
|
| 134 |
+
image_sizes: Optional[torch.Tensor] = None,
|
| 135 |
+
**kwargs,
|
| 136 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 137 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 138 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 139 |
+
vlm_inputs = kwargs.pop("vlm_inputs", None)
|
| 140 |
+
seg_timestamps = kwargs.pop("seg_timestamps", None)
|
| 141 |
+
relevance = kwargs.pop("relevance", None) # when using external retriever
|
| 142 |
+
|
| 143 |
+
if "inputs_embeds" in kwargs:
|
| 144 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 145 |
+
|
| 146 |
+
if vlm_inputs is not None:
|
| 147 |
+
(
|
| 148 |
+
inputs,
|
| 149 |
+
position_ids,
|
| 150 |
+
attention_mask,
|
| 151 |
+
_,
|
| 152 |
+
inputs_embeds,
|
| 153 |
+
_,
|
| 154 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 155 |
+
inputs,
|
| 156 |
+
position_ids,
|
| 157 |
+
attention_mask,
|
| 158 |
+
None,
|
| 159 |
+
None,
|
| 160 |
+
images,
|
| 161 |
+
image_sizes=image_sizes,
|
| 162 |
+
vlm_inputs=vlm_inputs,
|
| 163 |
+
seg_timestamps=seg_timestamps,
|
| 164 |
+
relevance=relevance,
|
| 165 |
+
)
|
| 166 |
+
elif images is not None:
|
| 167 |
+
(
|
| 168 |
+
inputs,
|
| 169 |
+
position_ids,
|
| 170 |
+
attention_mask,
|
| 171 |
+
_,
|
| 172 |
+
inputs_embeds,
|
| 173 |
+
_,
|
| 174 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 175 |
+
inputs,
|
| 176 |
+
position_ids,
|
| 177 |
+
attention_mask,
|
| 178 |
+
None,
|
| 179 |
+
None,
|
| 180 |
+
images,
|
| 181 |
+
image_sizes=image_sizes,
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
| 185 |
+
|
| 186 |
+
# if attention_mask is None:
|
| 187 |
+
# # avoid warning
|
| 188 |
+
# attention_mask = torch.ones(
|
| 189 |
+
# inputs_embeds.shape[:2],
|
| 190 |
+
# dtype=torch.long,
|
| 191 |
+
# device=inputs_embeds.device
|
| 192 |
+
# )
|
| 193 |
+
|
| 194 |
+
return super().generate(
|
| 195 |
+
inputs=None,
|
| 196 |
+
generation_config=generation_config,
|
| 197 |
+
logits_processor=logits_processor,
|
| 198 |
+
stopping_criteria=stopping_criteria,
|
| 199 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 200 |
+
synced_gpus=synced_gpus,
|
| 201 |
+
assistant_model=assistant_model,
|
| 202 |
+
streamer=streamer,
|
| 203 |
+
negative_prompt_ids=negative_prompt_ids,
|
| 204 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 205 |
+
use_model_defaults=use_model_defaults,
|
| 206 |
+
custom_generate=custom_generate,
|
| 207 |
+
position_ids=position_ids,
|
| 208 |
+
attention_mask=attention_mask,
|
| 209 |
+
inputs_embeds=inputs_embeds,
|
| 210 |
+
**kwargs,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def prepare_inputs_for_generation(
|
| 214 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
| 215 |
+
):
|
| 216 |
+
images = kwargs.pop("images", None)
|
| 217 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
| 218 |
+
inputs = super().prepare_inputs_for_generation(
|
| 219 |
+
input_ids,
|
| 220 |
+
past_key_values=past_key_values,
|
| 221 |
+
inputs_embeds=inputs_embeds,
|
| 222 |
+
**kwargs,
|
| 223 |
+
)
|
| 224 |
+
if images is not None:
|
| 225 |
+
inputs["images"] = images
|
| 226 |
+
if image_sizes is not None:
|
| 227 |
+
inputs["image_sizes"] = image_sizes
|
| 228 |
+
return inputs
|
| 229 |
+
|
| 230 |
+
AutoConfig.register("tempo_qwen", TempoConfig)
|
| 231 |
+
AutoModelForCausalLM.register(TempoConfig, TempoQwenForCausalLM)
|
tempo/mm_datautils.py
ADDED
|
@@ -0,0 +1,1607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
from typing import List
|
| 4 |
+
from packaging import version
|
| 5 |
+
from collections.abc import Sequence
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import tokenizers
|
| 9 |
+
import transformers
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from transformers import StoppingCriteria
|
| 13 |
+
from qwen_vl_utils import process_vision_info
|
| 14 |
+
from torch import distributed as dist
|
| 15 |
+
from torch.distributed.fsdp import (
|
| 16 |
+
FullStateDictConfig,
|
| 17 |
+
FullyShardedDataParallel as FSDP,
|
| 18 |
+
StateDictType,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from tempo import conversation as conversation_lib
|
| 22 |
+
from tempo.constants import (
|
| 23 |
+
DEFAULT_IM_END_TOKEN,
|
| 24 |
+
DEFAULT_IM_START_TOKEN,
|
| 25 |
+
DEFAULT_IMAGE_TOKEN,
|
| 26 |
+
IGNORE_INDEX,
|
| 27 |
+
IMAGE_TOKEN_INDEX,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse(
|
| 31 |
+
"0.14"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 35 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 36 |
+
self.keywords = keywords
|
| 37 |
+
self.keyword_ids = []
|
| 38 |
+
self.max_keyword_len = 0
|
| 39 |
+
for keyword in keywords:
|
| 40 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 41 |
+
if (
|
| 42 |
+
len(cur_keyword_ids) > 1
|
| 43 |
+
and cur_keyword_ids[0] == tokenizer.bos_token_id
|
| 44 |
+
):
|
| 45 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 46 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 47 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 48 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 49 |
+
self.tokenizer = tokenizer
|
| 50 |
+
self.start_len = input_ids.shape[1]
|
| 51 |
+
|
| 52 |
+
def call_for_batch(
|
| 53 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 54 |
+
) -> bool:
|
| 55 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 56 |
+
self.keyword_ids = [
|
| 57 |
+
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
|
| 58 |
+
]
|
| 59 |
+
for keyword_id in self.keyword_ids:
|
| 60 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0] :]
|
| 61 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
| 62 |
+
return True
|
| 63 |
+
outputs = self.tokenizer.batch_decode(
|
| 64 |
+
output_ids[:, -offset:], skip_special_tokens=True
|
| 65 |
+
)[0]
|
| 66 |
+
for keyword in self.keywords:
|
| 67 |
+
if keyword in outputs:
|
| 68 |
+
return True
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
def __call__(
|
| 72 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 73 |
+
) -> bool:
|
| 74 |
+
outputs = []
|
| 75 |
+
for i in range(output_ids.shape[0]):
|
| 76 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
| 77 |
+
return all(outputs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def safe_save_model_for_hf_trainer(
|
| 81 |
+
trainer: transformers.Trainer, output_dir: str
|
| 82 |
+
) -> None:
|
| 83 |
+
"""Collects the state dict and dump to disk."""
|
| 84 |
+
global_rank = dist.get_rank()
|
| 85 |
+
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
| 86 |
+
if len(trainer.args.fsdp) == 0:
|
| 87 |
+
cpu_state_dict = trainer.model.state_dict()
|
| 88 |
+
else:
|
| 89 |
+
with FSDP.state_dict_type(
|
| 90 |
+
trainer.model, StateDictType.FULL_STATE_DICT, save_policy
|
| 91 |
+
):
|
| 92 |
+
cpu_state_dict = trainer.model.state_dict()
|
| 93 |
+
|
| 94 |
+
for key in cpu_state_dict.keys():
|
| 95 |
+
cpu_state_dict[key] = cpu_state_dict[key].to(torch.bfloat16)
|
| 96 |
+
|
| 97 |
+
if global_rank == 0:
|
| 98 |
+
trainer.model.config.save_pretrained(output_dir)
|
| 99 |
+
current_folder = output_dir.split("/")[-1]
|
| 100 |
+
parent_folder = os.path.dirname(output_dir)
|
| 101 |
+
save_path = os.path.join(output_dir, "pytorch_model.bin")
|
| 102 |
+
if getattr(trainer.args, "tune_mm_mlp_adapter", False) and not getattr(
|
| 103 |
+
trainer.args, "tune_text_decoder", False
|
| 104 |
+
):
|
| 105 |
+
# Only save Adapter
|
| 106 |
+
keys_to_match = ["mm_projector"]
|
| 107 |
+
if getattr(trainer.args, "use_im_start_end", False):
|
| 108 |
+
keys_to_match.extend(["embed_tokens", "embed_in"])
|
| 109 |
+
|
| 110 |
+
freeze_layer_remove = []
|
| 111 |
+
for key in cpu_state_dict.keys():
|
| 112 |
+
remove = True
|
| 113 |
+
for key_match in keys_to_match:
|
| 114 |
+
if key_match in key:
|
| 115 |
+
remove = False
|
| 116 |
+
break
|
| 117 |
+
if remove:
|
| 118 |
+
freeze_layer_remove.append(key)
|
| 119 |
+
for key in freeze_layer_remove:
|
| 120 |
+
del cpu_state_dict[key]
|
| 121 |
+
|
| 122 |
+
if current_folder.startswith("checkpoint-"):
|
| 123 |
+
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
| 124 |
+
os.makedirs(mm_projector_folder, exist_ok=True)
|
| 125 |
+
save_path = os.path.join(mm_projector_folder, f"{current_folder}.bin")
|
| 126 |
+
else:
|
| 127 |
+
save_path = os.path.join(output_dir, f"mm_projector.bin")
|
| 128 |
+
torch.save(cpu_state_dict, save_path)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _tokenize_fn(
|
| 132 |
+
strings: Sequence[str],
|
| 133 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 134 |
+
) -> dict:
|
| 135 |
+
"""Tokenize a list of strings."""
|
| 136 |
+
tokenized_list = [
|
| 137 |
+
tokenizer(
|
| 138 |
+
text,
|
| 139 |
+
return_tensors="pt",
|
| 140 |
+
padding="longest",
|
| 141 |
+
max_length=tokenizer.model_max_length,
|
| 142 |
+
truncation=True,
|
| 143 |
+
)
|
| 144 |
+
for text in strings
|
| 145 |
+
]
|
| 146 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
| 147 |
+
input_ids_lens = labels_lens = [
|
| 148 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
| 149 |
+
for tokenized in tokenized_list
|
| 150 |
+
]
|
| 151 |
+
return dict(
|
| 152 |
+
input_ids=input_ids,
|
| 153 |
+
labels=labels,
|
| 154 |
+
input_ids_lens=input_ids_lens,
|
| 155 |
+
labels_lens=labels_lens,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _mask_targets(target, tokenized_lens, speakers) -> None:
|
| 160 |
+
# cur_idx = 0
|
| 161 |
+
cur_idx = tokenized_lens[0]
|
| 162 |
+
tokenized_lens = tokenized_lens[1:]
|
| 163 |
+
target[:cur_idx] = IGNORE_INDEX
|
| 164 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
| 165 |
+
if speaker == "human":
|
| 166 |
+
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
|
| 167 |
+
cur_idx += tokenized_len
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _add_speaker_and_signal(header, source, get_conversation: bool = True):
|
| 171 |
+
"""Add speaker and start/end signal on each round."""
|
| 172 |
+
BEGIN_SIGNAL = "### "
|
| 173 |
+
END_SIGNAL = "\n"
|
| 174 |
+
conversation = header
|
| 175 |
+
for sentence in source:
|
| 176 |
+
from_str = sentence["from"]
|
| 177 |
+
if from_str.lower() == "human":
|
| 178 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
| 179 |
+
elif from_str.lower() == "gpt":
|
| 180 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
| 181 |
+
else:
|
| 182 |
+
from_str = "unknown"
|
| 183 |
+
sentence["value"] = (
|
| 184 |
+
BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
|
| 185 |
+
)
|
| 186 |
+
if get_conversation:
|
| 187 |
+
conversation += sentence["value"]
|
| 188 |
+
conversation += BEGIN_SIGNAL
|
| 189 |
+
return conversation
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def expand2square(pil_img, background_color):
|
| 193 |
+
width, height = pil_img.size
|
| 194 |
+
if width == height:
|
| 195 |
+
return pil_img
|
| 196 |
+
elif width > height:
|
| 197 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 198 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 199 |
+
return result
|
| 200 |
+
else:
|
| 201 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 202 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 203 |
+
return result
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def crop2square(pil_img):
|
| 207 |
+
width, height = pil_img.size
|
| 208 |
+
if width == height:
|
| 209 |
+
return pil_img
|
| 210 |
+
elif width > height:
|
| 211 |
+
left = (width - height) // 2
|
| 212 |
+
right = left + height
|
| 213 |
+
top = 0
|
| 214 |
+
bottom = height
|
| 215 |
+
return pil_img.crop((left, top, right, bottom))
|
| 216 |
+
else:
|
| 217 |
+
top = (height - width) // 2
|
| 218 |
+
bottom = top + width
|
| 219 |
+
left = 0
|
| 220 |
+
right = width
|
| 221 |
+
return pil_img.crop((left, top, right, bottom))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def perpare_input_for_qwen_input(chunk_dict, pad_token_ids):
|
| 225 |
+
"""Currently, only batch size = 1 is supported for evaluation."""
|
| 226 |
+
|
| 227 |
+
qwenvl_input_dict = {}
|
| 228 |
+
has_video = any(["video" in key for key in list(chunk_dict.keys())])
|
| 229 |
+
|
| 230 |
+
qwenvl_input_dict["input_ids"] = torch.nn.utils.rnn.pad_sequence(
|
| 231 |
+
chunk_dict["vlm_input_ids"],
|
| 232 |
+
batch_first=True,
|
| 233 |
+
padding_value=pad_token_ids,
|
| 234 |
+
)
|
| 235 |
+
qwenvl_input_dict["attention_mask"] = torch.nn.utils.rnn.pad_sequence(
|
| 236 |
+
chunk_dict["vlm_attention_mask"],
|
| 237 |
+
batch_first=True,
|
| 238 |
+
padding_value=0,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if has_video:
|
| 242 |
+
qwenvl_input_dict["pixel_values_videos"] = torch.cat(chunk_dict["pixel_values_videos"], dim=0)
|
| 243 |
+
qwenvl_input_dict["video_grid_thw"] = torch.cat(chunk_dict["video_grid_thw"], dim=0)
|
| 244 |
+
else:
|
| 245 |
+
qwenvl_input_dict["pixel_values"] = torch.cat(chunk_dict["pixel_values"], dim=0)
|
| 246 |
+
qwenvl_input_dict["image_grid_thw"] = torch.cat(chunk_dict["image_grid_thw"], dim=0)
|
| 247 |
+
|
| 248 |
+
return qwenvl_input_dict
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def construct_message(
|
| 252 |
+
content_data,
|
| 253 |
+
data_type,
|
| 254 |
+
query,
|
| 255 |
+
multimodal_processor,
|
| 256 |
+
sample_fps=1,
|
| 257 |
+
return_text=False,
|
| 258 |
+
):
|
| 259 |
+
# # prompt 0 (using during training)
|
| 260 |
+
# system_message = (
|
| 261 |
+
# "You are a query-conditioned visual compressor. "
|
| 262 |
+
# "Store in the provided memory tokens the minimal visual information needed to answer the Query. "
|
| 263 |
+
# "Ignore irrelevant details."
|
| 264 |
+
# )
|
| 265 |
+
# prompt 1 (using during inference)
|
| 266 |
+
system_message = (
|
| 267 |
+
"You are a query-conditioned visual compressor. "
|
| 268 |
+
"Store in the provided memory tokens the minimal visual information needed to answer the Query. "
|
| 269 |
+
"Ignore irrelevant details. "
|
| 270 |
+
"Now, before compressing, answer exactly 'Yes' or 'No': is this segment relevant to the Query?"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
user_message = f"\nQuery:\n{query}"
|
| 274 |
+
# assistant_message = "Scanning for target features... The visual confidence representation is:"
|
| 275 |
+
assistant_message = None
|
| 276 |
+
|
| 277 |
+
if data_type == "image":
|
| 278 |
+
messages = [
|
| 279 |
+
{
|
| 280 |
+
"role": "system",
|
| 281 |
+
"content": [{"type": "text", "text": system_message}],
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"role": "user",
|
| 285 |
+
"content": [
|
| 286 |
+
{"type": "image", "image": content_data},
|
| 287 |
+
{"type": "text", "text": user_message},
|
| 288 |
+
],
|
| 289 |
+
},
|
| 290 |
+
]
|
| 291 |
+
elif data_type == "video":
|
| 292 |
+
messages = [
|
| 293 |
+
{
|
| 294 |
+
"role": "system",
|
| 295 |
+
"content": [{"type": "text", "text": system_message}],
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"role": "user",
|
| 299 |
+
"content": [
|
| 300 |
+
{
|
| 301 |
+
"type": "video",
|
| 302 |
+
"video": content_data,
|
| 303 |
+
"sample_fps": sample_fps,
|
| 304 |
+
},
|
| 305 |
+
{"type": "text", "text": user_message},
|
| 306 |
+
],
|
| 307 |
+
},
|
| 308 |
+
]
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(f"Unknown data type: {data_type}")
|
| 311 |
+
|
| 312 |
+
if return_text:
|
| 313 |
+
messages = multimodal_processor.apply_chat_template(
|
| 314 |
+
messages,
|
| 315 |
+
tokenize=False,
|
| 316 |
+
add_generation_prompt=True,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if assistant_message is not None:
|
| 320 |
+
messages = messages + assistant_message
|
| 321 |
+
|
| 322 |
+
return messages
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def video_process_with_frame_idx(
|
| 326 |
+
chunk_frames, multimodal_processor, query, sample_fps=2, frame_offset=0
|
| 327 |
+
):
|
| 328 |
+
messages = construct_message(
|
| 329 |
+
chunk_frames, "video", query, multimodal_processor, sample_fps
|
| 330 |
+
)
|
| 331 |
+
text = multimodal_processor.apply_chat_template(
|
| 332 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 333 |
+
)
|
| 334 |
+
image_inputs, video_inputs, video_kwargs = process_vision_info(
|
| 335 |
+
[messages],
|
| 336 |
+
return_video_kwargs=True,
|
| 337 |
+
image_patch_size=16,
|
| 338 |
+
return_video_metadata=True,
|
| 339 |
+
)
|
| 340 |
+
if video_inputs is not None:
|
| 341 |
+
video_inputs, video_metadatas = zip(*video_inputs)
|
| 342 |
+
video_inputs, video_metadatas = (
|
| 343 |
+
list(video_inputs),
|
| 344 |
+
list(video_metadatas),
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
video_metadatas = None
|
| 348 |
+
|
| 349 |
+
if video_metadatas is not None:
|
| 350 |
+
video_metadatas[0]["frames_indices"] = [
|
| 351 |
+
f + frame_offset for f in video_metadatas[0]["frames_indices"]
|
| 352 |
+
]
|
| 353 |
+
|
| 354 |
+
inputs = multimodal_processor(
|
| 355 |
+
text=[text],
|
| 356 |
+
images=image_inputs,
|
| 357 |
+
videos=video_inputs,
|
| 358 |
+
video_metadata=video_metadatas,
|
| 359 |
+
**video_kwargs,
|
| 360 |
+
do_resize=False,
|
| 361 |
+
return_tensors="pt",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
return inputs
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def process_qwen_content(
|
| 368 |
+
content_data,
|
| 369 |
+
data_type,
|
| 370 |
+
sources,
|
| 371 |
+
multimodal_processor,
|
| 372 |
+
real_fps=None,
|
| 373 |
+
frame_windows=8,
|
| 374 |
+
frame_stride=8,
|
| 375 |
+
is_eval=False,
|
| 376 |
+
):
|
| 377 |
+
"""
|
| 378 |
+
content_data:
|
| 379 |
+
- 'image': PIL.Image
|
| 380 |
+
- 'video': List[PIL.Image] or np.ndarray (T, H, W, C)
|
| 381 |
+
data_type: 'text', 'image' or 'video'
|
| 382 |
+
"""
|
| 383 |
+
|
| 384 |
+
# query process
|
| 385 |
+
if is_eval:
|
| 386 |
+
# for evaluation, please input only text string
|
| 387 |
+
assert isinstance(sources, str), "During evaluation, sources should be a single query string."
|
| 388 |
+
query = sources
|
| 389 |
+
else:
|
| 390 |
+
conversation = sources[0]
|
| 391 |
+
if isinstance(conversation, list):
|
| 392 |
+
# This is acceptable during training to learn better representations,
|
| 393 |
+
# but cannot be used during inference as it may lead to data leakage.
|
| 394 |
+
# Currently, only single-turn dialogue is supported during inference.
|
| 395 |
+
# Set the maximum number of dialogue turns to 8.
|
| 396 |
+
human_queries = []
|
| 397 |
+
for turn in conversation[:16]:
|
| 398 |
+
if turn.get("from") == "human":
|
| 399 |
+
clean_text = turn["value"].replace("<image>", "").strip()
|
| 400 |
+
if clean_text:
|
| 401 |
+
human_queries.append(clean_text)
|
| 402 |
+
|
| 403 |
+
query = "\n".join(
|
| 404 |
+
[f"Context turn {i + 1}: {q}" for i, q in enumerate(human_queries)]
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
query = "Describe this content."
|
| 408 |
+
if not query.strip():
|
| 409 |
+
query = "Describe this content."
|
| 410 |
+
|
| 411 |
+
chunk_results = []
|
| 412 |
+
|
| 413 |
+
def resize_images(frames, resolution=512):
|
| 414 |
+
resized_frames = []
|
| 415 |
+
for f in frames:
|
| 416 |
+
w, h = f.size
|
| 417 |
+
max_edge = max(w, h)
|
| 418 |
+
if max_edge > resolution:
|
| 419 |
+
ratio = resolution / max_edge
|
| 420 |
+
new_w = int(round((w * ratio) / 16) * 16)
|
| 421 |
+
new_h = int(round((h * ratio) / 16) * 16)
|
| 422 |
+
new_w = max(16, new_w)
|
| 423 |
+
new_h = max(16, new_h)
|
| 424 |
+
f = f.resize((new_w, new_h), resample=Image.Resampling.BICUBIC)
|
| 425 |
+
resized_frames.append(f)
|
| 426 |
+
return resized_frames
|
| 427 |
+
|
| 428 |
+
# === Text ===
|
| 429 |
+
if data_type == "text":
|
| 430 |
+
content_data = Image.new("RGB", (336, 336), color=(255, 255, 255)) # dummy image
|
| 431 |
+
messages = construct_message(
|
| 432 |
+
content_data, "image", query, multimodal_processor, return_text=True
|
| 433 |
+
) # str
|
| 434 |
+
inputs = multimodal_processor(
|
| 435 |
+
text=[messages],
|
| 436 |
+
images=[content_data],
|
| 437 |
+
padding=False,
|
| 438 |
+
return_tensors="pt",
|
| 439 |
+
)
|
| 440 |
+
chunk_results.append(inputs)
|
| 441 |
+
|
| 442 |
+
# === Image ===
|
| 443 |
+
elif data_type == "image":
|
| 444 |
+
if isinstance(content_data, list):
|
| 445 |
+
# multi-image
|
| 446 |
+
content_data = resize_images(content_data, resolution=512)
|
| 447 |
+
messages = construct_message(
|
| 448 |
+
content_data[0], data_type, query, multimodal_processor, return_text=True,
|
| 449 |
+
) # str
|
| 450 |
+
inputs = multimodal_processor(
|
| 451 |
+
text=[messages] * len(content_data),
|
| 452 |
+
images=content_data,
|
| 453 |
+
padding=True,
|
| 454 |
+
return_tensors="pt",
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
messages = construct_message(
|
| 458 |
+
content_data, data_type, query, multimodal_processor, return_text=True
|
| 459 |
+
) # str
|
| 460 |
+
inputs = multimodal_processor(
|
| 461 |
+
text=[messages],
|
| 462 |
+
images=[content_data],
|
| 463 |
+
padding=True,
|
| 464 |
+
return_tensors="pt",
|
| 465 |
+
)
|
| 466 |
+
chunk_results.append(inputs)
|
| 467 |
+
|
| 468 |
+
# === Video ===
|
| 469 |
+
elif data_type == "video":
|
| 470 |
+
if isinstance(content_data, np.ndarray):
|
| 471 |
+
frames = [Image.fromarray(f) for f in content_data]
|
| 472 |
+
else:
|
| 473 |
+
frames = content_data
|
| 474 |
+
|
| 475 |
+
if frames:
|
| 476 |
+
frames = resize_images(frames, resolution=512)
|
| 477 |
+
|
| 478 |
+
total_frames = len(frames)
|
| 479 |
+
window_size = frame_windows
|
| 480 |
+
stride = frame_stride
|
| 481 |
+
|
| 482 |
+
for i in range(0, total_frames, stride):
|
| 483 |
+
start_idx = i
|
| 484 |
+
frame_offset = i # use to compute timestamp
|
| 485 |
+
end_idx = min(start_idx + window_size, total_frames)
|
| 486 |
+
|
| 487 |
+
chunk_frames = frames[start_idx:end_idx]
|
| 488 |
+
# if len(chunk_frames) < window_size:
|
| 489 |
+
# chunk_frames.extend(
|
| 490 |
+
# [chunk_frames[-1]] * (window_size - len(chunk_frames))
|
| 491 |
+
# )
|
| 492 |
+
|
| 493 |
+
if len(chunk_frames) < window_size and len(chunk_frames) == 1:
|
| 494 |
+
chunk_frames.append(chunk_frames[-1])
|
| 495 |
+
print(f"Qwen processor requires at least 2 frames as video input, copy last frame to {len(chunk_frames)}")
|
| 496 |
+
|
| 497 |
+
inputs = video_process_with_frame_idx(
|
| 498 |
+
chunk_frames, multimodal_processor, query, real_fps, frame_offset,
|
| 499 |
+
)
|
| 500 |
+
chunk_results.append(inputs)
|
| 501 |
+
|
| 502 |
+
else:
|
| 503 |
+
raise ValueError(f"Unknown data type: {data_type}")
|
| 504 |
+
|
| 505 |
+
# ============Group for batch process===================
|
| 506 |
+
chunk_dict = {}
|
| 507 |
+
for key in chunk_results[0]:
|
| 508 |
+
if key in ["input_ids", "attention_mask"]:
|
| 509 |
+
chunk_dict[f"vlm_{key}"] = [r[key].squeeze(dim=0) for r in chunk_results]
|
| 510 |
+
else:
|
| 511 |
+
chunk_dict[key] = [r[key] for r in chunk_results]
|
| 512 |
+
|
| 513 |
+
if is_eval:
|
| 514 |
+
return perpare_input_for_qwen_input(
|
| 515 |
+
chunk_dict, multimodal_processor.tokenizer.pad_token_id
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
return chunk_dict
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def compute_segment_timestamp(
|
| 522 |
+
num_segments,
|
| 523 |
+
tokenizer,
|
| 524 |
+
real_fps,
|
| 525 |
+
stride=None,
|
| 526 |
+
window_size=None,
|
| 527 |
+
use_center_timestamp=True,
|
| 528 |
+
):
|
| 529 |
+
"""
|
| 530 |
+
The current version only supports non-overlapping segments.
|
| 531 |
+
You need to modify the timestamp computation to support overlapping segments.
|
| 532 |
+
"""
|
| 533 |
+
step = stride if stride is not None else window_size
|
| 534 |
+
fps = real_fps if real_fps and real_fps > 0 else 1.0
|
| 535 |
+
|
| 536 |
+
seg_timestamps_ids = []
|
| 537 |
+
for i in range(num_segments):
|
| 538 |
+
start_frame_idx = i * step
|
| 539 |
+
if use_center_timestamp:
|
| 540 |
+
frame_idx = start_frame_idx + (window_size / 2)
|
| 541 |
+
else:
|
| 542 |
+
frame_idx = start_frame_idx
|
| 543 |
+
cur_timestamp_sec = frame_idx / fps
|
| 544 |
+
text = f"<{cur_timestamp_sec:.1f} seconds>"
|
| 545 |
+
|
| 546 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 547 |
+
seg_timestamps_ids.append(ids)
|
| 548 |
+
|
| 549 |
+
return seg_timestamps_ids
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def compute_sample_indices(
|
| 553 |
+
total_frames: int,
|
| 554 |
+
original_fps: float,
|
| 555 |
+
target_fps: float,
|
| 556 |
+
min_frames: int,
|
| 557 |
+
max_frames: int,
|
| 558 |
+
) -> List[int]:
|
| 559 |
+
if total_frames <= 1:
|
| 560 |
+
return [0]
|
| 561 |
+
|
| 562 |
+
if original_fps is None or original_fps <= 0:
|
| 563 |
+
original_fps = target_fps
|
| 564 |
+
|
| 565 |
+
video_duration = total_frames / original_fps
|
| 566 |
+
target_num_frames = max(1, round(video_duration * target_fps))
|
| 567 |
+
|
| 568 |
+
final_num_frames = target_num_frames
|
| 569 |
+
if final_num_frames < min_frames:
|
| 570 |
+
print(
|
| 571 |
+
f"Upsampling video from {target_num_frames} to {min_frames} frames (min_frames limit)."
|
| 572 |
+
)
|
| 573 |
+
final_num_frames = min_frames
|
| 574 |
+
elif final_num_frames > max_frames:
|
| 575 |
+
print(
|
| 576 |
+
f"Downsampling video from {target_num_frames} to {max_frames} frames (max_frames limit)."
|
| 577 |
+
)
|
| 578 |
+
final_num_frames = max_frames
|
| 579 |
+
|
| 580 |
+
if final_num_frames == 1:
|
| 581 |
+
return [total_frames - 1]
|
| 582 |
+
|
| 583 |
+
indices = np.linspace(0, total_frames - 1, final_num_frames).astype(int)
|
| 584 |
+
indices = np.clip(indices, 0, total_frames - 1)
|
| 585 |
+
|
| 586 |
+
return indices.tolist()
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def process_images(images, image_processor, model_cfg):
|
| 590 |
+
# if image_processor is None:
|
| 591 |
+
# raise ValueError("image_processor cannot be None")
|
| 592 |
+
if isinstance(image_processor, list):
|
| 593 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
| 594 |
+
processor_aux_list = image_processor
|
| 595 |
+
new_images_aux_list = []
|
| 596 |
+
for image in images:
|
| 597 |
+
if isinstance(image, np.ndarray):
|
| 598 |
+
image = Image.fromarray(image)
|
| 599 |
+
image_aux_list = []
|
| 600 |
+
for processor_aux in processor_aux_list:
|
| 601 |
+
image_aux = image
|
| 602 |
+
if hasattr(processor_aux, "image_mean"):
|
| 603 |
+
try:
|
| 604 |
+
target_resolution = processor_aux.crop_size["height"]
|
| 605 |
+
except:
|
| 606 |
+
target_resolution = processor_aux.size["height"]
|
| 607 |
+
# image_aux = expand2square(
|
| 608 |
+
# image_aux, tuple(int(x * 255) for x in processor_aux.image_mean)
|
| 609 |
+
# ).resize((target_resolution, target_resolution))
|
| 610 |
+
if image_aspect_ratio == "pad":
|
| 611 |
+
image_aux = expand2square(
|
| 612 |
+
image_aux,
|
| 613 |
+
tuple(int(x * 255) for x in processor_aux.image_mean),
|
| 614 |
+
)
|
| 615 |
+
elif image_aspect_ratio == "crop":
|
| 616 |
+
image_aux = crop2square(image_aux)
|
| 617 |
+
image_aux = image_aux.resize((target_resolution, target_resolution))
|
| 618 |
+
|
| 619 |
+
image_aux = processor_aux.preprocess(image_aux, return_tensors="pt")[
|
| 620 |
+
"pixel_values"
|
| 621 |
+
][0]
|
| 622 |
+
image_aux_list.append(image_aux)
|
| 623 |
+
new_images_aux_list.append(image_aux_list)
|
| 624 |
+
new_images_aux_list = [
|
| 625 |
+
list(batch_image_aux) for batch_image_aux in zip(*new_images_aux_list)
|
| 626 |
+
]
|
| 627 |
+
new_images_aux_list = [
|
| 628 |
+
torch.stack(image_aux).half().cuda() for image_aux in new_images_aux_list
|
| 629 |
+
]
|
| 630 |
+
return new_images_aux_list
|
| 631 |
+
else:
|
| 632 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
| 633 |
+
new_images = []
|
| 634 |
+
if image_aspect_ratio == "pad":
|
| 635 |
+
for image in images:
|
| 636 |
+
image = expand2square(
|
| 637 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
| 638 |
+
)
|
| 639 |
+
image = image_processor.preprocess(image, return_tensors="pt")[
|
| 640 |
+
"pixel_values"
|
| 641 |
+
][0]
|
| 642 |
+
new_images.append(image)
|
| 643 |
+
elif image_aspect_ratio == "crop":
|
| 644 |
+
for image in images:
|
| 645 |
+
image = crop2square(image)
|
| 646 |
+
image = image_processor.preprocess(image, return_tensors="pt")[
|
| 647 |
+
"pixel_values"
|
| 648 |
+
][0]
|
| 649 |
+
new_images.append(image)
|
| 650 |
+
else:
|
| 651 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
| 652 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
| 653 |
+
new_images = torch.stack(new_images, dim=0)
|
| 654 |
+
return new_images
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def preprocess_multimodal(sources: Sequence[str], data_args) -> dict:
|
| 658 |
+
is_multimodal = data_args.is_multimodal
|
| 659 |
+
if not is_multimodal:
|
| 660 |
+
return sources
|
| 661 |
+
|
| 662 |
+
for source in sources:
|
| 663 |
+
for sentence in source:
|
| 664 |
+
num_im = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
|
| 665 |
+
if num_im == 1 or "<video>" in sentence["value"]:
|
| 666 |
+
# process only when the vision info is not multi-images
|
| 667 |
+
sentence["value"] = (
|
| 668 |
+
sentence["value"]
|
| 669 |
+
.replace(DEFAULT_IMAGE_TOKEN, "")
|
| 670 |
+
.replace("<video>", "")
|
| 671 |
+
.strip()
|
| 672 |
+
)
|
| 673 |
+
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
|
| 674 |
+
sentence["value"] = sentence["value"].strip()
|
| 675 |
+
if "mmtag" in conversation_lib.default_conversation.version:
|
| 676 |
+
sentence["value"] = sentence["value"].replace(
|
| 677 |
+
DEFAULT_IMAGE_TOKEN,
|
| 678 |
+
"<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>",
|
| 679 |
+
)
|
| 680 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
| 681 |
+
if data_args.mm_use_im_start_end:
|
| 682 |
+
replace_token = (
|
| 683 |
+
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 684 |
+
)
|
| 685 |
+
sentence["value"] = sentence["value"].replace(
|
| 686 |
+
DEFAULT_IMAGE_TOKEN, replace_token
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
return sources
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def preprocess_llama_2(
|
| 693 |
+
sources,
|
| 694 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 695 |
+
has_image: bool = False,
|
| 696 |
+
) -> dict:
|
| 697 |
+
conv = conversation_lib.default_conversation.copy()
|
| 698 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 699 |
+
|
| 700 |
+
# Apply prompt templates
|
| 701 |
+
conversations = []
|
| 702 |
+
for i, source in enumerate(sources):
|
| 703 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 704 |
+
# Skip the first one if it is not from human
|
| 705 |
+
source = source[1:]
|
| 706 |
+
|
| 707 |
+
conv.messages = []
|
| 708 |
+
for j, sentence in enumerate(source):
|
| 709 |
+
role = roles[sentence["from"]]
|
| 710 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 711 |
+
conv.append_message(role, sentence["value"])
|
| 712 |
+
conversations.append(conv.get_prompt())
|
| 713 |
+
|
| 714 |
+
# Tokenize conversations
|
| 715 |
+
|
| 716 |
+
if has_image:
|
| 717 |
+
input_ids = torch.stack(
|
| 718 |
+
[
|
| 719 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 720 |
+
for prompt in conversations
|
| 721 |
+
],
|
| 722 |
+
dim=0,
|
| 723 |
+
)
|
| 724 |
+
else:
|
| 725 |
+
input_ids = tokenizer(
|
| 726 |
+
conversations,
|
| 727 |
+
return_tensors="pt",
|
| 728 |
+
padding="longest",
|
| 729 |
+
max_length=tokenizer.model_max_length,
|
| 730 |
+
truncation=True,
|
| 731 |
+
).input_ids
|
| 732 |
+
|
| 733 |
+
targets = input_ids.clone()
|
| 734 |
+
|
| 735 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
| 736 |
+
|
| 737 |
+
# Mask targets
|
| 738 |
+
sep = "[/INST] "
|
| 739 |
+
for conversation, target in zip(conversations, targets):
|
| 740 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 741 |
+
|
| 742 |
+
rounds = conversation.split(conv.sep2)
|
| 743 |
+
cur_len = 1
|
| 744 |
+
target[:cur_len] = IGNORE_INDEX
|
| 745 |
+
for rou in rounds:
|
| 746 |
+
if rou == "":
|
| 747 |
+
break
|
| 748 |
+
|
| 749 |
+
parts = rou.split(sep)
|
| 750 |
+
if len(parts) != 2:
|
| 751 |
+
break
|
| 752 |
+
parts[0] += sep
|
| 753 |
+
|
| 754 |
+
if has_image:
|
| 755 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 756 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
| 757 |
+
else:
|
| 758 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 759 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 760 |
+
|
| 761 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 762 |
+
|
| 763 |
+
cur_len += round_len
|
| 764 |
+
target[cur_len:] = IGNORE_INDEX
|
| 765 |
+
|
| 766 |
+
if cur_len < tokenizer.model_max_length:
|
| 767 |
+
if cur_len != total_len:
|
| 768 |
+
target[:] = IGNORE_INDEX
|
| 769 |
+
print(
|
| 770 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 771 |
+
f" (ignored)"
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
return dict(
|
| 775 |
+
input_ids=input_ids,
|
| 776 |
+
labels=targets,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def preprocess_v1(
|
| 781 |
+
sources,
|
| 782 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 783 |
+
has_image: bool = False,
|
| 784 |
+
) -> dict:
|
| 785 |
+
conv = conversation_lib.default_conversation.copy()
|
| 786 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 787 |
+
|
| 788 |
+
# Apply prompt templates
|
| 789 |
+
conversations = []
|
| 790 |
+
for i, source in enumerate(sources):
|
| 791 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 792 |
+
# Skip the first one if it is not from human
|
| 793 |
+
source = source[1:]
|
| 794 |
+
|
| 795 |
+
conv.messages = []
|
| 796 |
+
for j, sentence in enumerate(source):
|
| 797 |
+
role = roles[sentence["from"]]
|
| 798 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 799 |
+
conv.append_message(role, sentence["value"])
|
| 800 |
+
conversations.append(conv.get_prompt())
|
| 801 |
+
|
| 802 |
+
# Tokenize conversations
|
| 803 |
+
|
| 804 |
+
if has_image:
|
| 805 |
+
input_ids = torch.stack(
|
| 806 |
+
[
|
| 807 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 808 |
+
for prompt in conversations
|
| 809 |
+
],
|
| 810 |
+
dim=0,
|
| 811 |
+
)
|
| 812 |
+
else:
|
| 813 |
+
input_ids = tokenizer(
|
| 814 |
+
conversations,
|
| 815 |
+
return_tensors="pt",
|
| 816 |
+
padding="longest",
|
| 817 |
+
max_length=tokenizer.model_max_length,
|
| 818 |
+
truncation=True,
|
| 819 |
+
).input_ids
|
| 820 |
+
|
| 821 |
+
targets = input_ids.clone()
|
| 822 |
+
|
| 823 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
| 824 |
+
|
| 825 |
+
# Mask targets
|
| 826 |
+
sep = conv.sep + conv.roles[1] + ": "
|
| 827 |
+
for conversation, target in zip(conversations, targets):
|
| 828 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 829 |
+
|
| 830 |
+
rounds = conversation.split(conv.sep2)
|
| 831 |
+
cur_len = 1
|
| 832 |
+
target[:cur_len] = IGNORE_INDEX
|
| 833 |
+
for i, rou in enumerate(rounds):
|
| 834 |
+
if rou == "":
|
| 835 |
+
break
|
| 836 |
+
|
| 837 |
+
parts = rou.split(sep)
|
| 838 |
+
if len(parts) != 2:
|
| 839 |
+
break
|
| 840 |
+
parts[0] += sep
|
| 841 |
+
|
| 842 |
+
if has_image:
|
| 843 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 844 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
| 845 |
+
else:
|
| 846 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 847 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 848 |
+
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
| 849 |
+
round_len -= 1
|
| 850 |
+
instruction_len -= 1
|
| 851 |
+
|
| 852 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 853 |
+
|
| 854 |
+
cur_len += round_len
|
| 855 |
+
target[cur_len:] = IGNORE_INDEX
|
| 856 |
+
|
| 857 |
+
if cur_len < tokenizer.model_max_length:
|
| 858 |
+
if cur_len != total_len:
|
| 859 |
+
target[:] = IGNORE_INDEX
|
| 860 |
+
print(
|
| 861 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 862 |
+
f" (ignored)"
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
return dict(
|
| 866 |
+
input_ids=input_ids,
|
| 867 |
+
labels=targets,
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
def tokenizer_image_token(
|
| 872 |
+
prompt,
|
| 873 |
+
tokenizer,
|
| 874 |
+
image_token_index=IMAGE_TOKEN_INDEX,
|
| 875 |
+
return_tensors=None,
|
| 876 |
+
):
|
| 877 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 878 |
+
|
| 879 |
+
def insert_separator(X, sep):
|
| 880 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
| 881 |
+
|
| 882 |
+
input_ids = []
|
| 883 |
+
offset = 0
|
| 884 |
+
if (
|
| 885 |
+
len(prompt_chunks) > 0
|
| 886 |
+
and len(prompt_chunks[0]) > 0
|
| 887 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
| 888 |
+
):
|
| 889 |
+
offset = 1
|
| 890 |
+
input_ids.append(prompt_chunks[0][0])
|
| 891 |
+
|
| 892 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 893 |
+
input_ids.extend(x[offset:])
|
| 894 |
+
|
| 895 |
+
if return_tensors is not None:
|
| 896 |
+
if return_tensors == "pt":
|
| 897 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 898 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
| 899 |
+
return input_ids
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def tokenizer_image_token_llama3(
|
| 903 |
+
prompt,
|
| 904 |
+
tokenizer,
|
| 905 |
+
image_token_index=IMAGE_TOKEN_INDEX,
|
| 906 |
+
return_tensors=None,
|
| 907 |
+
):
|
| 908 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 909 |
+
|
| 910 |
+
def insert_separator(X, sep):
|
| 911 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
| 912 |
+
|
| 913 |
+
input_ids = []
|
| 914 |
+
for x in insert_separator(prompt_chunks, [image_token_index]):
|
| 915 |
+
input_ids.extend(x)
|
| 916 |
+
|
| 917 |
+
if return_tensors is not None:
|
| 918 |
+
if return_tensors == "pt":
|
| 919 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 920 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
| 921 |
+
return input_ids
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def preprocess_qwen(
|
| 925 |
+
sources,
|
| 926 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 927 |
+
has_image: bool = False,
|
| 928 |
+
system_message: str = "You are a helpful assistant.",
|
| 929 |
+
) -> dict:
|
| 930 |
+
roles = {"human": "user", "gpt": "assistant"}
|
| 931 |
+
|
| 932 |
+
# Add image tokens to tokenizer as a special tokens
|
| 933 |
+
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
|
| 934 |
+
tokenizer = copy.deepcopy(tokenizer)
|
| 935 |
+
# When there is actually an image, we add the image tokens as a special token
|
| 936 |
+
if has_image:
|
| 937 |
+
tokenizer.add_tokens(["<image>"], special_tokens=True)
|
| 938 |
+
|
| 939 |
+
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
|
| 940 |
+
im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
| 941 |
+
im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 942 |
+
|
| 943 |
+
unmask_tokens_idx = [198, im_start, im_end]
|
| 944 |
+
# nl_tokens = tokenizer("\n").input_ids
|
| 945 |
+
|
| 946 |
+
# Reset Qwen chat templates so that it won't include system message every time we apply
|
| 947 |
+
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
| 948 |
+
tokenizer.chat_template = chat_template
|
| 949 |
+
|
| 950 |
+
# _system = tokenizer("system").input_ids + nl_tokens
|
| 951 |
+
# _user = tokenizer("user").input_ids + nl_tokens
|
| 952 |
+
# _assistant = tokenizer("assistant").input_ids + nl_tokens
|
| 953 |
+
|
| 954 |
+
# Apply prompt templates
|
| 955 |
+
input_ids, targets = [], []
|
| 956 |
+
for source in sources:
|
| 957 |
+
if roles[source[0]["from"]] != roles["human"]:
|
| 958 |
+
source = source[1:]
|
| 959 |
+
|
| 960 |
+
input_id, target = [], []
|
| 961 |
+
|
| 962 |
+
# New version, use apply chat template
|
| 963 |
+
# Build system message for each sentence
|
| 964 |
+
input_id += tokenizer.apply_chat_template(
|
| 965 |
+
[{"role": "system", "content": system_message}]
|
| 966 |
+
)
|
| 967 |
+
target += [IGNORE_INDEX] * len(input_id)
|
| 968 |
+
|
| 969 |
+
for conv in source:
|
| 970 |
+
# Make sure llava data can load
|
| 971 |
+
try:
|
| 972 |
+
role = conv["role"]
|
| 973 |
+
content = conv["content"]
|
| 974 |
+
except:
|
| 975 |
+
role = conv["from"]
|
| 976 |
+
content = conv["value"]
|
| 977 |
+
|
| 978 |
+
role = roles.get(role, role)
|
| 979 |
+
|
| 980 |
+
conv = [{"role": role, "content": content}]
|
| 981 |
+
encode_id = tokenizer.apply_chat_template(conv)
|
| 982 |
+
input_id += encode_id
|
| 983 |
+
if role in ["user", "system"]:
|
| 984 |
+
target += [IGNORE_INDEX] * len(encode_id)
|
| 985 |
+
else:
|
| 986 |
+
target += encode_id
|
| 987 |
+
|
| 988 |
+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
|
| 989 |
+
for idx, encode_id in enumerate(input_id):
|
| 990 |
+
if encode_id in unmask_tokens_idx:
|
| 991 |
+
target[idx] = encode_id
|
| 992 |
+
if encode_id == image_token_index:
|
| 993 |
+
input_id[idx] = IMAGE_TOKEN_INDEX
|
| 994 |
+
input_ids.append(input_id)
|
| 995 |
+
targets.append(target)
|
| 996 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| 997 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
| 998 |
+
|
| 999 |
+
return dict(
|
| 1000 |
+
input_ids=input_ids, # tensor(bs x seq_len)
|
| 1001 |
+
labels=targets, # tensor(bs x seq_len)
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def preprocess_llama3(
|
| 1006 |
+
sources,
|
| 1007 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 1008 |
+
has_image: bool = False,
|
| 1009 |
+
system_message: str = "You are a helpful assistant.",
|
| 1010 |
+
) -> dict:
|
| 1011 |
+
# roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
|
| 1012 |
+
roles = {"human": "user", "gpt": "assistant"}
|
| 1013 |
+
|
| 1014 |
+
# Add image tokens to tokenizer as a special tokens
|
| 1015 |
+
# Use a deepcopy of tokenizer so that we don't modify on the tokenizer
|
| 1016 |
+
tokenizer = copy.deepcopy(tokenizer)
|
| 1017 |
+
# When there is actually an image, we add the image tokens as a special token
|
| 1018 |
+
if has_image:
|
| 1019 |
+
tokenizer.add_tokens(["<image>"], special_tokens=True)
|
| 1020 |
+
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
|
| 1021 |
+
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
|
| 1022 |
+
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
|
| 1023 |
+
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
|
| 1024 |
+
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 1025 |
+
|
| 1026 |
+
unmask_tokens = [
|
| 1027 |
+
"<|begin_of_text|>",
|
| 1028 |
+
"<|start_header_id|>",
|
| 1029 |
+
"<|end_header_id|>",
|
| 1030 |
+
"<|eot_id|>",
|
| 1031 |
+
"\n\n",
|
| 1032 |
+
]
|
| 1033 |
+
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
|
| 1034 |
+
|
| 1035 |
+
# After update, calling tokenizer of llama3 will
|
| 1036 |
+
# auto add bos id for the tokens. γ½(ο½βΒ΄)οΎ
|
| 1037 |
+
def safe_tokenizer_llama3(text):
|
| 1038 |
+
input_ids = tokenizer(text).input_ids
|
| 1039 |
+
if input_ids[0] == bos_token_id:
|
| 1040 |
+
input_ids = input_ids[1:]
|
| 1041 |
+
return input_ids
|
| 1042 |
+
|
| 1043 |
+
nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
|
| 1044 |
+
|
| 1045 |
+
# chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{%- if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{%- endif %}"
|
| 1046 |
+
chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
| 1047 |
+
tokenizer.chat_template = chat_template
|
| 1048 |
+
|
| 1049 |
+
# Apply prompt templates
|
| 1050 |
+
input_ids, targets = [], []
|
| 1051 |
+
for source in sources:
|
| 1052 |
+
if roles[source[0]["from"]] != roles["human"]:
|
| 1053 |
+
source = source[1:]
|
| 1054 |
+
|
| 1055 |
+
input_id, target = [], []
|
| 1056 |
+
|
| 1057 |
+
# New version, use apply chat template
|
| 1058 |
+
# Build system message for each sentence
|
| 1059 |
+
input_id += tokenizer.apply_chat_template(
|
| 1060 |
+
[{"role": "system", "content": system_message}]
|
| 1061 |
+
# pyre-fixme[6]: For 1st argument expected `Union[int, str]` but got `slice`.
|
| 1062 |
+
)[:-4]
|
| 1063 |
+
|
| 1064 |
+
target += [IGNORE_INDEX] * len(input_id)
|
| 1065 |
+
|
| 1066 |
+
for conv in source:
|
| 1067 |
+
# Make sure llava data can load
|
| 1068 |
+
try:
|
| 1069 |
+
role = conv["role"]
|
| 1070 |
+
content = conv["content"]
|
| 1071 |
+
except:
|
| 1072 |
+
role = conv["from"]
|
| 1073 |
+
content = conv["value"]
|
| 1074 |
+
|
| 1075 |
+
role = roles.get(role, role)
|
| 1076 |
+
|
| 1077 |
+
conv = [{"role": role, "content": content}]
|
| 1078 |
+
# First is bos token we don't need here
|
| 1079 |
+
encode_id = tokenizer.apply_chat_template(conv)[1:-4]
|
| 1080 |
+
input_id += encode_id
|
| 1081 |
+
if role in ["user", "system"]:
|
| 1082 |
+
target += [IGNORE_INDEX] * len(encode_id)
|
| 1083 |
+
else:
|
| 1084 |
+
target += encode_id
|
| 1085 |
+
|
| 1086 |
+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
|
| 1087 |
+
for idx, encode_id in enumerate(input_id):
|
| 1088 |
+
if encode_id in unmask_tokens_idx:
|
| 1089 |
+
target[idx] = encode_id
|
| 1090 |
+
if encode_id == image_token_index:
|
| 1091 |
+
input_id[idx] = IMAGE_TOKEN_INDEX
|
| 1092 |
+
input_ids.append(input_id)
|
| 1093 |
+
targets.append(target)
|
| 1094 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| 1095 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
| 1096 |
+
|
| 1097 |
+
# print("input_ids", input_ids, flush=True)
|
| 1098 |
+
# print("targets", targets, flush=True)
|
| 1099 |
+
return dict(
|
| 1100 |
+
input_ids=input_ids, # tensor(bs x seq_len)
|
| 1101 |
+
labels=targets, # tensor(bs x seq_len)
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
def preprocess_llama_3_1(
|
| 1106 |
+
sources,
|
| 1107 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 1108 |
+
has_image: bool = False,
|
| 1109 |
+
) -> dict:
|
| 1110 |
+
conv = conversation_lib.default_conversation.copy()
|
| 1111 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 1112 |
+
|
| 1113 |
+
# Apply prompt templates
|
| 1114 |
+
conversations = []
|
| 1115 |
+
for source in sources:
|
| 1116 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 1117 |
+
# Skip the first one if it is not from human
|
| 1118 |
+
source = source[1:]
|
| 1119 |
+
|
| 1120 |
+
conv.messages = []
|
| 1121 |
+
for sentence in source:
|
| 1122 |
+
if sentence["from"] == "Answer":
|
| 1123 |
+
sentence["from"] = "gpt" # data bug
|
| 1124 |
+
role = roles[sentence["from"]]
|
| 1125 |
+
# assert role == conv.roles[j % 2], f"{i}"
|
| 1126 |
+
conv.append_message(role, sentence["value"])
|
| 1127 |
+
conversations.append(conv.get_prompt())
|
| 1128 |
+
|
| 1129 |
+
# Tokenize conversations
|
| 1130 |
+
|
| 1131 |
+
if has_image:
|
| 1132 |
+
input_ids = torch.stack(
|
| 1133 |
+
[
|
| 1134 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 1135 |
+
for prompt in conversations
|
| 1136 |
+
],
|
| 1137 |
+
dim=0,
|
| 1138 |
+
)
|
| 1139 |
+
else:
|
| 1140 |
+
input_ids = tokenizer(
|
| 1141 |
+
conversations,
|
| 1142 |
+
return_tensors="pt",
|
| 1143 |
+
padding="longest",
|
| 1144 |
+
max_length=tokenizer.model_max_length,
|
| 1145 |
+
truncation=True,
|
| 1146 |
+
).input_ids
|
| 1147 |
+
|
| 1148 |
+
# remove the first bos token
|
| 1149 |
+
if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
|
| 1150 |
+
input_ids = input_ids[:, 1:]
|
| 1151 |
+
targets = input_ids.clone()
|
| 1152 |
+
|
| 1153 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_1
|
| 1154 |
+
|
| 1155 |
+
# Mask targets
|
| 1156 |
+
sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
|
| 1157 |
+
# sep = conv.sep + conv.roles[1] + ": "
|
| 1158 |
+
for conversation, target in zip(conversations, targets):
|
| 1159 |
+
total_len = int(target.shape[0])
|
| 1160 |
+
|
| 1161 |
+
rounds = conversation.split(conv.tokenizer.eos_token)
|
| 1162 |
+
rounds = [rounds[0]] + [
|
| 1163 |
+
rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
|
| 1164 |
+
]
|
| 1165 |
+
|
| 1166 |
+
cur_len = 1
|
| 1167 |
+
target[:cur_len] = IGNORE_INDEX
|
| 1168 |
+
for i, rou in enumerate(rounds):
|
| 1169 |
+
if rou == "":
|
| 1170 |
+
break
|
| 1171 |
+
|
| 1172 |
+
parts = rou.split(sep)
|
| 1173 |
+
if len(parts) != 2 and i != 0:
|
| 1174 |
+
break
|
| 1175 |
+
|
| 1176 |
+
if i == 0:
|
| 1177 |
+
round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
|
| 1178 |
+
instruction_len = len(
|
| 1179 |
+
tokenizer(rou, add_special_tokens=False).input_ids
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
else:
|
| 1183 |
+
parts[0] += sep
|
| 1184 |
+
if has_image:
|
| 1185 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
|
| 1186 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
| 1187 |
+
else:
|
| 1188 |
+
round_len = len(tokenizer(rou).input_ids) + 1
|
| 1189 |
+
instruction_len = len(tokenizer(parts[0]).input_ids)
|
| 1190 |
+
|
| 1191 |
+
# if i > 0: round_len += 1
|
| 1192 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 1193 |
+
cur_len += round_len
|
| 1194 |
+
|
| 1195 |
+
target[cur_len:] = IGNORE_INDEX
|
| 1196 |
+
cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
|
| 1197 |
+
|
| 1198 |
+
# if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
|
| 1199 |
+
if cur_len < tokenizer.model_max_length:
|
| 1200 |
+
if cur_len != total_len:
|
| 1201 |
+
target[:] = IGNORE_INDEX
|
| 1202 |
+
print(
|
| 1203 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 1204 |
+
f" (ignored)"
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
return dict(
|
| 1208 |
+
input_ids=input_ids,
|
| 1209 |
+
labels=targets,
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
def preprocess_llama_3_2(
|
| 1214 |
+
sources,
|
| 1215 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 1216 |
+
has_image: bool = False,
|
| 1217 |
+
) -> dict:
|
| 1218 |
+
conv = conversation_lib.default_conversation.copy()
|
| 1219 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 1220 |
+
|
| 1221 |
+
# Apply prompt templates
|
| 1222 |
+
conversations = []
|
| 1223 |
+
for i, source in enumerate(sources):
|
| 1224 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 1225 |
+
# Skip the first one if it is not from human
|
| 1226 |
+
source = source[1:]
|
| 1227 |
+
|
| 1228 |
+
conv.messages = []
|
| 1229 |
+
for j, sentence in enumerate(source):
|
| 1230 |
+
role = roles[sentence["from"]]
|
| 1231 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 1232 |
+
conv.append_message(role, sentence["value"])
|
| 1233 |
+
conversations.append(conv.get_prompt())
|
| 1234 |
+
|
| 1235 |
+
# Tokenize conversations
|
| 1236 |
+
|
| 1237 |
+
if has_image:
|
| 1238 |
+
input_ids = torch.stack(
|
| 1239 |
+
[
|
| 1240 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 1241 |
+
for prompt in conversations
|
| 1242 |
+
],
|
| 1243 |
+
dim=0,
|
| 1244 |
+
)
|
| 1245 |
+
else:
|
| 1246 |
+
input_ids = tokenizer(
|
| 1247 |
+
conversations,
|
| 1248 |
+
return_tensors="pt",
|
| 1249 |
+
padding="longest",
|
| 1250 |
+
max_length=tokenizer.model_max_length,
|
| 1251 |
+
truncation=True,
|
| 1252 |
+
).input_ids
|
| 1253 |
+
|
| 1254 |
+
# remove the first bos token
|
| 1255 |
+
if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
|
| 1256 |
+
input_ids = input_ids[:, 1:]
|
| 1257 |
+
targets = input_ids.clone()
|
| 1258 |
+
|
| 1259 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_2
|
| 1260 |
+
|
| 1261 |
+
# Mask targets
|
| 1262 |
+
sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
|
| 1263 |
+
# sep = conv.sep + conv.roles[1] + ": "
|
| 1264 |
+
for conversation, target in zip(conversations, targets):
|
| 1265 |
+
total_len = int(target.shape[0])
|
| 1266 |
+
|
| 1267 |
+
rounds = conversation.split(conv.tokenizer.eos_token)
|
| 1268 |
+
rounds = [rounds[0]] + [
|
| 1269 |
+
rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
|
| 1270 |
+
]
|
| 1271 |
+
|
| 1272 |
+
cur_len = 1
|
| 1273 |
+
target[:cur_len] = IGNORE_INDEX
|
| 1274 |
+
for i, rou in enumerate(rounds):
|
| 1275 |
+
if rou == "":
|
| 1276 |
+
break
|
| 1277 |
+
|
| 1278 |
+
parts = rou.split(sep)
|
| 1279 |
+
if len(parts) != 2 and i != 0:
|
| 1280 |
+
break
|
| 1281 |
+
|
| 1282 |
+
if i == 0:
|
| 1283 |
+
round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
|
| 1284 |
+
instruction_len = len(
|
| 1285 |
+
tokenizer(rou, add_special_tokens=False).input_ids
|
| 1286 |
+
)
|
| 1287 |
+
|
| 1288 |
+
else:
|
| 1289 |
+
parts[0] += sep
|
| 1290 |
+
if has_image:
|
| 1291 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
|
| 1292 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
| 1293 |
+
else:
|
| 1294 |
+
round_len = len(tokenizer(rou).input_ids) + 1
|
| 1295 |
+
instruction_len = len(tokenizer(parts[0]).input_ids)
|
| 1296 |
+
|
| 1297 |
+
# if i > 0: round_len += 1
|
| 1298 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 1299 |
+
cur_len += round_len
|
| 1300 |
+
|
| 1301 |
+
target[cur_len:] = IGNORE_INDEX
|
| 1302 |
+
cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
|
| 1303 |
+
|
| 1304 |
+
# if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
|
| 1305 |
+
if cur_len < tokenizer.model_max_length:
|
| 1306 |
+
if cur_len != total_len:
|
| 1307 |
+
target[:] = IGNORE_INDEX
|
| 1308 |
+
print(
|
| 1309 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 1310 |
+
f" (ignored)"
|
| 1311 |
+
)
|
| 1312 |
+
|
| 1313 |
+
return dict(
|
| 1314 |
+
input_ids=input_ids,
|
| 1315 |
+
labels=targets,
|
| 1316 |
+
)
|
| 1317 |
+
|
| 1318 |
+
|
| 1319 |
+
def preprocess_phi3(
|
| 1320 |
+
sources,
|
| 1321 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 1322 |
+
has_image: bool = False,
|
| 1323 |
+
) -> dict:
|
| 1324 |
+
conv = conversation_lib.conv_templates["phi3"].copy()
|
| 1325 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 1326 |
+
|
| 1327 |
+
# Apply prompt templates
|
| 1328 |
+
conversations = []
|
| 1329 |
+
for i, source in enumerate(sources):
|
| 1330 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 1331 |
+
# Skip the first one if it is not from human
|
| 1332 |
+
source = source[1:]
|
| 1333 |
+
|
| 1334 |
+
conv.messages = []
|
| 1335 |
+
for j, sentence in enumerate(source):
|
| 1336 |
+
role = roles[sentence["from"]]
|
| 1337 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 1338 |
+
conv.append_message(role, sentence["value"])
|
| 1339 |
+
conversations.append(conv.get_prompt())
|
| 1340 |
+
|
| 1341 |
+
# Tokenize conversations
|
| 1342 |
+
if has_image:
|
| 1343 |
+
input_ids = torch.stack(
|
| 1344 |
+
[
|
| 1345 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 1346 |
+
for prompt in conversations
|
| 1347 |
+
],
|
| 1348 |
+
dim=0,
|
| 1349 |
+
)
|
| 1350 |
+
else:
|
| 1351 |
+
input_ids = tokenizer(
|
| 1352 |
+
conversations,
|
| 1353 |
+
return_tensors="pt",
|
| 1354 |
+
padding="longest",
|
| 1355 |
+
max_length=tokenizer.model_max_length,
|
| 1356 |
+
truncation=True,
|
| 1357 |
+
).input_ids
|
| 1358 |
+
|
| 1359 |
+
targets = input_ids.clone()
|
| 1360 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
|
| 1361 |
+
|
| 1362 |
+
# Mask targets
|
| 1363 |
+
sep = conv.sep + conv.roles[1]
|
| 1364 |
+
for conversation, target in zip(conversations, targets):
|
| 1365 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 1366 |
+
|
| 1367 |
+
rounds = conversation.split(conv.sep)
|
| 1368 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
| 1369 |
+
for conv_idx in range(3, len(rounds), 2):
|
| 1370 |
+
re_rounds.append(
|
| 1371 |
+
conv.sep.join(rounds[conv_idx : conv_idx + 2])
|
| 1372 |
+
) # user + gpt
|
| 1373 |
+
cur_len = 0
|
| 1374 |
+
target[:cur_len] = IGNORE_INDEX
|
| 1375 |
+
for i, rou in enumerate(re_rounds):
|
| 1376 |
+
if rou == "":
|
| 1377 |
+
break
|
| 1378 |
+
|
| 1379 |
+
parts = rou.split(sep)
|
| 1380 |
+
if len(parts) != 2:
|
| 1381 |
+
break
|
| 1382 |
+
parts[0] += sep
|
| 1383 |
+
|
| 1384 |
+
if has_image:
|
| 1385 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 1386 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
|
| 1387 |
+
else:
|
| 1388 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 1389 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
| 1390 |
+
|
| 1391 |
+
if i == 0:
|
| 1392 |
+
round_len += 1
|
| 1393 |
+
instruction_len += 1
|
| 1394 |
+
else:
|
| 1395 |
+
round_len -= 2
|
| 1396 |
+
instruction_len -= 2
|
| 1397 |
+
|
| 1398 |
+
if (
|
| 1399 |
+
i != 0
|
| 1400 |
+
and getattr(tokenizer, "legacy", False)
|
| 1401 |
+
and IS_TOKENIZER_GREATER_THAN_0_14
|
| 1402 |
+
):
|
| 1403 |
+
round_len += 1
|
| 1404 |
+
instruction_len += 1
|
| 1405 |
+
|
| 1406 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 1407 |
+
|
| 1408 |
+
cur_len += round_len
|
| 1409 |
+
target[cur_len:] = IGNORE_INDEX
|
| 1410 |
+
|
| 1411 |
+
if cur_len < tokenizer.model_max_length:
|
| 1412 |
+
if cur_len != total_len:
|
| 1413 |
+
target[:] = IGNORE_INDEX
|
| 1414 |
+
print(
|
| 1415 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 1416 |
+
f" (ignored)"
|
| 1417 |
+
)
|
| 1418 |
+
|
| 1419 |
+
return dict(
|
| 1420 |
+
input_ids=input_ids,
|
| 1421 |
+
labels=targets,
|
| 1422 |
+
)
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
def preprocess_mpt(
|
| 1426 |
+
sources,
|
| 1427 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 1428 |
+
has_image: bool = False,
|
| 1429 |
+
) -> dict:
|
| 1430 |
+
conv = conversation_lib.default_conversation.copy()
|
| 1431 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 1432 |
+
|
| 1433 |
+
# Apply prompt templates
|
| 1434 |
+
conversations = []
|
| 1435 |
+
for i, source in enumerate(sources):
|
| 1436 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 1437 |
+
# Skip the first one if it is not from human
|
| 1438 |
+
source = source[1:]
|
| 1439 |
+
|
| 1440 |
+
conv.messages = []
|
| 1441 |
+
for j, sentence in enumerate(source):
|
| 1442 |
+
role = roles[sentence["from"]]
|
| 1443 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 1444 |
+
conv.append_message(role, sentence["value"])
|
| 1445 |
+
conversations.append(conv.get_prompt())
|
| 1446 |
+
|
| 1447 |
+
# Tokenize conversations
|
| 1448 |
+
if has_image:
|
| 1449 |
+
input_ids = torch.stack(
|
| 1450 |
+
[
|
| 1451 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 1452 |
+
for prompt in conversations
|
| 1453 |
+
],
|
| 1454 |
+
dim=0,
|
| 1455 |
+
)
|
| 1456 |
+
else:
|
| 1457 |
+
input_ids = tokenizer(
|
| 1458 |
+
conversations,
|
| 1459 |
+
return_tensors="pt",
|
| 1460 |
+
padding="longest",
|
| 1461 |
+
max_length=tokenizer.model_max_length,
|
| 1462 |
+
truncation=True,
|
| 1463 |
+
).input_ids
|
| 1464 |
+
|
| 1465 |
+
targets = input_ids.clone()
|
| 1466 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
|
| 1467 |
+
|
| 1468 |
+
# Mask targets
|
| 1469 |
+
sep = conv.sep + conv.roles[1]
|
| 1470 |
+
for conversation, target in zip(conversations, targets):
|
| 1471 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 1472 |
+
|
| 1473 |
+
rounds = conversation.split(conv.sep)
|
| 1474 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
| 1475 |
+
for conv_idx in range(3, len(rounds), 2):
|
| 1476 |
+
re_rounds.append(
|
| 1477 |
+
conv.sep.join(rounds[conv_idx : conv_idx + 2])
|
| 1478 |
+
) # user + gpt
|
| 1479 |
+
cur_len = 0
|
| 1480 |
+
target[:cur_len] = IGNORE_INDEX
|
| 1481 |
+
for i, rou in enumerate(re_rounds):
|
| 1482 |
+
if rou == "":
|
| 1483 |
+
break
|
| 1484 |
+
|
| 1485 |
+
parts = rou.split(sep)
|
| 1486 |
+
if len(parts) != 2:
|
| 1487 |
+
break
|
| 1488 |
+
parts[0] += sep
|
| 1489 |
+
if has_image:
|
| 1490 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 1491 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
|
| 1492 |
+
else:
|
| 1493 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 1494 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
| 1495 |
+
|
| 1496 |
+
if (
|
| 1497 |
+
i != 0
|
| 1498 |
+
and getattr(tokenizer, "legacy", False)
|
| 1499 |
+
and IS_TOKENIZER_GREATER_THAN_0_14
|
| 1500 |
+
):
|
| 1501 |
+
round_len += 1
|
| 1502 |
+
instruction_len += 1
|
| 1503 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 1504 |
+
|
| 1505 |
+
cur_len += round_len
|
| 1506 |
+
target[cur_len:] = IGNORE_INDEX
|
| 1507 |
+
|
| 1508 |
+
if cur_len < tokenizer.model_max_length:
|
| 1509 |
+
if cur_len != total_len:
|
| 1510 |
+
target[:] = IGNORE_INDEX
|
| 1511 |
+
print(
|
| 1512 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 1513 |
+
f" (ignored)"
|
| 1514 |
+
)
|
| 1515 |
+
|
| 1516 |
+
return dict(
|
| 1517 |
+
input_ids=input_ids,
|
| 1518 |
+
labels=targets,
|
| 1519 |
+
)
|
| 1520 |
+
|
| 1521 |
+
|
| 1522 |
+
def preprocess_plain(
|
| 1523 |
+
sources: Sequence[str],
|
| 1524 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 1525 |
+
) -> dict:
|
| 1526 |
+
# add end signal and concatenate together
|
| 1527 |
+
conversations = []
|
| 1528 |
+
for source in sources:
|
| 1529 |
+
assert len(source) == 2
|
| 1530 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
|
| 1531 |
+
source[0]["value"] = DEFAULT_IMAGE_TOKEN
|
| 1532 |
+
conversation = (
|
| 1533 |
+
source[0]["value"]
|
| 1534 |
+
+ source[1]["value"]
|
| 1535 |
+
+ conversation_lib.default_conversation.sep
|
| 1536 |
+
)
|
| 1537 |
+
conversations.append(conversation)
|
| 1538 |
+
# tokenize conversations
|
| 1539 |
+
input_ids = [
|
| 1540 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 1541 |
+
for prompt in conversations
|
| 1542 |
+
]
|
| 1543 |
+
targets = copy.deepcopy(input_ids)
|
| 1544 |
+
for target, source in zip(targets, sources):
|
| 1545 |
+
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
|
| 1546 |
+
target[:tokenized_len] = IGNORE_INDEX
|
| 1547 |
+
|
| 1548 |
+
return dict(input_ids=input_ids, labels=targets)
|
| 1549 |
+
|
| 1550 |
+
|
| 1551 |
+
def preprocess(
|
| 1552 |
+
sources: Sequence[str],
|
| 1553 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 1554 |
+
has_image: bool = False,
|
| 1555 |
+
) -> dict:
|
| 1556 |
+
"""
|
| 1557 |
+
Given a list of sources, each is a conversation list. This transform:
|
| 1558 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
| 1559 |
+
2. Concatenate conversations together;
|
| 1560 |
+
3. Tokenize the concatenated conversation;
|
| 1561 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
| 1562 |
+
"""
|
| 1563 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
| 1564 |
+
return preprocess_plain(sources, tokenizer)
|
| 1565 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
|
| 1566 |
+
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
|
| 1567 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
| 1568 |
+
return preprocess_v1(sources, tokenizer, has_image=has_image)
|
| 1569 |
+
if conversation_lib.default_conversation.version == "mpt":
|
| 1570 |
+
return preprocess_mpt(sources, tokenizer, has_image=has_image)
|
| 1571 |
+
if conversation_lib.default_conversation.version == "phi3":
|
| 1572 |
+
return preprocess_phi3(sources, tokenizer, has_image=has_image)
|
| 1573 |
+
if conversation_lib.default_conversation.version == "qwen":
|
| 1574 |
+
return preprocess_qwen(sources, tokenizer, has_image=has_image)
|
| 1575 |
+
# add end signal and concatenate together
|
| 1576 |
+
conversations = []
|
| 1577 |
+
for source in sources:
|
| 1578 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
| 1579 |
+
conversation = _add_speaker_and_signal(header, source)
|
| 1580 |
+
conversations.append(conversation)
|
| 1581 |
+
|
| 1582 |
+
# tokenize conversations
|
| 1583 |
+
def get_tokenize_len(prompts):
|
| 1584 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
| 1585 |
+
|
| 1586 |
+
if has_image:
|
| 1587 |
+
input_ids = [
|
| 1588 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 1589 |
+
for prompt in conversations
|
| 1590 |
+
]
|
| 1591 |
+
else:
|
| 1592 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
| 1593 |
+
input_ids = conversations_tokenized["input_ids"]
|
| 1594 |
+
|
| 1595 |
+
targets = copy.deepcopy(input_ids)
|
| 1596 |
+
for target, source in zip(targets, sources):
|
| 1597 |
+
if has_image:
|
| 1598 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
| 1599 |
+
else:
|
| 1600 |
+
tokenized_lens = _tokenize_fn(
|
| 1601 |
+
[header] + [s["value"] for s in source],
|
| 1602 |
+
tokenizer,
|
| 1603 |
+
)["input_ids_lens"]
|
| 1604 |
+
speakers = [sentence["from"] for sentence in source]
|
| 1605 |
+
_mask_targets(target, tokenized_lens, speakers)
|
| 1606 |
+
|
| 1607 |
+
return dict(input_ids=input_ids, labels=targets)
|
tempo/multimodal_encoder/__pycache__/base_encoder.cpython-312.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|
tempo/multimodal_encoder/__pycache__/builder.cpython-312.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
tempo/multimodal_encoder/__pycache__/qwen3vl_encoder.cpython-312.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
tempo/multimodal_encoder/__pycache__/siglip_encoder.cpython-312.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
tempo/multimodal_encoder/base_encoder.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ProcessorWrapper:
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
transform,
|
| 11 |
+
height=378,
|
| 12 |
+
width=378,
|
| 13 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
| 14 |
+
):
|
| 15 |
+
self._crop_size = {
|
| 16 |
+
"height": height,
|
| 17 |
+
"width": width,
|
| 18 |
+
}
|
| 19 |
+
self._transforms = transform
|
| 20 |
+
# print(transform)
|
| 21 |
+
self.image_mean = image_mean
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def crop_size(self):
|
| 25 |
+
return self._crop_size
|
| 26 |
+
|
| 27 |
+
def preprocess(self, image, return_tensors="pt"):
|
| 28 |
+
# Ensure image is a PIL Image
|
| 29 |
+
output = {}
|
| 30 |
+
output["pixel_values"] = [self._transforms(image)]
|
| 31 |
+
return output
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class BaseVisionTower(nn.Module):
|
| 35 |
+
def __init__(self, vision_tower_name, args, delay_load=False):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.is_loaded = False
|
| 39 |
+
self.args = args
|
| 40 |
+
|
| 41 |
+
self.vision_tower_name = vision_tower_name
|
| 42 |
+
self.select_layer = args.mm_vision_select_layer
|
| 43 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
| 44 |
+
self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
|
| 45 |
+
self.delay_load = delay_load
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def load_model(self, device_map=None):
|
| 49 |
+
raise NotImplementedError("Subclasses must implement load_model")
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def _forward(self, images):
|
| 53 |
+
raise NotImplementedError("Subclasses must implement forward")
|
| 54 |
+
|
| 55 |
+
def forward(self, images):
|
| 56 |
+
if type(images) is list:
|
| 57 |
+
image_features = [self._forward(image.unsqueeze(0)) for image in images]
|
| 58 |
+
else:
|
| 59 |
+
image_features = self._forward(images)
|
| 60 |
+
|
| 61 |
+
return image_features
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def dummy_feature(self):
|
| 65 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def dtype(self):
|
| 69 |
+
# Dynamically infer the dtype from the first parameter, if not explicitly specified
|
| 70 |
+
if hasattr(self.vision_tower, "dtype"):
|
| 71 |
+
return self.vision_tower.dtype
|
| 72 |
+
else:
|
| 73 |
+
params = list(self.vision_tower.parameters())
|
| 74 |
+
return (
|
| 75 |
+
params[0].dtype if len(params) > 0 else torch.float32
|
| 76 |
+
) # Default to torch.float32 if no parameters
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def device(self):
|
| 80 |
+
# Dynamically infer the device from the first parameter, if not explicitly specified
|
| 81 |
+
if hasattr(self.vision_tower, "device"):
|
| 82 |
+
return self.vision_tower.device
|
| 83 |
+
else:
|
| 84 |
+
params = list(self.vision_tower.parameters())
|
| 85 |
+
return (
|
| 86 |
+
params[0].device if len(params) > 0 else torch.device("cpu")
|
| 87 |
+
) # Default to CPU if no parameters
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def config(self):
|
| 91 |
+
if self.is_loaded:
|
| 92 |
+
return self.vision_tower.config
|
| 93 |
+
else:
|
| 94 |
+
return self.cfg_only
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def hidden_size(self):
|
| 98 |
+
try:
|
| 99 |
+
return self.config.hidden_size
|
| 100 |
+
except:
|
| 101 |
+
return self._hidden_size
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def image_size(self): # resolution
|
| 105 |
+
# return self.config.image_size
|
| 106 |
+
try:
|
| 107 |
+
return self.config.image_size
|
| 108 |
+
except:
|
| 109 |
+
return self._image_size
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def patch_size(self):
|
| 113 |
+
# return self.config.patch_size
|
| 114 |
+
try:
|
| 115 |
+
return self.config.patch_size
|
| 116 |
+
except:
|
| 117 |
+
return self._patch_size
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def num_patches_per_side(self):
|
| 121 |
+
if self._interp_size is not None:
|
| 122 |
+
return int(self._interp_size**0.5)
|
| 123 |
+
try:
|
| 124 |
+
return self.image_size // self.patch_size
|
| 125 |
+
except:
|
| 126 |
+
return self._num_patches_per_side
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def num_patches(self):
|
| 130 |
+
if self._interp_size is not None:
|
| 131 |
+
return self._interp_size
|
| 132 |
+
try:
|
| 133 |
+
return self.num_patches_per_side**2
|
| 134 |
+
except:
|
| 135 |
+
return self._num_patches
|
tempo/multimodal_encoder/builder.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from .qwen3vl_encoder import Qwen3VLTower
|
| 4 |
+
from .siglip_encoder import SiglipVisionTower
|
| 5 |
+
|
| 6 |
+
def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
|
| 7 |
+
|
| 8 |
+
vision_tower_aux_name_list = getattr(vision_tower_cfg, "mm_vision_tower_aux_list", ["Qwen/Qwen3-VL-2B-Instruct"])
|
| 9 |
+
|
| 10 |
+
vision_tower_aux_list = []
|
| 11 |
+
for vision_tower_aux_name in vision_tower_aux_name_list:
|
| 12 |
+
config = copy.deepcopy(vision_tower_cfg)
|
| 13 |
+
vision_tower_basename = Path(vision_tower_aux_name).name.lower()
|
| 14 |
+
if "siglip" in vision_tower_basename:
|
| 15 |
+
vision_tower_aux_list.append(SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs))
|
| 16 |
+
elif "qwen3-vl" in vision_tower_basename:
|
| 17 |
+
vision_tower_aux_list.append(Qwen3VLTower(vision_tower_aux_name, args=config, **kwargs))
|
| 18 |
+
else:
|
| 19 |
+
raise ValueError(f"Unknown vision tower: {vision_tower_basename}")
|
| 20 |
+
|
| 21 |
+
return vision_tower_aux_list
|
tempo/multimodal_encoder/qwen3vl_encoder.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import random
|
| 3 |
+
random.seed(42)
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from accelerate import init_empty_weights
|
| 8 |
+
from transformers.utils import is_torchdynamo_compiling
|
| 9 |
+
from transformers import AutoConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
| 10 |
+
|
| 11 |
+
class Qwen3VLTower(nn.Module):
|
| 12 |
+
def __init__(self, vision_tower_aux_name, args, **kwargs):
|
| 13 |
+
super(Qwen3VLTower, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.is_loaded = True # for compatibility
|
| 16 |
+
self.model_path = vision_tower_aux_name
|
| 17 |
+
self.dynamic_compress = getattr(args, "dynamic_compress", False)
|
| 18 |
+
|
| 19 |
+
# load processor
|
| 20 |
+
self.image_processor = Qwen3VLProcessor.from_pretrained(self.model_path)
|
| 21 |
+
|
| 22 |
+
# load config
|
| 23 |
+
self.config = AutoConfig.from_pretrained(self.model_path)
|
| 24 |
+
self.config._attn_implementation = "flash_attention_2"
|
| 25 |
+
self.config.dtype = torch.bfloat16
|
| 26 |
+
|
| 27 |
+
# load model
|
| 28 |
+
with init_empty_weights():
|
| 29 |
+
self.vlm = Qwen3VLForConditionalGeneration(self.config)
|
| 30 |
+
if hasattr(self.vlm, "lm_head"):
|
| 31 |
+
del self.vlm.lm_head
|
| 32 |
+
|
| 33 |
+
self.vlm.requires_grad_(False)
|
| 34 |
+
|
| 35 |
+
self.hidden_size = self.config.text_config.hidden_size
|
| 36 |
+
self.num_compression_tokens = args.num_compression_tokens
|
| 37 |
+
|
| 38 |
+
self.compression_tokens = nn.Parameter(
|
| 39 |
+
torch.empty(1, self.num_compression_tokens, self.hidden_size)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def smart_init_vision_tower(self):
|
| 43 |
+
"""Load only during Stage 0"""
|
| 44 |
+
|
| 45 |
+
temp_model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 46 |
+
self.model_path,
|
| 47 |
+
dtype=self.vlm.dtype,
|
| 48 |
+
device_map="cpu", # avoid multiple nodes and gpu conflicit
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
missing_keys, unexpected_keys = self.vlm.load_state_dict(temp_model.state_dict(), strict=False)
|
| 52 |
+
|
| 53 |
+
if len(missing_keys) > 0:
|
| 54 |
+
print(f"[Warning] Missing keys in Qwen3-VL loading: {missing_keys}")
|
| 55 |
+
if len(unexpected_keys) > 0:
|
| 56 |
+
print(f"[Warning] Unexpected keys keys in Qwen3-VL loading: {unexpected_keys}")
|
| 57 |
+
|
| 58 |
+
del temp_model
|
| 59 |
+
gc.collect()
|
| 60 |
+
torch.cuda.empty_cache()
|
| 61 |
+
|
| 62 |
+
self.vlm.requires_grad_(False)
|
| 63 |
+
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
embed_weights = self.vlm.model.language_model.embed_tokens.weight
|
| 66 |
+
mean = embed_weights.mean(dim=0)
|
| 67 |
+
std = embed_weights.std(dim=0)
|
| 68 |
+
|
| 69 |
+
self.compression_tokens.data = torch.normal(mean=mean.repeat(self.num_compression_tokens, 1), std=std.repeat(self.num_compression_tokens, 1)).unsqueeze(0)
|
| 70 |
+
print(f"[Smart Init] Done. Shape: {self.compression_tokens.shape}")
|
| 71 |
+
|
| 72 |
+
def smart_init_dynamic_compress(self):
|
| 73 |
+
if getattr(self, "vlm_head", None) is None:
|
| 74 |
+
temp_model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 75 |
+
self.model_path,
|
| 76 |
+
dtype=self.vlm.dtype,
|
| 77 |
+
device_map="cpu", # avoid multiple nodes and gpu conflicit
|
| 78 |
+
)
|
| 79 |
+
self.vlm_head = temp_model.lm_head
|
| 80 |
+
|
| 81 |
+
del temp_model
|
| 82 |
+
gc.collect()
|
| 83 |
+
torch.cuda.empty_cache()
|
| 84 |
+
|
| 85 |
+
token_true_id = self.image_processor.tokenizer.get_vocab()["Yes"]
|
| 86 |
+
token_false_id = self.image_processor.tokenizer.get_vocab()["No"]
|
| 87 |
+
|
| 88 |
+
lm_head_weights = self.vlm_head.weight.data
|
| 89 |
+
weight_yes = lm_head_weights[token_true_id]
|
| 90 |
+
weight_no = lm_head_weights[token_false_id]
|
| 91 |
+
|
| 92 |
+
D = weight_yes.size()[0]
|
| 93 |
+
self.linear_layer = nn.Linear(D, 1, bias=False)
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
self.linear_layer.weight[0] = weight_yes - weight_no
|
| 96 |
+
|
| 97 |
+
del self.vlm_head
|
| 98 |
+
self.linear_layer.to("cuda")
|
| 99 |
+
|
| 100 |
+
print(f"[Smart Init Router] Done!")
|
| 101 |
+
|
| 102 |
+
def compute_relevance(self, mask_compression_token, batch_size, last_hidden_state):
|
| 103 |
+
first_compression_idx = mask_compression_token.int().argmax(
|
| 104 |
+
dim=1
|
| 105 |
+
) # (batch_size,)
|
| 106 |
+
prev_idx = first_compression_idx - 1 # (batch_size,)
|
| 107 |
+
batch_indices = torch.arange(batch_size, device=last_hidden_state.device)
|
| 108 |
+
prev_token_features = last_hidden_state[
|
| 109 |
+
batch_indices, prev_idx
|
| 110 |
+
] # (batch_size, hidden_dim)
|
| 111 |
+
|
| 112 |
+
scores = self.linear_layer(prev_token_features.float())
|
| 113 |
+
scores = torch.sigmoid(scores).squeeze(-1).cpu().detach().tolist()
|
| 114 |
+
|
| 115 |
+
return scores
|
| 116 |
+
|
| 117 |
+
def load_model(self):
|
| 118 |
+
# for compatible with other encoder
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
input_ids=None,
|
| 124 |
+
attention_mask=None,
|
| 125 |
+
position_ids=None,
|
| 126 |
+
past_key_values=None,
|
| 127 |
+
inputs_embeds=None,
|
| 128 |
+
pixel_values=None,
|
| 129 |
+
pixel_values_videos=None,
|
| 130 |
+
image_grid_thw=None,
|
| 131 |
+
video_grid_thw=None,
|
| 132 |
+
cache_position=None,
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
if self.dynamic_compress and not hasattr(self, "linear_layer"):
|
| 136 |
+
self.smart_init_dynamic_compress()
|
| 137 |
+
|
| 138 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 139 |
+
raise ValueError(
|
| 140 |
+
"You must specify exactly one of input_ids or inputs_embeds"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
used_compression_tokens = self.num_compression_tokens
|
| 144 |
+
|
| 145 |
+
if inputs_embeds is None:
|
| 146 |
+
# process input_ids to insert learnable token
|
| 147 |
+
batch_size, n_seq = input_ids.shape
|
| 148 |
+
valid_lengths = (
|
| 149 |
+
input_ids != self.image_processor.tokenizer.pad_token_id
|
| 150 |
+
).sum(dim=1)
|
| 151 |
+
input_ids = torch.cat(
|
| 152 |
+
[
|
| 153 |
+
input_ids,
|
| 154 |
+
torch.full(
|
| 155 |
+
(batch_size, used_compression_tokens),
|
| 156 |
+
# self.config.pad_token_id,
|
| 157 |
+
self.image_processor.tokenizer.pad_token_id,
|
| 158 |
+
dtype=input_ids.dtype,
|
| 159 |
+
device=input_ids.device,
|
| 160 |
+
),
|
| 161 |
+
],
|
| 162 |
+
dim=1,
|
| 163 |
+
)
|
| 164 |
+
attention_mask = torch.cat(
|
| 165 |
+
[
|
| 166 |
+
attention_mask,
|
| 167 |
+
torch.zeros(
|
| 168 |
+
(batch_size, used_compression_tokens),
|
| 169 |
+
dtype=attention_mask.dtype,
|
| 170 |
+
device=attention_mask.device,
|
| 171 |
+
),
|
| 172 |
+
],
|
| 173 |
+
dim=1,
|
| 174 |
+
)
|
| 175 |
+
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
|
| 176 |
+
else:
|
| 177 |
+
raise NotImplementedError(
|
| 178 |
+
"Current only support input_ids as vlm compressor inputs"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
image_mask = None
|
| 182 |
+
video_mask = None
|
| 183 |
+
|
| 184 |
+
if pixel_values is not None:
|
| 185 |
+
image_embeds, deepstack_image_embeds = self.vlm.get_image_features(
|
| 186 |
+
pixel_values, image_grid_thw
|
| 187 |
+
)
|
| 188 |
+
image_embeds = torch.cat(image_embeds, dim=0).to(
|
| 189 |
+
inputs_embeds.device, inputs_embeds.dtype
|
| 190 |
+
)
|
| 191 |
+
image_mask, _ = self.vlm.model.get_placeholder_mask(
|
| 192 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
|
| 193 |
+
)
|
| 194 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 195 |
+
|
| 196 |
+
if pixel_values_videos is not None:
|
| 197 |
+
video_embeds, deepstack_video_embeds = self.vlm.get_video_features(
|
| 198 |
+
pixel_values_videos, video_grid_thw
|
| 199 |
+
)
|
| 200 |
+
video_embeds = torch.cat(video_embeds, dim=0).to(
|
| 201 |
+
inputs_embeds.device, inputs_embeds.dtype
|
| 202 |
+
)
|
| 203 |
+
_, video_mask = self.vlm.model.get_placeholder_mask(
|
| 204 |
+
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
|
| 205 |
+
)
|
| 206 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 207 |
+
|
| 208 |
+
visual_pos_masks = None
|
| 209 |
+
deepstack_visual_embeds = None
|
| 210 |
+
if image_mask is not None and video_mask is not None:
|
| 211 |
+
# aggregate visual_pos_masks and deepstack_visual_embeds
|
| 212 |
+
image_mask = image_mask[..., 0]
|
| 213 |
+
video_mask = video_mask[..., 0]
|
| 214 |
+
visual_pos_masks = image_mask | video_mask
|
| 215 |
+
deepstack_visual_embeds = []
|
| 216 |
+
image_mask_joint = image_mask[visual_pos_masks]
|
| 217 |
+
video_mask_joint = video_mask[visual_pos_masks]
|
| 218 |
+
for img_embed, vid_embed in zip(
|
| 219 |
+
deepstack_image_embeds, deepstack_video_embeds
|
| 220 |
+
):
|
| 221 |
+
embed_joint = img_embed.new_zeros(
|
| 222 |
+
visual_pos_masks.sum(), img_embed.shape[-1]
|
| 223 |
+
).to(img_embed.device)
|
| 224 |
+
embed_joint[image_mask_joint, :] = img_embed
|
| 225 |
+
embed_joint[video_mask_joint, :] = vid_embed
|
| 226 |
+
deepstack_visual_embeds.append(embed_joint)
|
| 227 |
+
elif image_mask is not None:
|
| 228 |
+
image_mask = image_mask[..., 0]
|
| 229 |
+
visual_pos_masks = image_mask
|
| 230 |
+
deepstack_visual_embeds = deepstack_image_embeds
|
| 231 |
+
elif video_mask is not None:
|
| 232 |
+
video_mask = video_mask[..., 0]
|
| 233 |
+
visual_pos_masks = video_mask
|
| 234 |
+
deepstack_visual_embeds = deepstack_video_embeds
|
| 235 |
+
|
| 236 |
+
# ------------------------------------------------------------------
|
| 237 |
+
# inputs_embeds, [Text + Image + Video]οΌshape: (B, L, D)
|
| 238 |
+
# concat Learnable Tokens
|
| 239 |
+
position_compression_token = (
|
| 240 |
+
torch.arange(n_seq + used_compression_tokens, device=input_ids.device)
|
| 241 |
+
.unsqueeze(0)
|
| 242 |
+
.expand(batch_size, -1)
|
| 243 |
+
)
|
| 244 |
+
mask_compression_token = (
|
| 245 |
+
position_compression_token >= valid_lengths.unsqueeze(1)
|
| 246 |
+
) & (
|
| 247 |
+
position_compression_token
|
| 248 |
+
< (valid_lengths + used_compression_tokens).unsqueeze(1)
|
| 249 |
+
)
|
| 250 |
+
compression_tokens_expanded = self.compression_tokens[
|
| 251 |
+
:, :used_compression_tokens, :
|
| 252 |
+
].expand(batch_size, -1, -1)
|
| 253 |
+
inputs_embeds[mask_compression_token] = compression_tokens_expanded.reshape(
|
| 254 |
+
-1, self.hidden_size
|
| 255 |
+
).to(inputs_embeds.dtype)
|
| 256 |
+
attention_mask.masked_fill_(mask_compression_token, 1)
|
| 257 |
+
# ------------------------------------------------------------------
|
| 258 |
+
|
| 259 |
+
if position_ids is None:
|
| 260 |
+
attention_mask_tensor = (
|
| 261 |
+
attention_mask
|
| 262 |
+
if not isinstance(attention_mask, dict)
|
| 263 |
+
else attention_mask["full_attention"]
|
| 264 |
+
)
|
| 265 |
+
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
| 266 |
+
attention_mask_tensor = torch.diagonal(
|
| 267 |
+
attention_mask_tensor[:, 0], dim1=1, dim2=2
|
| 268 |
+
)
|
| 269 |
+
# Only apply conversion for floating point tensors (inverted masks)
|
| 270 |
+
if attention_mask_tensor.dtype.is_floating_point:
|
| 271 |
+
attention_mask_tensor = (
|
| 272 |
+
attention_mask_tensor
|
| 273 |
+
/ torch.finfo(attention_mask_tensor.dtype).min
|
| 274 |
+
)
|
| 275 |
+
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
| 276 |
+
|
| 277 |
+
# Calculate RoPE index once per generation in the pre-fill stage only.
|
| 278 |
+
# When compiling, we can't check tensor values thus we check only input length
|
| 279 |
+
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
| 280 |
+
# models currently cannot do asssisted decoding
|
| 281 |
+
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
| 282 |
+
(input_ids is not None and input_ids.shape[1] != 1)
|
| 283 |
+
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
| 284 |
+
)
|
| 285 |
+
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
| 286 |
+
(cache_position is not None and cache_position[0] == 0)
|
| 287 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 288 |
+
)
|
| 289 |
+
if (
|
| 290 |
+
prefill_compiled_stage or prefill_noncompiled_stage
|
| 291 |
+
) or self.rope_deltas is None:
|
| 292 |
+
position_ids, rope_deltas = self.vlm.model.get_rope_index(
|
| 293 |
+
input_ids,
|
| 294 |
+
image_grid_thw,
|
| 295 |
+
video_grid_thw,
|
| 296 |
+
attention_mask=attention_mask_tensor,
|
| 297 |
+
)
|
| 298 |
+
self.rope_deltas = rope_deltas
|
| 299 |
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
| 300 |
+
else:
|
| 301 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 302 |
+
delta = (
|
| 303 |
+
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
| 304 |
+
if cache_position is not None
|
| 305 |
+
else 0
|
| 306 |
+
)
|
| 307 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 308 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 309 |
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
| 310 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
| 311 |
+
position_ids = position_ids.add(delta)
|
| 312 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 313 |
+
|
| 314 |
+
outputs = self.vlm.model.language_model(
|
| 315 |
+
input_ids=None,
|
| 316 |
+
position_ids=position_ids,
|
| 317 |
+
attention_mask=attention_mask,
|
| 318 |
+
past_key_values=past_key_values,
|
| 319 |
+
inputs_embeds=inputs_embeds,
|
| 320 |
+
cache_position=cache_position,
|
| 321 |
+
visual_pos_masks=visual_pos_masks,
|
| 322 |
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
| 323 |
+
**kwargs,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
last_hidden_state = outputs.last_hidden_state
|
| 327 |
+
compression_features_flat = last_hidden_state[mask_compression_token]
|
| 328 |
+
compression_features = compression_features_flat.reshape(
|
| 329 |
+
batch_size, used_compression_tokens, -1
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
relevance_scores = None
|
| 333 |
+
if self.dynamic_compress:
|
| 334 |
+
relevance_scores = self.compute_relevance(mask_compression_token, batch_size, last_hidden_state)
|
| 335 |
+
|
| 336 |
+
return compression_features, relevance_scores
|
tempo/multimodal_encoder/siglip_encoder.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from transformers import SiglipImageProcessor, SiglipVisionModel
|
| 4 |
+
|
| 5 |
+
from .base_encoder import BaseVisionTower
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SiglipVisionTower(BaseVisionTower):
|
| 9 |
+
def __init__(self, vision_tower_name, args, delay_load=False):
|
| 10 |
+
super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
|
| 11 |
+
model_path, res, interp = vision_tower_name, 384, 576
|
| 12 |
+
self.vision_tower_name = model_path
|
| 13 |
+
self._image_size = res if res is not None else 512
|
| 14 |
+
self._interp_size = interp
|
| 15 |
+
if not self.delay_load:
|
| 16 |
+
self.load_model()
|
| 17 |
+
elif self.unfreeze_mm_vision_tower:
|
| 18 |
+
self.load_model()
|
| 19 |
+
else:
|
| 20 |
+
self._hidden_size = 1152
|
| 21 |
+
|
| 22 |
+
def load_model(self, device_map=None):
|
| 23 |
+
self.vision_model = "siglip"
|
| 24 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
|
| 25 |
+
|
| 26 |
+
# self.vision_tower = clip_model.visual.trunk
|
| 27 |
+
self.vision_tower.output_tokens = True
|
| 28 |
+
|
| 29 |
+
self._hidden_size = self.vision_tower.config.hidden_size
|
| 30 |
+
self._image_size = self.vision_tower.config.image_size
|
| 31 |
+
self._patch_size = self.vision_tower.config.patch_size
|
| 32 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(
|
| 33 |
+
self.vision_tower_name
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
|
| 37 |
+
self.is_loaded = True
|
| 38 |
+
|
| 39 |
+
def interpolate(self, image_features):
|
| 40 |
+
if self._interp_size is None:
|
| 41 |
+
return image_features
|
| 42 |
+
|
| 43 |
+
b, num_tokens, dim = image_features.shape
|
| 44 |
+
|
| 45 |
+
if num_tokens != self.num_patches:
|
| 46 |
+
target_h = target_w = int(self._interp_size**0.5)
|
| 47 |
+
h = w = int(num_tokens**0.5)
|
| 48 |
+
|
| 49 |
+
image_features = image_features.view(b, h, w, dim)
|
| 50 |
+
image_features = image_features.permute(0, 3, 1, 2).contiguous()
|
| 51 |
+
|
| 52 |
+
image_features = F.interpolate(
|
| 53 |
+
image_features.to(torch.float32),
|
| 54 |
+
size=(target_h, target_w),
|
| 55 |
+
mode="bilinear",
|
| 56 |
+
align_corners=False,
|
| 57 |
+
).to(image_features.dtype)
|
| 58 |
+
|
| 59 |
+
# Permute the dimensions back to (b, target_h, target_w, dim)
|
| 60 |
+
image_features = image_features.permute(0, 2, 3, 1).contiguous()
|
| 61 |
+
|
| 62 |
+
# Flatten the spatial dimensions (target_h, target_w) into a single dimension
|
| 63 |
+
image_features = image_features.flatten(1, 2)
|
| 64 |
+
|
| 65 |
+
return image_features
|
| 66 |
+
|
| 67 |
+
def _forward(self, images, interpolate_token=576):
|
| 68 |
+
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
|
| 69 |
+
embeddings = self.vision_tower.vision_model.embeddings(images)
|
| 70 |
+
encoder_outputs = self.vision_tower.vision_model.encoder(
|
| 71 |
+
inputs_embeds=embeddings
|
| 72 |
+
)
|
| 73 |
+
image_features = encoder_outputs.last_hidden_state
|
| 74 |
+
interp_features = self.interpolate(image_features)
|
| 75 |
+
return interp_features
|
tempo/multimodal_projector/__pycache__/builder.cpython-312.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
tempo/multimodal_projector/builder.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class IdentityMap(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super().__init__()
|
| 7 |
+
|
| 8 |
+
def forward(self, x, *args, **kwargs):
|
| 9 |
+
return x
|
| 10 |
+
|
| 11 |
+
@property
|
| 12 |
+
def config(self):
|
| 13 |
+
return {"mm_projector_type": "identity"}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SimpleResBlock(nn.Module):
|
| 17 |
+
def __init__(self, channels):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.pre_norm = nn.LayerNorm(channels)
|
| 20 |
+
|
| 21 |
+
self.proj = nn.Sequential(
|
| 22 |
+
nn.Linear(channels, channels),
|
| 23 |
+
nn.GELU(),
|
| 24 |
+
nn.Linear(channels, channels)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
x = self.pre_norm(x)
|
| 29 |
+
return x + self.proj(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_vision_projector(config):
|
| 33 |
+
|
| 34 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
| 35 |
+
|
| 36 |
+
if projector_type == "linear":
|
| 37 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 38 |
+
|
| 39 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
| 40 |
+
if mlp_gelu_match:
|
| 41 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 42 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 43 |
+
for _ in range(1, mlp_depth):
|
| 44 |
+
modules.append(nn.GELU())
|
| 45 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 46 |
+
return nn.Sequential(*modules)
|
| 47 |
+
|
| 48 |
+
if projector_type == "identity":
|
| 49 |
+
return IdentityMap()
|
| 50 |
+
|
| 51 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
tempo/tempo_arch.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from tempo.constants import (
|
| 21 |
+
DEFAULT_IM_END_TOKEN,
|
| 22 |
+
DEFAULT_IM_START_TOKEN,
|
| 23 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
| 24 |
+
IGNORE_INDEX,
|
| 25 |
+
IMAGE_TOKEN_INDEX,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from tempo.multimodal_encoder.builder import build_vision_tower_aux_list
|
| 29 |
+
from tempo.multimodal_projector.builder import build_vision_projector
|
| 30 |
+
from tempo.vlm_multimodal_processor import VLMMultimodalProcessor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TempoMetaModel:
|
| 34 |
+
def __init__(self, config):
|
| 35 |
+
super(TempoMetaModel, self).__init__(config)
|
| 36 |
+
|
| 37 |
+
if hasattr(config, "mm_vision_tower_aux_list"):
|
| 38 |
+
self.vision_tower_aux_list = nn.ModuleList(
|
| 39 |
+
build_vision_tower_aux_list(config, delay_load=True)
|
| 40 |
+
)
|
| 41 |
+
config.mm_hidden_size = sum(
|
| 42 |
+
[
|
| 43 |
+
vision_tower_aux.hidden_size for vision_tower_aux in self.vision_tower_aux_list
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
self.mm_projector = build_vision_projector(config)
|
| 47 |
+
else:
|
| 48 |
+
raise NotImplementedError(
|
| 49 |
+
"mm_vision_tower_aux_list is not found in config. Please initialize vision modules in the subclass of TempoMetaModel."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def get_vision_tower_aux_list(self):
|
| 53 |
+
vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None)
|
| 54 |
+
return vision_tower_aux_list
|
| 55 |
+
|
| 56 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
| 57 |
+
# vision_hidden_size = model_args.vision_hidden_size
|
| 58 |
+
vision_tower_aux_list = model_args.vision_tower_aux_list
|
| 59 |
+
# vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list
|
| 60 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
| 61 |
+
self.config.mm_vision_tower_aux_list = vision_tower_aux_list
|
| 62 |
+
# self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list
|
| 63 |
+
|
| 64 |
+
if self.get_vision_tower_aux_list() is None:
|
| 65 |
+
vision_tower_aux_list = build_vision_tower_aux_list(model_args)
|
| 66 |
+
if model_args.unfreeze_mm_vision_tower:
|
| 67 |
+
self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list)
|
| 68 |
+
else:
|
| 69 |
+
self.vision_tower_aux_list = vision_tower_aux_list
|
| 70 |
+
else:
|
| 71 |
+
vision_tower_aux_list = self.vision_tower_aux_list
|
| 72 |
+
for vision_tower_aux in vision_tower_aux_list:
|
| 73 |
+
vision_tower_aux.load_model()
|
| 74 |
+
|
| 75 |
+
if model_args.unfreeze_mm_vision_tower and not isinstance(self.vision_tower_aux_list, nn.ModuleList):
|
| 76 |
+
self.vision_tower_aux_list = nn.ModuleList(self.vision_tower_aux_list)
|
| 77 |
+
|
| 78 |
+
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
| 79 |
+
# self.config.vision_hidden_size = vision_hidden_size
|
| 80 |
+
|
| 81 |
+
if getattr(self, "mm_projector", None) is None:
|
| 82 |
+
self.config.mm_hidden_size = sum(
|
| 83 |
+
[
|
| 84 |
+
vision_tower_aux.hidden_size for vision_tower_aux in vision_tower_aux_list
|
| 85 |
+
]
|
| 86 |
+
)
|
| 87 |
+
self.mm_projector = build_vision_projector(self.config)
|
| 88 |
+
else:
|
| 89 |
+
for p in self.mm_projector.parameters():
|
| 90 |
+
p.requires_grad = True
|
| 91 |
+
|
| 92 |
+
if pretrain_mm_mlp_adapter is not None:
|
| 93 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
| 94 |
+
|
| 95 |
+
def get_w(weights, keyword):
|
| 96 |
+
return {
|
| 97 |
+
k.split(keyword + ".")[1]: v
|
| 98 |
+
for k, v in weights.items()
|
| 99 |
+
if keyword + "." in k
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
self.mm_projector.load_state_dict(
|
| 103 |
+
get_w(mm_projector_weights, "mm_projector"), strict=True
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class TempoMetaForCausalLM(ABC):
|
| 108 |
+
@abstractmethod
|
| 109 |
+
def get_model(self):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def get_vision_tower_aux_list(self):
|
| 113 |
+
return self.get_model().get_vision_tower_aux_list()
|
| 114 |
+
|
| 115 |
+
def prepare_inputs_labels_for_multimodal(
|
| 116 |
+
self,
|
| 117 |
+
input_ids,
|
| 118 |
+
position_ids,
|
| 119 |
+
attention_mask,
|
| 120 |
+
past_key_values,
|
| 121 |
+
labels,
|
| 122 |
+
images=None,
|
| 123 |
+
image_sizes=None,
|
| 124 |
+
vlm_inputs=None,
|
| 125 |
+
seg_timestamps=None,
|
| 126 |
+
batch_split_size=None,
|
| 127 |
+
relevance=None,
|
| 128 |
+
):
|
| 129 |
+
if input_ids.shape[1] == 1: # inference
|
| 130 |
+
return (
|
| 131 |
+
input_ids,
|
| 132 |
+
position_ids,
|
| 133 |
+
attention_mask,
|
| 134 |
+
past_key_values,
|
| 135 |
+
None,
|
| 136 |
+
labels,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
is_video = "pixel_values_videos" in vlm_inputs
|
| 140 |
+
|
| 141 |
+
compressed_features, relevance_scores = VLMMultimodalProcessor.tokenize_vision_inputs(self.get_vision_tower_aux_list()[0], vlm_inputs, is_video)
|
| 142 |
+
compressed_features, count_allocations = (
|
| 143 |
+
VLMMultimodalProcessor.adaptive_linear_budget_allocation(
|
| 144 |
+
compressed_features,
|
| 145 |
+
relevance_scores,
|
| 146 |
+
is_video,
|
| 147 |
+
max_budget=self.config.visual_token_budget if hasattr(self.config, "visual_token_budget") else 8192,
|
| 148 |
+
min_tokens=4,
|
| 149 |
+
strategy="head",
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# for visulization of the allocation results, can be removed
|
| 154 |
+
self._demo_count_allocations = count_allocations
|
| 155 |
+
|
| 156 |
+
if isinstance(compressed_features, list):
|
| 157 |
+
seg_lens = [feat.shape[0] for feat in compressed_features]
|
| 158 |
+
compressed_features = torch.cat(compressed_features, dim=0)
|
| 159 |
+
image_features = self.get_model().mm_projector(compressed_features)
|
| 160 |
+
print(f"[Total Segments: {len(seg_lens)}], compression features after dynamic compress", image_features.shape)
|
| 161 |
+
image_features = list(torch.split(image_features, seg_lens, dim=0))
|
| 162 |
+
else:
|
| 163 |
+
image_features = self.get_model().mm_projector(compressed_features)
|
| 164 |
+
print("final compression features for the whole batch:", image_features.shape)
|
| 165 |
+
|
| 166 |
+
if is_video:
|
| 167 |
+
# add timestamp embeddings for video inputs
|
| 168 |
+
if batch_split_size is not None and len(batch_split_size) > 1: # batch training
|
| 169 |
+
print(f"Number of segments for each video is: {batch_split_size}")
|
| 170 |
+
start_idx = 0
|
| 171 |
+
final_image_features_list = []
|
| 172 |
+
for b_split_size in batch_split_size:
|
| 173 |
+
current_image_features = image_features[start_idx : start_idx + b_split_size]
|
| 174 |
+
current_seg_timestamp = seg_timestamps[start_idx : start_idx + b_split_size]
|
| 175 |
+
final_image_features_list.append(
|
| 176 |
+
VLMMultimodalProcessor.add_seg_timestamp(
|
| 177 |
+
current_image_features,
|
| 178 |
+
self.get_model(),
|
| 179 |
+
current_seg_timestamp,
|
| 180 |
+
is_video,
|
| 181 |
+
)
|
| 182 |
+
)
|
| 183 |
+
start_idx += b_split_size
|
| 184 |
+
else:
|
| 185 |
+
final_image_features_list = [
|
| 186 |
+
VLMMultimodalProcessor.add_seg_timestamp(
|
| 187 |
+
image_features, self.get_model(), seg_timestamps, is_video
|
| 188 |
+
)
|
| 189 |
+
]
|
| 190 |
+
else:
|
| 191 |
+
final_image_features_list = [img for img in image_features]
|
| 192 |
+
|
| 193 |
+
_labels = labels
|
| 194 |
+
_position_ids = position_ids
|
| 195 |
+
_attention_mask = attention_mask
|
| 196 |
+
|
| 197 |
+
if attention_mask is None:
|
| 198 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 199 |
+
else:
|
| 200 |
+
attention_mask = attention_mask.bool()
|
| 201 |
+
|
| 202 |
+
if position_ids is None:
|
| 203 |
+
position_ids = torch.arange(
|
| 204 |
+
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if labels is None:
|
| 208 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 209 |
+
|
| 210 |
+
attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX)
|
| 211 |
+
|
| 212 |
+
input_ids = [
|
| 213 |
+
cur_input_ids[cur_attention_mask]
|
| 214 |
+
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
| 215 |
+
]
|
| 216 |
+
labels = [
|
| 217 |
+
cur_labels[cur_attention_mask]
|
| 218 |
+
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
new_input_embeds = []
|
| 222 |
+
new_labels = []
|
| 223 |
+
cur_image_idx = 0
|
| 224 |
+
|
| 225 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
| 226 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
| 227 |
+
|
| 228 |
+
if num_images == 0:
|
| 229 |
+
cur_image_features = final_image_features_list[cur_image_idx]
|
| 230 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
| 231 |
+
cur_input_embeds = torch.cat(
|
| 232 |
+
[cur_input_embeds_1, cur_image_features[0:0]], dim=0
|
| 233 |
+
)
|
| 234 |
+
new_input_embeds.append(cur_input_embeds)
|
| 235 |
+
new_labels.append(labels[batch_idx])
|
| 236 |
+
cur_image_idx += 1
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
image_token_indices = (
|
| 240 |
+
[-1]
|
| 241 |
+
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
|
| 242 |
+
+ [cur_input_ids.shape[0]]
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
cur_input_ids_noim = []
|
| 246 |
+
cur_labels = labels[batch_idx]
|
| 247 |
+
cur_labels_noim = []
|
| 248 |
+
|
| 249 |
+
for i in range(len(image_token_indices) - 1):
|
| 250 |
+
cur_input_ids_noim.append(
|
| 251 |
+
cur_input_ids[
|
| 252 |
+
image_token_indices[i] + 1 : image_token_indices[i + 1]
|
| 253 |
+
]
|
| 254 |
+
)
|
| 255 |
+
cur_labels_noim.append(
|
| 256 |
+
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
split_sizes_text = [x.shape[0] for x in cur_labels_noim]
|
| 260 |
+
cur_input_embeds = self.get_model().embed_tokens(
|
| 261 |
+
torch.cat(cur_input_ids_noim)
|
| 262 |
+
)
|
| 263 |
+
cur_input_embeds_no_im = torch.split(
|
| 264 |
+
cur_input_embeds, split_sizes_text, dim=0
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# for multi-image inputs, there is a bug.
|
| 268 |
+
cur_new_input_embeds = []
|
| 269 |
+
cur_new_labels = []
|
| 270 |
+
text_len = sum([x.shape[0] for x in cur_input_embeds_no_im])
|
| 271 |
+
visual_len = len(final_image_features_list[cur_image_idx])
|
| 272 |
+
max_visual_len = (
|
| 273 |
+
self.get_model().config.tokenizer_model_max_length
|
| 274 |
+
- getattr(self.get_model().config, "inference_max_length", 16)
|
| 275 |
+
- text_len
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if max_visual_len < visual_len:
|
| 279 |
+
final_image_features_list[cur_image_idx] = final_image_features_list[cur_image_idx][:max_visual_len]
|
| 280 |
+
|
| 281 |
+
for i in range(num_images + 1):
|
| 282 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
| 283 |
+
cur_new_labels.append(cur_labels_noim[i])
|
| 284 |
+
|
| 285 |
+
if i < num_images:
|
| 286 |
+
try:
|
| 287 |
+
cur_image_features = final_image_features_list[cur_image_idx]
|
| 288 |
+
except IndexError:
|
| 289 |
+
print(f"cur_image_idx={cur_image_idx} is not ok, get {num_images} images!!!")
|
| 290 |
+
cur_image_features = final_image_features_list[cur_image_idx - 1]
|
| 291 |
+
|
| 292 |
+
cur_image_idx += 1
|
| 293 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 294 |
+
cur_new_labels.append(
|
| 295 |
+
torch.full(
|
| 296 |
+
(cur_image_features.shape[0],),
|
| 297 |
+
IGNORE_INDEX,
|
| 298 |
+
device=cur_labels.device,
|
| 299 |
+
dtype=cur_labels.dtype,
|
| 300 |
+
)
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
| 304 |
+
|
| 305 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
| 306 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
| 307 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 308 |
+
new_labels.append(cur_new_labels)
|
| 309 |
+
|
| 310 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
| 311 |
+
if tokenizer_model_max_length is not None:
|
| 312 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
| 313 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
| 314 |
+
|
| 315 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
| 316 |
+
batch_size = len(new_input_embeds)
|
| 317 |
+
|
| 318 |
+
new_input_embeds_padded = []
|
| 319 |
+
new_labels_padded = torch.full(
|
| 320 |
+
(batch_size, max_len),
|
| 321 |
+
IGNORE_INDEX,
|
| 322 |
+
dtype=new_labels[0].dtype,
|
| 323 |
+
device=new_labels[0].device,
|
| 324 |
+
)
|
| 325 |
+
attention_mask = torch.zeros(
|
| 326 |
+
(batch_size, max_len),
|
| 327 |
+
dtype=attention_mask.dtype,
|
| 328 |
+
device=attention_mask.device,
|
| 329 |
+
)
|
| 330 |
+
position_ids = torch.zeros(
|
| 331 |
+
(batch_size, max_len),
|
| 332 |
+
dtype=position_ids.dtype,
|
| 333 |
+
device=position_ids.device,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(
|
| 337 |
+
zip(new_input_embeds, new_labels)
|
| 338 |
+
):
|
| 339 |
+
cur_len = cur_new_embed.shape[0]
|
| 340 |
+
|
| 341 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
| 342 |
+
new_input_embeds_padded.append(
|
| 343 |
+
torch.cat(
|
| 344 |
+
(
|
| 345 |
+
torch.zeros(
|
| 346 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
| 347 |
+
dtype=cur_new_embed.dtype,
|
| 348 |
+
device=cur_new_embed.device,
|
| 349 |
+
),
|
| 350 |
+
cur_new_embed,
|
| 351 |
+
),
|
| 352 |
+
dim=0,
|
| 353 |
+
)
|
| 354 |
+
)
|
| 355 |
+
if cur_len > 0:
|
| 356 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
| 357 |
+
attention_mask[i, -cur_len:] = True
|
| 358 |
+
position_ids[i, -cur_len:] = torch.arange(
|
| 359 |
+
0,
|
| 360 |
+
cur_len,
|
| 361 |
+
dtype=position_ids.dtype,
|
| 362 |
+
device=position_ids.device,
|
| 363 |
+
)
|
| 364 |
+
else:
|
| 365 |
+
new_input_embeds_padded.append(
|
| 366 |
+
torch.cat(
|
| 367 |
+
(
|
| 368 |
+
cur_new_embed,
|
| 369 |
+
torch.zeros(
|
| 370 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
| 371 |
+
dtype=cur_new_embed.dtype,
|
| 372 |
+
device=cur_new_embed.device,
|
| 373 |
+
),
|
| 374 |
+
),
|
| 375 |
+
dim=0,
|
| 376 |
+
)
|
| 377 |
+
)
|
| 378 |
+
if cur_len > 0:
|
| 379 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
| 380 |
+
attention_mask[i, :cur_len] = True
|
| 381 |
+
position_ids[i, :cur_len] = torch.arange(
|
| 382 |
+
0,
|
| 383 |
+
cur_len,
|
| 384 |
+
dtype=position_ids.dtype,
|
| 385 |
+
device=position_ids.device,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
| 389 |
+
|
| 390 |
+
if _labels is None:
|
| 391 |
+
new_labels = None
|
| 392 |
+
else:
|
| 393 |
+
new_labels = new_labels_padded
|
| 394 |
+
|
| 395 |
+
if _attention_mask is None:
|
| 396 |
+
attention_mask = None
|
| 397 |
+
else:
|
| 398 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
| 399 |
+
|
| 400 |
+
if _position_ids is None:
|
| 401 |
+
position_ids = None
|
| 402 |
+
|
| 403 |
+
return (
|
| 404 |
+
None,
|
| 405 |
+
position_ids,
|
| 406 |
+
attention_mask,
|
| 407 |
+
past_key_values,
|
| 408 |
+
new_input_embeds,
|
| 409 |
+
new_labels,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
| 413 |
+
if model_args.mm_use_im_patch_token:
|
| 414 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 415 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 416 |
+
|
| 417 |
+
if model_args.mm_use_im_start_end:
|
| 418 |
+
num_new_tokens = tokenizer.add_tokens(
|
| 419 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
| 420 |
+
)
|
| 421 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 422 |
+
|
| 423 |
+
if num_new_tokens > 0:
|
| 424 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
| 425 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
| 426 |
+
|
| 427 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 428 |
+
dim=0, keepdim=True
|
| 429 |
+
)
|
| 430 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 431 |
+
dim=0, keepdim=True
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 435 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 436 |
+
|
| 437 |
+
if model_args.tune_mm_mlp_adapter:
|
| 438 |
+
for p in self.get_input_embeddings().parameters():
|
| 439 |
+
p.requires_grad = True
|
| 440 |
+
for p in self.get_output_embeddings().parameters():
|
| 441 |
+
p.requires_grad = False
|
| 442 |
+
|
| 443 |
+
if model_args.pretrain_mm_mlp_adapter:
|
| 444 |
+
mm_projector_weights = torch.load(
|
| 445 |
+
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
|
| 446 |
+
)
|
| 447 |
+
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
| 448 |
+
assert num_new_tokens == 2
|
| 449 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
| 450 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
|
| 451 |
+
-num_new_tokens:
|
| 452 |
+
]
|
| 453 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
| 454 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
| 455 |
+
else:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
|
| 458 |
+
)
|
| 459 |
+
elif model_args.mm_use_im_patch_token:
|
| 460 |
+
if model_args.tune_mm_mlp_adapter:
|
| 461 |
+
for p in self.get_input_embeddings().parameters():
|
| 462 |
+
p.requires_grad = False
|
| 463 |
+
for p in self.get_output_embeddings().parameters():
|
| 464 |
+
p.requires_grad = False
|
tempo/vlm_multimodal_processor.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class VLMMultimodalProcessor:
|
| 6 |
+
"""SVLM-based vision compression."""
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def tokenize_vision_inputs(vision_language_model, vlm_inputs, is_video, chunk_size=4):
|
| 10 |
+
return vision_language_model(**vlm_inputs)
|
| 11 |
+
|
| 12 |
+
@staticmethod
|
| 13 |
+
def add_seg_timestamp(vision_features, model, seg_timestamps, is_video):
|
| 14 |
+
if not is_video:
|
| 15 |
+
return vision_features
|
| 16 |
+
|
| 17 |
+
device = vision_features[0].device
|
| 18 |
+
dtype = vision_features[0].dtype
|
| 19 |
+
|
| 20 |
+
max_len = max(len(ts) for ts in seg_timestamps)
|
| 21 |
+
num_segments = len(seg_timestamps)
|
| 22 |
+
# pad_token_id = getattr(model.config, "pad_token_id", 151643)
|
| 23 |
+
pad_token_id = 151643
|
| 24 |
+
timestamp_ids_tensor = torch.full(
|
| 25 |
+
(num_segments, max_len), pad_token_id, dtype=torch.long, device=device
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
for i, ts in enumerate(seg_timestamps):
|
| 29 |
+
length = len(ts)
|
| 30 |
+
timestamp_ids_tensor[i, :length] = torch.tensor(ts, device=device)
|
| 31 |
+
|
| 32 |
+
timestamp_embeds = model.get_input_embeddings()(timestamp_ids_tensor).to(dtype)
|
| 33 |
+
|
| 34 |
+
final_vision_features = []
|
| 35 |
+
for i in range(num_segments):
|
| 36 |
+
if vision_features[i].shape[0] == 0:
|
| 37 |
+
print("drop this segment directly.")
|
| 38 |
+
continue
|
| 39 |
+
final_vision_features.append(
|
| 40 |
+
torch.cat(
|
| 41 |
+
[
|
| 42 |
+
timestamp_embeds[i][: len(seg_timestamps[i])],
|
| 43 |
+
vision_features[i],
|
| 44 |
+
],
|
| 45 |
+
dim=0,
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# return torch.cat(final_vision_features, dim=0).unsqueeze(0)
|
| 50 |
+
return torch.cat(final_vision_features, dim=0) # (comp_frame1+comp_frame2+comp_frame3+..., d)
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def tome_merge(x: torch.Tensor, target_num: int) -> torch.Tensor:
|
| 54 |
+
"""
|
| 55 |
+
Token Merging using bipartite soft matching.
|
| 56 |
+
Reference: "Token Merging: Your ViT But Faster" (Bolya et al.)
|
| 57 |
+
Args:
|
| 58 |
+
x: (n, d) tensor of token features
|
| 59 |
+
target_num: number of tokens to keep after merging
|
| 60 |
+
Returns:
|
| 61 |
+
merged tokens: (target_num, d) tensor
|
| 62 |
+
"""
|
| 63 |
+
if target_num <= 0:
|
| 64 |
+
raise ValueError("target_num must be positive")
|
| 65 |
+
|
| 66 |
+
n, d = x.shape
|
| 67 |
+
|
| 68 |
+
if target_num >= n:
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
while x.shape[0] > target_num:
|
| 72 |
+
current_n = x.shape[0]
|
| 73 |
+
|
| 74 |
+
if current_n < 2:
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
t1 = (current_n + 1) // 2 # ceil(n/2) - source token
|
| 78 |
+
t2 = current_n // 2 # floor(n/2) - target token
|
| 79 |
+
|
| 80 |
+
if t2 == 0:
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
tokens_to_remove = current_n - target_num
|
| 84 |
+
r = min(tokens_to_remove, t1)
|
| 85 |
+
|
| 86 |
+
if r <= 0:
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
x_batch = x.unsqueeze(0) # (1, n, d)
|
| 90 |
+
k = x_batch / x_batch.norm(dim=-1, keepdim=True)
|
| 91 |
+
a, b = k[..., ::2, :], k[..., 1::2, :]
|
| 92 |
+
scores = a @ b.transpose(-1, -2)
|
| 93 |
+
node_max, node_idx = scores.max(dim=-1)
|
| 94 |
+
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
| 95 |
+
unm_idx = edge_idx[..., r:, :]
|
| 96 |
+
src_idx = edge_idx[..., :r, :]
|
| 97 |
+
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
|
| 98 |
+
unm_idx = unm_idx.sort(dim=-2)[0]
|
| 99 |
+
# merge
|
| 100 |
+
src, dst = x_batch[..., ::2, :], x_batch[..., 1::2, :]
|
| 101 |
+
batch, _, c = src.shape
|
| 102 |
+
unm = src.gather(dim=-2, index=unm_idx.expand(batch, t1 - r, c))
|
| 103 |
+
src_to_merge = src.gather(dim=-2, index=src_idx.expand(batch, r, c))
|
| 104 |
+
dst = dst.scatter_add(-2, dst_idx.expand(batch, r, c), src_to_merge)
|
| 105 |
+
# pooling
|
| 106 |
+
ones = torch.ones(batch, r, 1, device=x.device, dtype=x.dtype)
|
| 107 |
+
dst_counts = torch.ones(
|
| 108 |
+
batch, dst.shape[1], 1, device=x.device, dtype=x.dtype
|
| 109 |
+
)
|
| 110 |
+
dst_counts = dst_counts.scatter_add(-2, dst_idx.expand(batch, r, 1), ones)
|
| 111 |
+
dst = dst / dst_counts
|
| 112 |
+
x = torch.cat([unm, dst], dim=-2).squeeze(0) # (new_n, d)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def topk_compress(
|
| 117 |
+
vision_features,
|
| 118 |
+
relevance_scores,
|
| 119 |
+
is_video,
|
| 120 |
+
k=0,
|
| 121 |
+
drop_ratio=0.5,
|
| 122 |
+
strategy="topk",
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Drop/Truncate low-scoring segments to keep only first k tokens based on relevance scores
|
| 126 |
+
Args:
|
| 127 |
+
vision_features: (n_segment, n, d)
|
| 128 |
+
relevance_scores: n_segment (list of scores between 0 and 1)
|
| 129 |
+
k: number of tokens to keep, k=0 means drop directly
|
| 130 |
+
drop_ratio: ratio of segments to drop/truncate
|
| 131 |
+
strategy: "topk" (keep highest), "lastk" (keep lowest), "random" (random drop)
|
| 132 |
+
Return a list of tensor with length of n_segment. Each tensor is (k, d) or (num_compression_token, d)
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
if not is_video or relevance_scores is None:
|
| 136 |
+
return vision_features
|
| 137 |
+
|
| 138 |
+
n_segment = vision_features.shape[0]
|
| 139 |
+
if n_segment <= 1:
|
| 140 |
+
print("video segment is equal/less than 1, not compressing")
|
| 141 |
+
return vision_features
|
| 142 |
+
|
| 143 |
+
if strategy == "topk":
|
| 144 |
+
# in ascending order
|
| 145 |
+
sorted_indices = sorted(range(n_segment), key=lambda i: relevance_scores[i])
|
| 146 |
+
elif strategy == "lastk":
|
| 147 |
+
# in descending order
|
| 148 |
+
sorted_indices = sorted(
|
| 149 |
+
range(n_segment), key=lambda i: relevance_scores[i], reverse=True
|
| 150 |
+
)
|
| 151 |
+
elif strategy == "random":
|
| 152 |
+
# shuffle index, random
|
| 153 |
+
sorted_indices = list(range(n_segment))
|
| 154 |
+
random.shuffle(sorted_indices)
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Unknown strategy: {strategy}")
|
| 157 |
+
|
| 158 |
+
# truncate or drop ratio
|
| 159 |
+
num_to_prune = int(n_segment * drop_ratio)
|
| 160 |
+
print(
|
| 161 |
+
f"[topk_compress] total segment: {n_segment}, drop/truncate segment: {num_to_prune}, keep first {k} token"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
low_score_indices = set(sorted_indices[:num_to_prune])
|
| 165 |
+
|
| 166 |
+
result = []
|
| 167 |
+
for i in range(n_segment):
|
| 168 |
+
if i in low_score_indices:
|
| 169 |
+
result.append(vision_features[i, :k, :]) # shape: (k, d)
|
| 170 |
+
else:
|
| 171 |
+
result.append(vision_features[i]) # shape: (n, d)
|
| 172 |
+
|
| 173 |
+
print(
|
| 174 |
+
f"[topk_compress] segments: {n_segment} (pruned={num_to_prune}, kept={n_segment - num_to_prune}), "
|
| 175 |
+
f"kept tokens for pruned segment={k}"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return result
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def adaptive_linear_budget_allocation(
|
| 182 |
+
vision_features,
|
| 183 |
+
relevance_scores,
|
| 184 |
+
is_video,
|
| 185 |
+
max_budget=8192,
|
| 186 |
+
min_tokens=4,
|
| 187 |
+
max_tokens=None,
|
| 188 |
+
strategy="head",
|
| 189 |
+
):
|
| 190 |
+
"""
|
| 191 |
+
soft token allocation based on min-max normalized relevance scores, Uses linear mapping instead of softmax, providing more aggressive sparsity
|
| 192 |
+
Args:
|
| 193 |
+
vision_features: (n_segment, n_tokens, d)
|
| 194 |
+
relevance_scores: list/array of scores
|
| 195 |
+
is_video: bool
|
| 196 |
+
max_budget: the largest budget for a video
|
| 197 |
+
min_tokens: minimum number of tokens to allocate to each segment
|
| 198 |
+
max_tokens: maximum number of tokens to allocate to each segment
|
| 199 |
+
strategy: "head", "tail", "random", "tome"
|
| 200 |
+
Return a list of tensor with length of n_segment. Each tensor is (k, d), where k in [min_tokens, max_tokens]
|
| 201 |
+
"""
|
| 202 |
+
if not is_video or relevance_scores is None:
|
| 203 |
+
return vision_features, None
|
| 204 |
+
|
| 205 |
+
n_segments = vision_features.shape[0]
|
| 206 |
+
n_tokens_per_segment = vision_features.shape[1]
|
| 207 |
+
max_tokens = min(max_tokens or n_tokens_per_segment, n_tokens_per_segment)
|
| 208 |
+
|
| 209 |
+
base_budget = n_segments * min_tokens
|
| 210 |
+
if base_budget > max_budget:
|
| 211 |
+
actual_min = max(1, max_budget // n_segments)
|
| 212 |
+
print(
|
| 213 |
+
f"[adaptive_linear_budget_allocation] Warning: budget insufficient, "
|
| 214 |
+
f"min_tokens: {min_tokens} -> {actual_min}"
|
| 215 |
+
)
|
| 216 |
+
min_tokens = actual_min
|
| 217 |
+
base_budget = n_segments * min_tokens
|
| 218 |
+
|
| 219 |
+
# Convert to tensor
|
| 220 |
+
scores = torch.tensor(relevance_scores, dtype=torch.float32)
|
| 221 |
+
score_min = scores.min()
|
| 222 |
+
score_max = scores.max()
|
| 223 |
+
score_range = score_max - score_min
|
| 224 |
+
|
| 225 |
+
if score_range < 1e-8:
|
| 226 |
+
k = min(max_tokens, max_budget // n_segments)
|
| 227 |
+
return [vision_features[i, :k, :] for i in range(n_segments)], torch.full(
|
| 228 |
+
(n_segments,), k, dtype=torch.long
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Normalize scores to [0, 1] range
|
| 232 |
+
normalized_scores = (scores - score_min) / score_range
|
| 233 |
+
|
| 234 |
+
# Linear mapping: [0, 1] -> [min_tokens, max_tokens]
|
| 235 |
+
token_range = max_tokens - min_tokens
|
| 236 |
+
ideal_allocations = (
|
| 237 |
+
min_tokens + (normalized_scores * token_range).floor().long()
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# ========== Budget Protection ==========
|
| 241 |
+
total_desired = ideal_allocations.sum().item()
|
| 242 |
+
|
| 243 |
+
if total_desired > max_budget:
|
| 244 |
+
# Only scale down when total demand exceeds budget
|
| 245 |
+
print(
|
| 246 |
+
f"[adaptive_linear_budget_allocation] Warning: Desired budget "
|
| 247 |
+
f"({total_desired}) exceeds max ({max_budget}). Scaling down..."
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
extra_budget = max_budget - base_budget
|
| 251 |
+
|
| 252 |
+
# Prevent division by zero
|
| 253 |
+
score_sum = normalized_scores.sum().item()
|
| 254 |
+
if score_sum < 1e-8:
|
| 255 |
+
# Fallback to uniform distribution if all normalized scores are ~0
|
| 256 |
+
weights = torch.ones_like(normalized_scores) / n_segments
|
| 257 |
+
else:
|
| 258 |
+
weights = normalized_scores / score_sum
|
| 259 |
+
|
| 260 |
+
# Distribute extra budget proportionally
|
| 261 |
+
extra_allocation = (weights * extra_budget).floor().long()
|
| 262 |
+
remainder = int(extra_budget - extra_allocation.sum().item())
|
| 263 |
+
if remainder > 0:
|
| 264 |
+
top_indices = torch.argsort(weights, descending=True)[:remainder]
|
| 265 |
+
for idx in top_indices:
|
| 266 |
+
extra_allocation[idx] += 1
|
| 267 |
+
|
| 268 |
+
allocations = min_tokens + extra_allocation
|
| 269 |
+
else:
|
| 270 |
+
allocations = ideal_allocations
|
| 271 |
+
|
| 272 |
+
# Final clamp to ensure bounds
|
| 273 |
+
allocations = allocations.clamp(min=min_tokens, max=max_tokens)
|
| 274 |
+
|
| 275 |
+
# default, head truncation
|
| 276 |
+
if strategy == "head":
|
| 277 |
+
result = []
|
| 278 |
+
for i in range(n_segments):
|
| 279 |
+
k = min(allocations[i].item(), n_tokens_per_segment)
|
| 280 |
+
result.append(vision_features[i, :k, :])
|
| 281 |
+
|
| 282 |
+
elif strategy == "tail":
|
| 283 |
+
result = []
|
| 284 |
+
for i in range(n_segments):
|
| 285 |
+
k = min(allocations[i].item(), n_tokens_per_segment)
|
| 286 |
+
if k == 0:
|
| 287 |
+
result.append(vision_features[i, :0, :])
|
| 288 |
+
else:
|
| 289 |
+
result.append(vision_features[i, -k:, :])
|
| 290 |
+
|
| 291 |
+
elif strategy == "random":
|
| 292 |
+
result = []
|
| 293 |
+
for i in range(n_segments):
|
| 294 |
+
k = min(allocations[i].item(), n_tokens_per_segment)
|
| 295 |
+
if k == 0:
|
| 296 |
+
result.append(vision_features[i, :0, :])
|
| 297 |
+
else:
|
| 298 |
+
indices = torch.randperm(
|
| 299 |
+
n_tokens_per_segment, device=vision_features.device
|
| 300 |
+
)[:k]
|
| 301 |
+
indices = indices.sort().values
|
| 302 |
+
result.append(vision_features[i, indices, :])
|
| 303 |
+
|
| 304 |
+
elif strategy == "tome":
|
| 305 |
+
result = []
|
| 306 |
+
for i in range(n_segments):
|
| 307 |
+
k = min(allocations[i].item(), n_tokens_per_segment)
|
| 308 |
+
if k == 0:
|
| 309 |
+
result.append(vision_features[i, :0, :])
|
| 310 |
+
elif k == n_tokens_per_segment:
|
| 311 |
+
result.append(vision_features[i])
|
| 312 |
+
else:
|
| 313 |
+
# Token Merging
|
| 314 |
+
result.append(
|
| 315 |
+
VLMMultimodalProcessor.tome_merge(
|
| 316 |
+
vision_features[i], target_num=k
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
else:
|
| 320 |
+
raise ValueError(f"Unknown strategy: {strategy}")
|
| 321 |
+
|
| 322 |
+
total_used = allocations.sum().item()
|
| 323 |
+
|
| 324 |
+
print(
|
| 325 |
+
f"[adaptive_linear_budget_allocation] segments={n_segments}, "
|
| 326 |
+
f"budget_used={total_used}/{max_budget}, "
|
| 327 |
+
f"theoretical_range=[{min_tokens}, {max_tokens}], "
|
| 328 |
+
f"actual_range=[{allocations.min().item():.0f}, {allocations.max().item():.0f}]",
|
| 329 |
+
flush=True,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
return result, allocations
|