FeiElysia commited on
Commit
161b19e
Β·
1 Parent(s): 5087c07

πŸš€ Initial deploy

Browse files
Files changed (48) hide show
  1. .gitattributes +2 -0
  2. app.py +384 -0
  3. examples/cover_videomme_FjS2LzrHEO8.png +3 -0
  4. examples/cover_videomme_fFjv93ACGo8.png +3 -0
  5. examples/demo.mp4 +3 -0
  6. examples/demo_cases.json +40 -0
  7. examples/description_honkai3_becauseofyou.png +3 -0
  8. examples/honkai3_becauseofyou.mp4 +3 -0
  9. examples/hsr_helloworld.mp4 +3 -0
  10. examples/lvbench_gXnhqF0TqqI.mp4 +3 -0
  11. examples/meme_hsr_helloworld.png +3 -0
  12. examples/ocr_honkai3_becauseofyou.png +3 -0
  13. examples/performance_hsr_helloworld.png +3 -0
  14. examples/tempo.png +3 -0
  15. examples/tempo.svg +0 -0
  16. examples/videomme_FjS2LzrHEO8.mp4 +3 -0
  17. examples/videomme_FsLaTZmP6Uw.mp4 +3 -0
  18. examples/videomme_Sp2nxlrQ89w.mp4 +3 -0
  19. examples/videomme_fFjv93ACGo8.mp4 +3 -0
  20. packages.txt +1 -0
  21. requirements.txt +11 -0
  22. tempo/__init__.py +6 -0
  23. tempo/__pycache__/__init__.cpython-312.pyc +0 -0
  24. tempo/__pycache__/builder.cpython-312.pyc +0 -0
  25. tempo/__pycache__/constants.cpython-312.pyc +0 -0
  26. tempo/__pycache__/conversation.cpython-312.pyc +0 -0
  27. tempo/__pycache__/mm_datautils.cpython-312.pyc +0 -0
  28. tempo/__pycache__/mm_utils.cpython-312.pyc +0 -0
  29. tempo/__pycache__/tempo_arch.cpython-312.pyc +0 -0
  30. tempo/__pycache__/vlm_multimodal_processor.cpython-312.pyc +0 -0
  31. tempo/builder.py +62 -0
  32. tempo/constants.py +13 -0
  33. tempo/conversation.py +545 -0
  34. tempo/language_model/__pycache__/modeling_tempo_qwen.cpython-312.pyc +0 -0
  35. tempo/language_model/modeling_tempo_qwen.py +231 -0
  36. tempo/mm_datautils.py +1607 -0
  37. tempo/multimodal_encoder/__pycache__/base_encoder.cpython-312.pyc +0 -0
  38. tempo/multimodal_encoder/__pycache__/builder.cpython-312.pyc +0 -0
  39. tempo/multimodal_encoder/__pycache__/qwen3vl_encoder.cpython-312.pyc +0 -0
  40. tempo/multimodal_encoder/__pycache__/siglip_encoder.cpython-312.pyc +0 -0
  41. tempo/multimodal_encoder/base_encoder.py +135 -0
  42. tempo/multimodal_encoder/builder.py +21 -0
  43. tempo/multimodal_encoder/qwen3vl_encoder.py +336 -0
  44. tempo/multimodal_encoder/siglip_encoder.py +75 -0
  45. tempo/multimodal_projector/__pycache__/builder.cpython-312.pyc +0 -0
  46. tempo/multimodal_projector/builder.py +51 -0
  47. tempo/tempo_arch.py +464 -0
  48. tempo/vlm_multimodal_processor.py +332 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ import torch
5
+ import numpy as np
6
+ import gradio as gr
7
+ import multiprocessing
8
+ from decord import cpu, VideoReader
9
+
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.ticker as ticker
12
+ import matplotlib.colors as mcolors
13
+ from scipy.interpolate import make_interp_spline
14
+ from PIL import Image
15
+
16
+ from tempo.builder import load_pretrained_model
17
+ from tempo.conversation import conv_templates, SeparatorStyle
18
+ from tempo.constants import (
19
+ DEFAULT_IM_END_TOKEN,
20
+ DEFAULT_IM_START_TOKEN,
21
+ DEFAULT_IMAGE_TOKEN,
22
+ IMAGE_TOKEN_INDEX,
23
+ )
24
+ from tempo.mm_datautils import (
25
+ compute_segment_timestamp,
26
+ KeywordsStoppingCriteria,
27
+ process_qwen_content,
28
+ tokenizer_image_token,
29
+ )
30
+
31
+ import spaces
32
+ from huggingface_hub import snapshot_download
33
+
34
+ def get_real_cpu_cores():
35
+ """use multiple threads for video decoding"""
36
+ try:
37
+ # HF Spaces
38
+ cores = len(os.sched_getaffinity(0))
39
+ except AttributeError:
40
+ # Local environments
41
+ cores = multiprocessing.cpu_count()
42
+ return cores
43
+
44
+ def compute_sample_indices(
45
+ total_frames: int,
46
+ original_fps: float,
47
+ video_fps: float = 2.0,
48
+ min_frames_num: int = 4,
49
+ max_frames_num: int = 1024
50
+ ) -> list[int]:
51
+
52
+ start_frame, end_frame = 0, total_frames - 1
53
+ clip_frames = end_frame - start_frame + 1
54
+ if clip_frames <= 1:
55
+ return [start_frame]
56
+
57
+ if original_fps is None or original_fps <= 0:
58
+ original_fps = video_fps
59
+
60
+ clip_duration = clip_frames / original_fps
61
+ target_num_frames = max(1, round(clip_duration * video_fps))
62
+ final_num_frames = min(max(target_num_frames, min_frames_num), max_frames_num)
63
+
64
+ if final_num_frames == 1:
65
+ return [end_frame]
66
+
67
+ indices = np.round(np.linspace(start_frame, end_frame, final_num_frames)).astype(int)
68
+ indices = np.clip(indices, start_frame, end_frame)
69
+
70
+ return indices.tolist()
71
+
72
+ def load_video(video_path: str, video_fps: float = 2.0, max_frames: int = 1024) -> tuple:
73
+
74
+ available_cores = get_real_cpu_cores()
75
+ optimal_threads = min(max(1, available_cores - 1), 16)
76
+ print(f"[Profiling] Detected {available_cores} CPU cores. Decord using {optimal_threads} threads.")
77
+
78
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=optimal_threads)
79
+ total_frames = len(vr)
80
+ original_fps = vr.get_avg_fps()
81
+ frame_idx = compute_sample_indices(total_frames, original_fps, video_fps, max_frames_num=max_frames)
82
+ images = vr.get_batch(frame_idx).asnumpy()
83
+ clip_duration = total_frames / original_fps
84
+
85
+ real_fps = len(images) / clip_duration if clip_duration > 0 else video_fps
86
+
87
+ return images, real_fps
88
+
89
+ def generate_allocation_plot(allocations):
90
+ """
91
+ Token allocation visualization function
92
+ """
93
+ if allocations is None or len(allocations) == 0:
94
+ # if disable_dynamic_compress is True, we return a blank image
95
+ return Image.new('RGB', (1600, 350), color='white')
96
+
97
+ allocations = np.array(allocations)
98
+ num_segments = len(allocations)
99
+
100
+ plt.rcParams.update({'font.size': 14, 'font.family': 'serif'})
101
+ fig = plt.figure(figsize=(16, 3.5), layout='constrained')
102
+ gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1.0], hspace=0.05)
103
+
104
+
105
+ ax_heat = fig.add_subplot(gs[0])
106
+ ax_heat.set_title(" ", pad=50)
107
+
108
+ colors = ["#EBF5FB", "#85C1E9", "#F2D7D5", "#E74C3C", "#641E16"]
109
+ cmap_custom = mcolors.LinearSegmentedColormap.from_list("custom_heat", colors)
110
+
111
+ vmax_val = max(128, allocations.max())
112
+ ax_heat.imshow([allocations], cmap=cmap_custom, aspect='auto', extent=[0.5, num_segments + 0.5, 0, 1], vmin=4, vmax=vmax_val)
113
+ ax_heat.set_yticks([])
114
+ ax_heat.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
115
+ for spine in ax_heat.spines.values():
116
+ spine.set_linewidth(1.2)
117
+
118
+ ax_line = fig.add_subplot(gs[1], sharex=ax_heat)
119
+
120
+ x = np.arange(1, num_segments + 1)
121
+
122
+ if num_segments > 3:
123
+ spl = make_interp_spline(x, allocations, k=3)
124
+ x_smooth = np.linspace(1, num_segments, 800)
125
+ y_smooth = spl(x_smooth)
126
+ y_smooth = np.clip(y_smooth, 4, vmax_val)
127
+ else:
128
+ x_smooth = x
129
+ y_smooth = allocations
130
+
131
+ line_color = '#1A252C'
132
+ fill_color = '#D5D8DC'
133
+
134
+ ax_line.plot(x_smooth, y_smooth, color=line_color, linewidth=2.0)
135
+ ax_line.fill_between(x_smooth, y_smooth, color=fill_color, alpha=0.4)
136
+
137
+ ax_line.axhline(vmax_val, color='#C0392B', linestyle='--', linewidth=1.2, alpha=0.8)
138
+ ax_line.axhline(4, color='#2980B9', linestyle='--', linewidth=1.2, alpha=0.8)
139
+
140
+ ax_line.set_xlim(0.5, num_segments + 0.5)
141
+ ax_line.set_ylim(0, vmax_val + 12)
142
+ ax_line.set_ylabel("Tokens / Seg", fontsize=14, fontweight='bold')
143
+ ax_line.set_xlabel("Temporal Segments", fontsize=14, fontweight='bold')
144
+ ax_line.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
145
+
146
+ ax_line.spines['top'].set_visible(False)
147
+ ax_line.spines['right'].set_visible(False)
148
+ ax_line.spines['bottom'].set_linewidth(1.2)
149
+ ax_line.spines['left'].set_linewidth(1.2)
150
+ ax_line.grid(axis='y', linestyle=':', color='gray', alpha=0.5)
151
+
152
+ buf = io.BytesIO()
153
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=100, transparent=True)
154
+ plt.close(fig)
155
+ buf.seek(0)
156
+ return Image.open(buf)
157
+
158
+ model_id = "Vision-CAIR/Tempo-6B"
159
+ print(f"[Init] Downloading/Loading weights from {model_id}...")
160
+ MODEL_PATH = snapshot_download(repo_id=model_id)
161
+ print(f"[Init] Loading Tempo model from {MODEL_PATH}...")
162
+ tokenizer, model, image_processor = load_pretrained_model(
163
+ MODEL_PATH,
164
+ device_map="cuda",
165
+ use_flash_attn=True
166
+ )
167
+
168
+ FIXED_MAX_LENGTH = 16384
169
+ model.config.tokenizer_model_max_length = FIXED_MAX_LENGTH
170
+ tokenizer.model_max_length = FIXED_MAX_LENGTH
171
+ model.eval()
172
+ model.to(torch.bfloat16)
173
+ print(f"[Init] Model loaded! Max context length set to {FIXED_MAX_LENGTH}.")
174
+
175
+
176
+ # ==========================================
177
+ # inference
178
+ # ==========================================
179
+ @spaces.GPU
180
+ def predict(video_path, query, max_frames, visual_token_budget, temperature, max_new_tokens, disable_dynamic_compress):
181
+ if not video_path:
182
+ return "⚠️ Error: Please upload a video first."
183
+ if not query:
184
+ return "⚠️ Error: Please enter a question."
185
+
186
+ print(f"\n[Request] Video: {video_path} | Query: {query}")
187
+
188
+ model.config.visual_token_budget = int(visual_token_budget)
189
+ model.get_vision_tower_aux_list()[0].dynamic_compress = not disable_dynamic_compress
190
+
191
+ # video process
192
+ start_prep_time = time.perf_counter()
193
+ try:
194
+ video_frames, real_fps = load_video(video_path, video_fps=2.0, max_frames=int(max_frames))
195
+ except Exception as e:
196
+ return f"⚠️ Error loading video: {str(e)}"
197
+
198
+ # process local compressor inputs
199
+ frame_windows, frame_stride = 8, 8
200
+ vlm_inputs = process_qwen_content(
201
+ video_frames, "video", query, image_processor[0], real_fps, frame_windows, frame_stride, is_eval=True
202
+ )
203
+ vlm_inputs = {key: v.cuda() for key, v in vlm_inputs.items()}
204
+
205
+ # compute timestamp for each segment
206
+ seg_timestamps = compute_segment_timestamp(
207
+ len(vlm_inputs["video_grid_thw"]), tokenizer, real_fps, frame_stride, frame_windows
208
+ )
209
+
210
+ # stat info
211
+ num_segments = len(vlm_inputs["video_grid_thw"])
212
+ segment_duration = frame_windows / real_fps
213
+ stats_info = f"🎬 Video Stats: Total Segments: {num_segments} | Segment Duration: {segment_duration:.2f}s | Real FPS: {real_fps:.2f}"
214
+
215
+ # prompt
216
+ if getattr(model.config, "mm_use_im_start_end", False):
217
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + query
218
+ else:
219
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + query
220
+
221
+ conv_version = "qwen"
222
+ conv = conv_templates[conv_version].copy()
223
+ conv.append_message(conv.roles[0], qs)
224
+ conv.append_message(conv.roles[1], None)
225
+ prompt = conv.get_prompt()
226
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
227
+
228
+ # tokenization
229
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
230
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
231
+
232
+ model._demo_count_allocations = []
233
+
234
+ start_infer_time = time.perf_counter()
235
+
236
+ # generating
237
+ with torch.inference_mode():
238
+ output_ids = model.generate(
239
+ input_ids,
240
+ images=None, # Qwen-VL architecture usually uses vlm_inputs instead of raw images in kwargs if projector is vlm
241
+ image_sizes=None,
242
+ do_sample=(temperature > 0),
243
+ temperature=temperature if temperature > 0 else None,
244
+ max_new_tokens=int(max_new_tokens),
245
+ use_cache=True,
246
+ stopping_criteria=[stopping_criteria],
247
+ vlm_inputs=vlm_inputs,
248
+ seg_timestamps=seg_timestamps,
249
+ )
250
+
251
+ end_infer_time = time.perf_counter()
252
+
253
+ if isinstance(output_ids, tuple):
254
+ output_ids = output_ids[0]
255
+
256
+ prep_duration = start_infer_time - start_prep_time
257
+ infer_duration = end_infer_time - start_infer_time
258
+ total_duration = end_infer_time - start_prep_time
259
+ stats_info += f"\n⚑ Profiling : Prep Time: {prep_duration:.2f}s | Inference Time: {infer_duration:.2f}s | Total: {total_duration:.2f}s"
260
+
261
+ pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
262
+ if pred.endswith(stop_str):
263
+ pred = pred[: -len(stop_str)].strip()
264
+
265
+ # token allocation plot
266
+ allocations_data = model._demo_count_allocations
267
+ plot_img = generate_allocation_plot(allocations_data)
268
+
269
+ return pred, plot_img, stats_info
270
+
271
+ # ==========================================
272
+ # UI
273
+ # ==========================================
274
+ with gr.Blocks(title="Tempo Video Understanding", theme=gr.themes.Soft()) as demo:
275
+ gr.Markdown(
276
+ """
277
+ # ⏱️ Tempo: Small Vision-Language Models are Smart Compressors for Long Video Understanding
278
+ Upload a video and ask any question! Tempo dynamically compresses visual tokens based on your query to achieve SOTA performance.
279
+ **[🏠 Project Page](https://feielysia.github.io/)** | **[πŸ’» GitHub](https://github.com/FeiElysia)** | **[πŸ“„ Paper](https://arxiv.org/abs/xxxx)** | **[πŸ‘¨β€πŸ’» @Junjie Fei](https://feielysia.github.io/)**
280
+
281
+ *⏳ **Slow preprocessing?** Try Examples 4 & 5 below, decrease `Max Sampled Frames` in Advanced Settings, or check our [GitHub](https://github.com/FeiElysia) for full-speed local deployment.*
282
+ """
283
+ )
284
+
285
+ with gr.Row():
286
+ # left column: inputs
287
+ with gr.Column(scale=1):
288
+ video_input = gr.Video(label="Upload Video")
289
+ example_poster = gr.Image(label="Video Poster", interactive=False, height=150, visible=False)
290
+ query_input = gr.Textbox(label="Your Question", placeholder="e.g., What is the person doing in the video?", lines=3)
291
+ with gr.Row():
292
+ clear_btn = gr.Button("🧹 Clear", variant="secondary")
293
+ submit_btn = gr.Button("πŸš€ Generate Response", variant="primary")
294
+
295
+ # hyperparameters
296
+ with gr.Accordion("Advanced Settings", open=False):
297
+ max_frames_slider = gr.Slider(minimum=16, maximum=2048, value=1024, step=16, label="Max Sampled Frames")
298
+ budget_slider = gr.Slider(minimum=64, maximum=16384, value=8192, step=64, label="Visual Token Budget")
299
+ temp_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature (0 = Greedy)")
300
+ max_tokens_slider = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max New Tokens")
301
+ disable_compress_chk = gr.Checkbox(label="Disable Dynamic Compression (Baseline)", value=False)
302
+
303
+ # right column: outputs
304
+ with gr.Column(scale=1):
305
+ output_text = gr.Textbox(label="Tempo Response", lines=12, interactive=False)
306
+ stats_text = gr.Textbox(label="πŸ“Š Video Segment Stats", lines=1, interactive=False)
307
+ output_plot = gr.Image(label="Query-Aware Visual Feature Intensity (Visual Token Allocation)", interactive=False, height=180)
308
+
309
+ # clicking submit_btn or pressing enter in query_input will trigger prediction
310
+ submit_btn.click(
311
+ fn=predict,
312
+ inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk],
313
+ outputs=[output_text, output_plot, stats_text]
314
+ )
315
+
316
+ query_input.submit(
317
+ fn=predict,
318
+ inputs=[video_input, query_input, max_frames_slider, budget_slider, temp_slider, max_tokens_slider, disable_compress_chk],
319
+ outputs=[output_text, output_plot, stats_text]
320
+ )
321
+
322
+ clear_btn.click(
323
+ fn=lambda: (None, None, None, None, None, None),
324
+ inputs=None,
325
+ outputs=[video_input, example_poster, query_input, output_text, stats_text, output_plot]
326
+ )
327
+
328
+ # Examples
329
+ gr.Markdown("---")
330
+ gr.Markdown("### πŸ’‘ Try an Example")
331
+ gr.Examples(
332
+ examples=[
333
+ [
334
+ "examples/hsr_helloworld.mp4",
335
+ "Task: Please examine the provided media and answer the following three questions regarding the specific puppy in the scene:\n"
336
+ "Q1: What is the primary fur color of the puppy positioned on the swing?\n"
337
+ "Q2: Specify the exact time interval (in seconds, e.g., XX-XXs) during which the puppy is seen sitting on the swing.\n"
338
+ "Q3: Provide a brief description of the puppy's appearance and its surroundings.",
339
+ "examples/meme_hsr_helloworld.png"
340
+ ],
341
+ [
342
+ "examples/hsr_helloworld.mp4",
343
+ "Task: Please analyze the provided video and answer the following 7 questions precisely.\n"
344
+ "Q1: How many performers are visible on the stage?\n"
345
+ "Q2: Describe the architectural elements in the background. What historical civilization do they remind you of?\n"
346
+ "Q3: What is happening in the night sky above the performers, and what does this suggest about the event?\n"
347
+ "Q4: List the hair colors of the performers in order from left to right.\n"
348
+ "Q5: Identify the specific musical instrument being played by the performer located on the far left of the stage.\n"
349
+ "Q6: What is the specific time interval (in seconds, e.g., XX-XXs) during which this fireworks performance scene occurs in the video?\n"
350
+ "Q7: Look at the audience in the foreground. How does their silhouette-like depiction affect the viewer's perspective of the stage?",
351
+ "examples/performance_hsr_helloworld.png"
352
+ ],
353
+ [
354
+ "examples/honkai3_becauseofyou.mp4",
355
+ "What text appears in the center of the video behind a sea of pink flowers?",
356
+ "examples/ocr_honkai3_becauseofyou.png"
357
+ ],
358
+ [
359
+ "examples/videomme_fFjv93ACGo8.mp4",
360
+ "How many red socks are above the fireplace at the end of this video?",
361
+ "examples/cover_videomme_fFjv93ACGo8.png"
362
+ ],
363
+ [
364
+ "examples/videomme_FjS2LzrHEO8.mp4",
365
+ "What was the purpose of using a hammer to hit the car in the video?\n"
366
+ "A. To show the hammer works well.\n"
367
+ "B. To show the solidity of the car.\n"
368
+ "C. To warn people not to hit cars with hammers.\n"
369
+ "D. To illustrate that a hammer is harder than a bullet.",
370
+ "examples/cover_videomme_FjS2LzrHEO8.png"
371
+ ],
372
+ [
373
+ "examples/honkai3_becauseofyou.mp4",
374
+ "Describe the video in detail.",
375
+ "examples/description_honkai3_becauseofyou.png"
376
+ ]
377
+ ],
378
+ inputs=[video_input, query_input, example_poster],
379
+ cache_examples=False,
380
+ )
381
+
382
+ if __name__ == "__main__":
383
+ demo.queue().launch(share=True)
384
+
examples/cover_videomme_FjS2LzrHEO8.png ADDED

Git LFS Details

  • SHA256: 4dd8043baab3af724d2115de6a90d8015b356a545a617490091b7f98592662f4
  • Pointer size: 132 Bytes
  • Size of remote file: 4.01 MB
examples/cover_videomme_fFjv93ACGo8.png ADDED

Git LFS Details

  • SHA256: 43d8ae2c4416c09fdda0b305ad3a0264d5c61308d6740067818078e8a64ef79d
  • Pointer size: 133 Bytes
  • Size of remote file: 17.4 MB
examples/demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec316c8e5fe7f2a62137060c0a35123de75802522816c74478c69dbc329d45da
3
+ size 85447789
examples/demo_cases.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "video_path": "./examples/hsr_helloworld.mp4",
4
+ "query": "Task: Please examine the provided media and answer the following three questions regarding the specific puppy in the scene:\nQ1: What is the primary fur color of the puppy positioned on the swing?\nQ2: Specify the exact time interval (in seconds, e.g., XX-XXs) during which the puppy is seen sitting on the swing.\nQ3: Provide a brief description of the puppy's appearance and its surroundings."
5
+ },
6
+ {
7
+ "video_path": "./examples/hsr_helloworld.mp4",
8
+ "query": "Task: Please analyze the provided video and answer the following 7 questions precisely.\nQ1: How many performers are visible on the stage?\nQ2: Describe the architectural elements in the background. What historical civilization do they remind you of?\nQ3: What is happening in the night sky above the performers, and what does this suggest about the event?\nQ4: List the hair colors of the performers in order from left to right.\nQ5: Identify the specific musical instrument being played by the performer located on the far left of the stage.\nQ6: What is the specific time interval (in seconds, e.g., XX-XXs) during which this fireworks performance scene occurs in the video?\nQ7: Look at the audience in the foreground. How does their silhouette-like depiction affect the viewer's perspective of the stage?"
9
+ },
10
+ {
11
+ "video_path": "./examples/honkai3_becauseofyou.mp4",
12
+ "query": "What text appears in the center of the video behind a sea of pink flowers?"
13
+ },
14
+ {
15
+ "video_path": "./examples/honkai3_becauseofyou.mp4",
16
+ "query": "Describe the video in detail."
17
+ },
18
+ {
19
+ "video_path": "./examples/videomme_fFjv93ACGo8.mp4",
20
+ "query": "What colors are the clothes worn by the two announcers in the studio?"
21
+ },
22
+ {
23
+ "video_path": "./examples/videomme_FjS2LzrHEO8.mp4",
24
+ "query": "What was the purpose of using a hammer to hit the car in the video?\nA. To show the hammer works well.\nB. To show the solidity of the car.\nC. To warn people not to hit cars with hammers.\nD. To illustrate that a hammer is harder than a bullet."
25
+ },
26
+ {
27
+ "video_path": "./examples/videomme_FsLaTZmP6Uw.mp4",
28
+ "query": "Which year was the game held?"
29
+ },
30
+ {
31
+ "video_path": "./examples/videomme_Sp2nxlrQ89w.mp4",
32
+ "query": "In line with the video evidence, why does the orange stickman want to destroy the Minecraft world?\nA. He wants to save his son.\nB. He is too sad.\nC. He loses his son.\nD. He does like the world."
33
+ },
34
+ {
35
+ "video_path": "./examples/lvbench_gXnhqF0TqqI.mp4",
36
+ "query": "Where are the woman and children when they first appear in the video?"
37
+ }
38
+
39
+
40
+ ]
examples/description_honkai3_becauseofyou.png ADDED

Git LFS Details

  • SHA256: 29dff9fea96a3f57f539da37c40aba481adee56082bd2ab0cc695f319ade3a76
  • Pointer size: 132 Bytes
  • Size of remote file: 5.9 MB
examples/honkai3_becauseofyou.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12f943f66683a6a5a49888d4651dd868d75ce4be4fc35ea5af0853b5921de62f
3
+ size 43919755
examples/hsr_helloworld.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd2fb2c3a84719bb586925e68acb70de282020268a42bfef53635b88c8e6afcd
3
+ size 17860778
examples/lvbench_gXnhqF0TqqI.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7135161f9b45bf968ec320cc27b3c381593f47dbd8f17013a11a59da50980315
3
+ size 276007868
examples/meme_hsr_helloworld.png ADDED

Git LFS Details

  • SHA256: 4adcaa6d15600f1cb128e84c3524fff4ab333d5239f4a38c2a576663e2472a82
  • Pointer size: 132 Bytes
  • Size of remote file: 8.77 MB
examples/ocr_honkai3_becauseofyou.png ADDED

Git LFS Details

  • SHA256: c7c6b756d1ec0b6cf2797a0b229c99c19feb436b0cf0e27378f84baf7853cdca
  • Pointer size: 131 Bytes
  • Size of remote file: 456 kB
examples/performance_hsr_helloworld.png ADDED

Git LFS Details

  • SHA256: 42cbe2fd725ba17dc380748680b706cb0d070716c165953e89848e0d540d999f
  • Pointer size: 133 Bytes
  • Size of remote file: 12.4 MB
examples/tempo.png ADDED

Git LFS Details

  • SHA256: da600c7bca0b107f1d7696cb2b8876d6ecfdda8742bf966751e99035c1cfa0fa
  • Pointer size: 131 Bytes
  • Size of remote file: 863 kB
examples/tempo.svg ADDED
examples/videomme_FjS2LzrHEO8.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43b9cb496d7e3c56b9de35e777df206f84c99a907bdbd569a788ba098ffd46e8
3
+ size 6328020
examples/videomme_FsLaTZmP6Uw.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c68ee0eab9c6ef905de336222e127e7b393e0a5d862c2cdc5997f50b3b130d48
3
+ size 8838644
examples/videomme_Sp2nxlrQ89w.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67436aa59d44434e2b115bc5d6e07ea7e06a6de1bdfaea7c91ca320fbdd787b3
3
+ size 136659511
examples/videomme_fFjv93ACGo8.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba82f02cf6acc6a25c6efc4783fc2252c65afbedb8dbc4c7148af08834fdf999
3
+ size 17126035
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ qwen-vl-utils==0.0.14
3
+ transformers==4.57.1
4
+ accelerate==1.13.0
5
+ gradio
6
+ matplotlib
7
+ scipy
8
+ huggingface_hub
9
+ decord
10
+ av
11
+ flash-attn==2.7.4.post1
tempo/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from tempo.language_model.modeling_tempo_qwen import TempoConfig, TempoQwenForCausalLM
3
+ __all__ = [
4
+ "TempoConfig",
5
+ "TempoQwenForCausalLM",
6
+ ]
tempo/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (305 Bytes). View file
 
tempo/__pycache__/builder.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
tempo/__pycache__/constants.cpython-312.pyc ADDED
Binary file (556 Bytes). View file
 
tempo/__pycache__/conversation.cpython-312.pyc ADDED
Binary file (20.3 kB). View file
 
tempo/__pycache__/mm_datautils.cpython-312.pyc ADDED
Binary file (57.1 kB). View file
 
tempo/__pycache__/mm_utils.cpython-312.pyc ADDED
Binary file (3.07 kB). View file
 
tempo/__pycache__/tempo_arch.cpython-312.pyc ADDED
Binary file (17.5 kB). View file
 
tempo/__pycache__/vlm_multimodal_processor.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
tempo/builder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from tempo.constants import (
17
+ DEFAULT_IM_END_TOKEN,
18
+ DEFAULT_IM_START_TOKEN,
19
+ DEFAULT_IMAGE_PATCH_TOKEN,
20
+ )
21
+ from transformers import AutoTokenizer
22
+
23
+ from tempo.language_model.modeling_tempo_qwen import TempoQwenForCausalLM
24
+
25
+ def load_pretrained_model(
26
+ model_path,
27
+ device_map="auto",
28
+ device="cuda",
29
+ use_flash_attn=False,
30
+ **kwargs,
31
+ ):
32
+ kwargs = {"device_map": device_map, **kwargs}
33
+
34
+ if device != "cuda":
35
+ kwargs["device_map"] = {"": device}
36
+
37
+ kwargs["dtype"] = torch.float16
38
+ if use_flash_attn:
39
+ kwargs["attn_implementation"] = "flash_attention_2"
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
42
+ model = TempoQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
43
+ add_tokens_flag = False
44
+ if getattr(model.config, "mm_use_im_patch_token", False):
45
+ add_tokens_flag = True
46
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
47
+ if getattr(model.config, "mm_use_im_start_end", False):
48
+ add_tokens_flag = True
49
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
50
+ if add_tokens_flag:
51
+ model.resize_token_embeddings(len(tokenizer))
52
+
53
+ vision_tower_aux_list = model.get_vision_tower_aux_list()
54
+ for vision_tower_aux in vision_tower_aux_list:
55
+ if not vision_tower_aux.is_loaded:
56
+ vision_tower_aux.load_model(device_map=device_map)
57
+ vision_tower_aux.to(device=device, dtype=torch.float16)
58
+
59
+ image_processor = None
60
+ image_processor = [vision_tower_aux.image_processor for vision_tower_aux in vision_tower_aux_list]
61
+
62
+ return tokenizer, model, image_processor
tempo/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
tempo/conversation.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import dataclasses
3
+ from io import BytesIO
4
+ from enum import auto, Enum
5
+ from typing import Any, Union
6
+
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+
14
+ SINGLE = auto()
15
+ TWO = auto()
16
+ MPT = auto()
17
+ PLAIN = auto()
18
+ LLAMA_2 = auto()
19
+ LLAMA_3 = auto()
20
+ LLAMA_3_1 = auto()
21
+ LLAMA_3_2 = auto()
22
+ QWEN = auto()
23
+ CHATML = auto()
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class Conversation:
28
+ """A class that keeps all conversation history."""
29
+
30
+ system: str
31
+ roles: list[str]
32
+ messages: list[list[str]]
33
+ offset: int
34
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
35
+ sep: str = "###"
36
+ sep2: str = None
37
+ version: str = "Unknown"
38
+
39
+ tokenizer: Any = None
40
+ # Stop criteria (the default one is EOS token)
41
+ stop_str: Union[str, list[str]] = None
42
+ # Stops generation if meeting any token in this list
43
+ stop_token_ids: list[int] = None
44
+
45
+ skip_next: bool = False
46
+
47
+ def get_prompt(self):
48
+ messages = self.messages
49
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
50
+ messages = self.messages.copy()
51
+ init_role, init_msg = messages[0].copy()
52
+ init_msg = init_msg[0].replace("<image>", "").strip()
53
+ if "mmtag" in self.version:
54
+ messages[0] = (init_role, init_msg)
55
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
56
+ messages.insert(1, (self.roles[1], "Received."))
57
+ else:
58
+ messages[0] = (init_role, "<image>\n" + init_msg)
59
+
60
+ if self.sep_style == SeparatorStyle.SINGLE:
61
+ ret = self.system + self.sep
62
+ for role, message in messages:
63
+ if message:
64
+ if type(message) is tuple:
65
+ message, _, _ = message
66
+ ret += role + ": " + message + self.sep
67
+ else:
68
+ ret += role + ":"
69
+
70
+ elif self.sep_style == SeparatorStyle.TWO:
71
+ seps = [self.sep, self.sep2]
72
+ ret = self.system + seps[0]
73
+ for i, (role, message) in enumerate(messages):
74
+ if message:
75
+ if type(message) is tuple:
76
+ message, _, _ = message
77
+ ret += role + ": " + message + seps[i % 2]
78
+ else:
79
+ ret += role + ":"
80
+
81
+ elif self.sep_style == SeparatorStyle.CHATML:
82
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
83
+ for role, message in messages:
84
+ if message:
85
+ if type(message) is tuple:
86
+ message, images, _ = message
87
+ message = "<image>" * len(images) + message
88
+ ret += role + "\n" + message + self.sep + "\n"
89
+ else:
90
+ ret += role + "\n"
91
+ return ret
92
+
93
+ elif self.sep_style == SeparatorStyle.MPT:
94
+ ret = self.system + self.sep
95
+ for role, message in messages:
96
+ if message:
97
+ if type(message) is tuple:
98
+ message, _, _ = message
99
+ ret += role + message + self.sep
100
+ else:
101
+ ret += role
102
+
103
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
104
+ wrap_sys = lambda msg: (
105
+ f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
106
+ )
107
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
108
+ ret = ""
109
+
110
+ for i, (role, message) in enumerate(messages):
111
+ if i == 0:
112
+ assert message, "first message should not be none"
113
+ assert role == self.roles[0], "first message should come from user"
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ if i == 0:
118
+ message = wrap_sys(self.system) + message
119
+ if i % 2 == 0:
120
+ message = wrap_inst(message)
121
+ ret += self.sep + message
122
+ else:
123
+ ret += " " + message + " " + self.sep2
124
+ else:
125
+ ret += ""
126
+ ret = ret.lstrip(self.sep)
127
+
128
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
129
+ if self.tokenizer is None:
130
+ self.tokenizer = AutoTokenizer.from_pretrained("//path/to/llama3/tokenizer")
131
+ chat_template_messages = [{"role": "system", "content": self.system}]
132
+ for role, message in messages:
133
+ if message:
134
+ if type(message) is tuple:
135
+ message, images = message
136
+ message = "<image>" * len(images) + message
137
+ chat_template_messages.append({"role": role, "content": message})
138
+
139
+ return self.tokenizer.apply_chat_template(
140
+ chat_template_messages, tokenize=False, add_generation_prompt=True
141
+ )
142
+
143
+ elif self.sep_style == SeparatorStyle.LLAMA_3_1:
144
+ if self.tokenizer is None:
145
+ self.tokenizer = AutoTokenizer.from_pretrained("//path/to/llama3.1/tokenizer")
146
+ chat_template_messages = [{"role": "system", "content": self.system}]
147
+ for role, message in messages:
148
+ if message:
149
+ if type(message) is tuple:
150
+ message, images = message
151
+ message = "<image>" * len(images) + message
152
+ chat_template_messages.append({"role": role, "content": message})
153
+
154
+ return self.tokenizer.apply_chat_template(
155
+ chat_template_messages, tokenize=False, add_generation_prompt=False
156
+ )
157
+
158
+ elif self.sep_style == SeparatorStyle.LLAMA_3_2:
159
+ wrap_sys = lambda msg: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{msg}<|eot_id|>" if len(msg) > 0 else msg
160
+ wrap_inst_user = lambda msg: f"<|start_header_id|>user<|end_header_id|>{msg}<|eot_id|>"
161
+ wrap_inst_assistant = lambda msg: f"<|start_header_id|>assistant<|end_header_id|>{msg}<|eot_id|>"
162
+ ret = ""
163
+
164
+ for i, (role, message) in enumerate(messages):
165
+ if i == 0:
166
+ assert message, "first message should not be none"
167
+ assert role == self.roles[0], "first message should come from user"
168
+ if message:
169
+ if type(message) is tuple:
170
+ message, _, _ = message
171
+ if i == 0:
172
+ ret += wrap_sys(self.system)
173
+
174
+ if i % 2 == 0:
175
+ message = wrap_inst_user(message)
176
+ ret += message
177
+ else:
178
+ message = wrap_inst_assistant(message)
179
+ ret += message
180
+ else:
181
+ ret += ""
182
+ ret += "<|start_header_id|>assistant<|end_header_id|>"
183
+
184
+ elif self.sep_style == SeparatorStyle.PLAIN:
185
+ seps = [self.sep, self.sep2]
186
+ ret = self.system
187
+ for i, (role, message) in enumerate(messages):
188
+ if message:
189
+ if type(message) is tuple:
190
+ message, _, _ = message
191
+ ret += message + seps[i % 2]
192
+ else:
193
+ ret += ""
194
+ else:
195
+ raise ValueError(f"Invalid style: {self.sep_style}")
196
+
197
+ return ret
198
+
199
+ def append_message(self, role, message):
200
+ self.messages.append([role, message])
201
+
202
+ def process_image(
203
+ self,
204
+ image,
205
+ image_process_mode,
206
+ return_pil=False,
207
+ image_format="PNG",
208
+ max_len=1344,
209
+ min_len=672,
210
+ ):
211
+ if image_process_mode == "Pad":
212
+
213
+ def expand2square(pil_img, background_color=(122, 116, 104)):
214
+ width, height = pil_img.size
215
+ if width == height:
216
+ return pil_img
217
+ elif width > height:
218
+ result = Image.new(pil_img.mode, (width, width), background_color)
219
+ result.paste(pil_img, (0, (width - height) // 2))
220
+ return result
221
+ else:
222
+ result = Image.new(pil_img.mode, (height, height), background_color)
223
+ result.paste(pil_img, ((height - width) // 2, 0))
224
+ return result
225
+
226
+ image = expand2square(image)
227
+ elif image_process_mode in ["Default", "Crop"]:
228
+ pass
229
+ elif image_process_mode == "Resize":
230
+ image = image.resize((336, 336))
231
+ else:
232
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
233
+ if max(image.size) > max_len:
234
+ max_hw, min_hw = max(image.size), min(image.size)
235
+ aspect_ratio = max_hw / min_hw
236
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
237
+ longest_edge = int(shortest_edge * aspect_ratio)
238
+ W, H = image.size
239
+ if H > W:
240
+ H, W = longest_edge, shortest_edge
241
+ else:
242
+ H, W = shortest_edge, longest_edge
243
+ image = image.resize((W, H))
244
+ if return_pil:
245
+ return image
246
+ else:
247
+ buffered = BytesIO()
248
+ image.save(buffered, format=image_format)
249
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
250
+ return img_b64_str
251
+
252
+ def get_images(self, return_pil=False):
253
+ images = []
254
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
255
+ if i % 2 == 0:
256
+ if type(msg) is tuple:
257
+ msg, image, image_process_mode = msg
258
+ image = self.process_image(
259
+ image, image_process_mode, return_pil=return_pil
260
+ )
261
+ images.append(image)
262
+ return images
263
+
264
+ def to_gradio_chatbot(self):
265
+ ret = []
266
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
267
+ if i % 2 == 0:
268
+ if type(msg) is tuple:
269
+ msg, image, image_process_mode = msg
270
+ img_b64_str = self.process_image(
271
+ image, "Default", return_pil=False, image_format="JPEG"
272
+ )
273
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
274
+ msg = img_str + msg.replace("<image>", "").strip()
275
+ ret.append([msg, None])
276
+ else:
277
+ ret.append([msg, None])
278
+ else:
279
+ ret[-1][-1] = msg
280
+ return ret
281
+
282
+ def copy(self):
283
+ return Conversation(
284
+ system=self.system,
285
+ roles=self.roles,
286
+ messages=[[x, y] for x, y in self.messages],
287
+ offset=self.offset,
288
+ sep_style=self.sep_style,
289
+ sep=self.sep,
290
+ sep2=self.sep2,
291
+ version=self.version,
292
+ )
293
+
294
+ def dict(self):
295
+ if len(self.get_images()) > 0:
296
+ return {
297
+ "system": self.system,
298
+ "roles": self.roles,
299
+ "messages": [
300
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
301
+ ],
302
+ "offset": self.offset,
303
+ "sep": self.sep,
304
+ "sep2": self.sep2,
305
+ }
306
+ return {
307
+ "system": self.system,
308
+ "roles": self.roles,
309
+ "messages": self.messages,
310
+ "offset": self.offset,
311
+ "sep": self.sep,
312
+ "sep2": self.sep2,
313
+ }
314
+
315
+
316
+ conv_vicuna_v0 = Conversation(
317
+ system="A chat between a curious human and an artificial intelligence assistant. "
318
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
319
+ roles=("Human", "Assistant"),
320
+ messages=(
321
+ (
322
+ "Human",
323
+ "What are the key differences between renewable and non-renewable energy sources?",
324
+ ),
325
+ (
326
+ "Assistant",
327
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
328
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
329
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
330
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
331
+ "renewable and non-renewable energy sources:\n"
332
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
333
+ "energy sources are finite and will eventually run out.\n"
334
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
335
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
336
+ "and other negative effects.\n"
337
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
338
+ "have lower operational costs than non-renewable sources.\n"
339
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
340
+ "locations than non-renewable sources.\n"
341
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
342
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
343
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
344
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
345
+ ),
346
+ ),
347
+ offset=2,
348
+ sep_style=SeparatorStyle.SINGLE,
349
+ sep="###",
350
+ )
351
+
352
+ conv_vicuna_v1 = Conversation(
353
+ system="A chat between a curious user and an artificial intelligence assistant. "
354
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
355
+ roles=("USER", "ASSISTANT"),
356
+ version="v1",
357
+ messages=(),
358
+ offset=0,
359
+ sep_style=SeparatorStyle.TWO,
360
+ sep=" ",
361
+ sep2="</s>",
362
+ )
363
+
364
+ conv_llama_2 = Conversation(
365
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
366
+
367
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
368
+ roles=("USER", "ASSISTANT"),
369
+ version="llama_v2",
370
+ messages=(),
371
+ offset=0,
372
+ sep_style=SeparatorStyle.LLAMA_2,
373
+ sep="<s>",
374
+ sep2="</s>",
375
+ )
376
+
377
+ conv_llava_llama_2 = Conversation(
378
+ system="You are a helpful language and vision assistant. "
379
+ "You are able to understand the visual content that the user provides, "
380
+ "and assist the user with a variety of tasks using natural language.",
381
+ roles=("USER", "ASSISTANT"),
382
+ version="llama_v2",
383
+ messages=(),
384
+ offset=0,
385
+ sep_style=SeparatorStyle.LLAMA_2,
386
+ sep="<s>",
387
+ sep2="</s>",
388
+ )
389
+
390
+ conv_mpt = Conversation(
391
+ system="""<|im_start|>system
392
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
393
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
394
+ version="mpt",
395
+ messages=(),
396
+ offset=0,
397
+ sep_style=SeparatorStyle.MPT,
398
+ sep="<|im_end|>",
399
+ )
400
+
401
+ conv_llava_plain = Conversation(
402
+ system="",
403
+ roles=("", ""),
404
+ messages=(),
405
+ offset=0,
406
+ sep_style=SeparatorStyle.PLAIN,
407
+ sep="\n",
408
+ version="plain",
409
+ )
410
+
411
+ conv_llava_v0 = Conversation(
412
+ system="A chat between a curious human and an artificial intelligence assistant. "
413
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
414
+ roles=("Human", "Assistant"),
415
+ messages=(),
416
+ offset=0,
417
+ sep_style=SeparatorStyle.SINGLE,
418
+ sep="###",
419
+ )
420
+
421
+ conv_llava_v0_mmtag = Conversation(
422
+ system="A chat between a curious user and an artificial intelligence assistant. "
423
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
424
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
425
+ roles=("Human", "Assistant"),
426
+ messages=(),
427
+ offset=0,
428
+ sep_style=SeparatorStyle.SINGLE,
429
+ sep="###",
430
+ version="v0_mmtag",
431
+ )
432
+
433
+ conv_llava_v1 = Conversation(
434
+ system="A chat between a curious human and an artificial intelligence assistant. "
435
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
436
+ roles=("USER", "ASSISTANT"),
437
+ version="v1",
438
+ messages=(),
439
+ offset=0,
440
+ sep_style=SeparatorStyle.TWO,
441
+ sep=" ",
442
+ sep2="</s>",
443
+ )
444
+
445
+ conv_llava_v1_mmtag = Conversation(
446
+ system="A chat between a curious user and an artificial intelligence assistant. "
447
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
448
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
449
+ roles=("USER", "ASSISTANT"),
450
+ messages=(),
451
+ offset=0,
452
+ sep_style=SeparatorStyle.TWO,
453
+ sep=" ",
454
+ sep2="</s>",
455
+ version="v1_mmtag",
456
+ )
457
+
458
+ conv_mistral_instruct = Conversation(
459
+ system="",
460
+ roles=("USER", "ASSISTANT"),
461
+ version="llama_v2",
462
+ messages=(),
463
+ offset=0,
464
+ sep_style=SeparatorStyle.LLAMA_2,
465
+ sep="",
466
+ sep2="</s>",
467
+ )
468
+
469
+ conv_chatml_direct = Conversation(
470
+ system="""<|im_start|>system
471
+ Answer the questions.""",
472
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
473
+ version="mpt",
474
+ messages=(),
475
+ offset=0,
476
+ sep_style=SeparatorStyle.MPT,
477
+ sep="<|im_end|>",
478
+ )
479
+
480
+ conv_llama3 = Conversation(
481
+ system="""You are a helpful assistant.""",
482
+ roles=("user", "assistant"),
483
+ version="llama3",
484
+ messages=(),
485
+ offset=0,
486
+ sep_style=SeparatorStyle.LLAMA_3,
487
+ sep="<|eot_id|>",
488
+ )
489
+ conv_llama3_2 = Conversation(
490
+ system="""You are a helpful assistant.""",
491
+ roles=("user", "assistant"),
492
+ version="llama3_2",
493
+ messages=(),
494
+ offset=0,
495
+ sep_style=SeparatorStyle.LLAMA_3_2,
496
+ sep="<|eot_id|>",
497
+ )
498
+
499
+ conv_phi3_instruct = Conversation(
500
+ system="""<|system|>\nYou are a helpful AI assistant.""",
501
+ roles=("\n<|user|>\n", "\n<|assistant|>\n"),
502
+ version="phi3",
503
+ messages=(),
504
+ offset=0,
505
+ sep_style=SeparatorStyle.MPT,
506
+ sep="<|end|>",
507
+ )
508
+
509
+ conv_qwen = Conversation(
510
+ system="""<|im_start|>system
511
+ You are a helpful assistant.""",
512
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
513
+ version="qwen",
514
+ messages=[],
515
+ offset=0,
516
+ sep_style=SeparatorStyle.CHATML,
517
+ sep="<|im_end|>",
518
+ )
519
+
520
+ default_conversation = conv_qwen
521
+ conv_templates = {
522
+ "default": conv_vicuna_v0,
523
+ "v0": conv_vicuna_v0,
524
+ "v1": conv_vicuna_v1,
525
+ "vicuna_v1": conv_vicuna_v1,
526
+ "llama_2": conv_llama_2,
527
+ "mistral_instruct": conv_mistral_instruct,
528
+ "chatml_direct": conv_chatml_direct,
529
+ "mistral_direct": conv_chatml_direct,
530
+ "plain": conv_llava_plain,
531
+ "v0_plain": conv_llava_plain,
532
+ "llava_v0": conv_llava_v0,
533
+ "v0_mmtag": conv_llava_v0_mmtag,
534
+ "llava_v1": conv_llava_v1,
535
+ "v1_mmtag": conv_llava_v1_mmtag,
536
+ "llava_llama_2": conv_llava_llama_2,
537
+ "mpt": conv_mpt,
538
+ "llama3": conv_llama3,
539
+ "llama3_2": conv_llama3_2,
540
+ "phi3": conv_phi3_instruct,
541
+ "qwen": conv_qwen,
542
+ }
543
+
544
+ if __name__ == "__main__":
545
+ print(default_conversation.get_prompt())
tempo/language_model/__pycache__/modeling_tempo_qwen.cpython-312.pyc ADDED
Binary file (7.46 kB). View file
 
tempo/language_model/modeling_tempo_qwen.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from typing import Optional, Union, Callable
19
+ from transformers.utils import logging
20
+ from transformers.cache_utils import Cache
21
+ from transformers import AutoConfig, AutoModelForCausalLM
22
+ from transformers.modeling_outputs import CausalLMOutputWithPast
23
+ from transformers import Qwen3Config, Qwen3ForCausalLM, Qwen3Model
24
+ from transformers.modeling_utils import PreTrainedModel
25
+ from transformers.generation.streamers import BaseStreamer
26
+ from transformers.generation.utils import (
27
+ GenerateOutput,
28
+ GenerationConfig,
29
+ LogitsProcessorList,
30
+ StoppingCriteriaList,
31
+ )
32
+
33
+
34
+ from tempo.tempo_arch import TempoMetaForCausalLM, TempoMetaModel
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ class TempoConfig(Qwen3Config):
39
+ model_type = "tempo_qwen"
40
+ debug = "debug"
41
+
42
+ class TempoQwenModel(TempoMetaModel, Qwen3Model):
43
+ config_class = TempoConfig
44
+
45
+ def __init__(self, config: Qwen3Config):
46
+ super(TempoQwenModel, self).__init__(config)
47
+
48
+ class TempoQwenForCausalLM(Qwen3ForCausalLM, TempoMetaForCausalLM):
49
+ config_class = TempoConfig
50
+
51
+ def __init__(self, config):
52
+ super(Qwen3ForCausalLM, self).__init__(config)
53
+ config.model_type = "tempo_qwen"
54
+ config.rope_scaling = None
55
+
56
+ self.model = TempoQwenModel(config)
57
+ self.vocab_size = config.vocab_size
58
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
59
+
60
+ # Initialize weights and apply final processing
61
+ self.post_init()
62
+
63
+ def get_model(self):
64
+ return self.model
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.LongTensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[Cache] = None,
72
+ inputs_embeds: Optional[torch.FloatTensor] = None,
73
+ labels: Optional[torch.LongTensor] = None,
74
+ use_cache: Optional[bool] = None,
75
+ cache_position: Optional[torch.LongTensor] = None,
76
+ logits_to_keep: Union[int, torch.Tensor] = 0,
77
+ images: Optional[torch.FloatTensor] = None,
78
+ image_sizes: Optional[list[list[int]]] = None,
79
+ vlm_inputs: Optional[dict] = None,
80
+ seg_timestamps: Optional[torch.LongTensor] = None,
81
+ batch_split_size: Optional[list[int]] = None,
82
+ **kwargs,
83
+ ) -> CausalLMOutputWithPast:
84
+ if inputs_embeds is None:
85
+ (
86
+ input_ids,
87
+ position_ids,
88
+ attention_mask,
89
+ past_key_values,
90
+ inputs_embeds,
91
+ labels,
92
+ ) = self.prepare_inputs_labels_for_multimodal(
93
+ input_ids,
94
+ position_ids,
95
+ attention_mask,
96
+ past_key_values,
97
+ labels,
98
+ images,
99
+ image_sizes,
100
+ vlm_inputs,
101
+ seg_timestamps,
102
+ batch_split_size,
103
+ )
104
+
105
+ return super().forward(
106
+ input_ids=input_ids,
107
+ attention_mask=attention_mask,
108
+ position_ids=position_ids,
109
+ past_key_values=past_key_values,
110
+ inputs_embeds=inputs_embeds,
111
+ labels=labels,
112
+ use_cache=use_cache,
113
+ cache_position=cache_position,
114
+ logits_to_keep=logits_to_keep,
115
+ **kwargs,
116
+ )
117
+
118
+ @torch.no_grad()
119
+ def generate(
120
+ self,
121
+ inputs: Optional[torch.Tensor] = None,
122
+ generation_config: Optional[GenerationConfig] = None,
123
+ logits_processor: Optional[LogitsProcessorList] = None,
124
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
125
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
126
+ synced_gpus: Optional[bool] = None,
127
+ assistant_model: Optional["PreTrainedModel"] = None,
128
+ streamer: Optional["BaseStreamer"] = None,
129
+ negative_prompt_ids: Optional[torch.Tensor] = None,
130
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
131
+ use_model_defaults: Optional[bool] = None,
132
+ custom_generate: Optional[Union[str, Callable]] = None,
133
+ images: Optional[torch.Tensor] = None,
134
+ image_sizes: Optional[torch.Tensor] = None,
135
+ **kwargs,
136
+ ) -> Union[GenerateOutput, torch.LongTensor]:
137
+ position_ids = kwargs.pop("position_ids", None)
138
+ attention_mask = kwargs.pop("attention_mask", None)
139
+ vlm_inputs = kwargs.pop("vlm_inputs", None)
140
+ seg_timestamps = kwargs.pop("seg_timestamps", None)
141
+ relevance = kwargs.pop("relevance", None) # when using external retriever
142
+
143
+ if "inputs_embeds" in kwargs:
144
+ raise NotImplementedError("`inputs_embeds` is not supported")
145
+
146
+ if vlm_inputs is not None:
147
+ (
148
+ inputs,
149
+ position_ids,
150
+ attention_mask,
151
+ _,
152
+ inputs_embeds,
153
+ _,
154
+ ) = self.prepare_inputs_labels_for_multimodal(
155
+ inputs,
156
+ position_ids,
157
+ attention_mask,
158
+ None,
159
+ None,
160
+ images,
161
+ image_sizes=image_sizes,
162
+ vlm_inputs=vlm_inputs,
163
+ seg_timestamps=seg_timestamps,
164
+ relevance=relevance,
165
+ )
166
+ elif images is not None:
167
+ (
168
+ inputs,
169
+ position_ids,
170
+ attention_mask,
171
+ _,
172
+ inputs_embeds,
173
+ _,
174
+ ) = self.prepare_inputs_labels_for_multimodal(
175
+ inputs,
176
+ position_ids,
177
+ attention_mask,
178
+ None,
179
+ None,
180
+ images,
181
+ image_sizes=image_sizes,
182
+ )
183
+ else:
184
+ inputs_embeds = self.get_model().embed_tokens(inputs)
185
+
186
+ # if attention_mask is None:
187
+ # # avoid warning
188
+ # attention_mask = torch.ones(
189
+ # inputs_embeds.shape[:2],
190
+ # dtype=torch.long,
191
+ # device=inputs_embeds.device
192
+ # )
193
+
194
+ return super().generate(
195
+ inputs=None,
196
+ generation_config=generation_config,
197
+ logits_processor=logits_processor,
198
+ stopping_criteria=stopping_criteria,
199
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
200
+ synced_gpus=synced_gpus,
201
+ assistant_model=assistant_model,
202
+ streamer=streamer,
203
+ negative_prompt_ids=negative_prompt_ids,
204
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
205
+ use_model_defaults=use_model_defaults,
206
+ custom_generate=custom_generate,
207
+ position_ids=position_ids,
208
+ attention_mask=attention_mask,
209
+ inputs_embeds=inputs_embeds,
210
+ **kwargs,
211
+ )
212
+
213
+ def prepare_inputs_for_generation(
214
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
215
+ ):
216
+ images = kwargs.pop("images", None)
217
+ image_sizes = kwargs.pop("image_sizes", None)
218
+ inputs = super().prepare_inputs_for_generation(
219
+ input_ids,
220
+ past_key_values=past_key_values,
221
+ inputs_embeds=inputs_embeds,
222
+ **kwargs,
223
+ )
224
+ if images is not None:
225
+ inputs["images"] = images
226
+ if image_sizes is not None:
227
+ inputs["image_sizes"] = image_sizes
228
+ return inputs
229
+
230
+ AutoConfig.register("tempo_qwen", TempoConfig)
231
+ AutoModelForCausalLM.register(TempoConfig, TempoQwenForCausalLM)
tempo/mm_datautils.py ADDED
@@ -0,0 +1,1607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ from typing import List
4
+ from packaging import version
5
+ from collections.abc import Sequence
6
+
7
+ import torch
8
+ import tokenizers
9
+ import transformers
10
+ import numpy as np
11
+ from PIL import Image
12
+ from transformers import StoppingCriteria
13
+ from qwen_vl_utils import process_vision_info
14
+ from torch import distributed as dist
15
+ from torch.distributed.fsdp import (
16
+ FullStateDictConfig,
17
+ FullyShardedDataParallel as FSDP,
18
+ StateDictType,
19
+ )
20
+
21
+ from tempo import conversation as conversation_lib
22
+ from tempo.constants import (
23
+ DEFAULT_IM_END_TOKEN,
24
+ DEFAULT_IM_START_TOKEN,
25
+ DEFAULT_IMAGE_TOKEN,
26
+ IGNORE_INDEX,
27
+ IMAGE_TOKEN_INDEX,
28
+ )
29
+
30
+ IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse(
31
+ "0.14"
32
+ )
33
+
34
+ class KeywordsStoppingCriteria(StoppingCriteria):
35
+ def __init__(self, keywords, tokenizer, input_ids):
36
+ self.keywords = keywords
37
+ self.keyword_ids = []
38
+ self.max_keyword_len = 0
39
+ for keyword in keywords:
40
+ cur_keyword_ids = tokenizer(keyword).input_ids
41
+ if (
42
+ len(cur_keyword_ids) > 1
43
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
44
+ ):
45
+ cur_keyword_ids = cur_keyword_ids[1:]
46
+ if len(cur_keyword_ids) > self.max_keyword_len:
47
+ self.max_keyword_len = len(cur_keyword_ids)
48
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
49
+ self.tokenizer = tokenizer
50
+ self.start_len = input_ids.shape[1]
51
+
52
+ def call_for_batch(
53
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
54
+ ) -> bool:
55
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
56
+ self.keyword_ids = [
57
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
58
+ ]
59
+ for keyword_id in self.keyword_ids:
60
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0] :]
61
+ if torch.equal(truncated_output_ids, keyword_id):
62
+ return True
63
+ outputs = self.tokenizer.batch_decode(
64
+ output_ids[:, -offset:], skip_special_tokens=True
65
+ )[0]
66
+ for keyword in self.keywords:
67
+ if keyword in outputs:
68
+ return True
69
+ return False
70
+
71
+ def __call__(
72
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
73
+ ) -> bool:
74
+ outputs = []
75
+ for i in range(output_ids.shape[0]):
76
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
77
+ return all(outputs)
78
+
79
+
80
+ def safe_save_model_for_hf_trainer(
81
+ trainer: transformers.Trainer, output_dir: str
82
+ ) -> None:
83
+ """Collects the state dict and dump to disk."""
84
+ global_rank = dist.get_rank()
85
+ save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
86
+ if len(trainer.args.fsdp) == 0:
87
+ cpu_state_dict = trainer.model.state_dict()
88
+ else:
89
+ with FSDP.state_dict_type(
90
+ trainer.model, StateDictType.FULL_STATE_DICT, save_policy
91
+ ):
92
+ cpu_state_dict = trainer.model.state_dict()
93
+
94
+ for key in cpu_state_dict.keys():
95
+ cpu_state_dict[key] = cpu_state_dict[key].to(torch.bfloat16)
96
+
97
+ if global_rank == 0:
98
+ trainer.model.config.save_pretrained(output_dir)
99
+ current_folder = output_dir.split("/")[-1]
100
+ parent_folder = os.path.dirname(output_dir)
101
+ save_path = os.path.join(output_dir, "pytorch_model.bin")
102
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False) and not getattr(
103
+ trainer.args, "tune_text_decoder", False
104
+ ):
105
+ # Only save Adapter
106
+ keys_to_match = ["mm_projector"]
107
+ if getattr(trainer.args, "use_im_start_end", False):
108
+ keys_to_match.extend(["embed_tokens", "embed_in"])
109
+
110
+ freeze_layer_remove = []
111
+ for key in cpu_state_dict.keys():
112
+ remove = True
113
+ for key_match in keys_to_match:
114
+ if key_match in key:
115
+ remove = False
116
+ break
117
+ if remove:
118
+ freeze_layer_remove.append(key)
119
+ for key in freeze_layer_remove:
120
+ del cpu_state_dict[key]
121
+
122
+ if current_folder.startswith("checkpoint-"):
123
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
124
+ os.makedirs(mm_projector_folder, exist_ok=True)
125
+ save_path = os.path.join(mm_projector_folder, f"{current_folder}.bin")
126
+ else:
127
+ save_path = os.path.join(output_dir, f"mm_projector.bin")
128
+ torch.save(cpu_state_dict, save_path)
129
+
130
+
131
+ def _tokenize_fn(
132
+ strings: Sequence[str],
133
+ tokenizer: transformers.PreTrainedTokenizer,
134
+ ) -> dict:
135
+ """Tokenize a list of strings."""
136
+ tokenized_list = [
137
+ tokenizer(
138
+ text,
139
+ return_tensors="pt",
140
+ padding="longest",
141
+ max_length=tokenizer.model_max_length,
142
+ truncation=True,
143
+ )
144
+ for text in strings
145
+ ]
146
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
147
+ input_ids_lens = labels_lens = [
148
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
149
+ for tokenized in tokenized_list
150
+ ]
151
+ return dict(
152
+ input_ids=input_ids,
153
+ labels=labels,
154
+ input_ids_lens=input_ids_lens,
155
+ labels_lens=labels_lens,
156
+ )
157
+
158
+
159
+ def _mask_targets(target, tokenized_lens, speakers) -> None:
160
+ # cur_idx = 0
161
+ cur_idx = tokenized_lens[0]
162
+ tokenized_lens = tokenized_lens[1:]
163
+ target[:cur_idx] = IGNORE_INDEX
164
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
165
+ if speaker == "human":
166
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
167
+ cur_idx += tokenized_len
168
+
169
+
170
+ def _add_speaker_and_signal(header, source, get_conversation: bool = True):
171
+ """Add speaker and start/end signal on each round."""
172
+ BEGIN_SIGNAL = "### "
173
+ END_SIGNAL = "\n"
174
+ conversation = header
175
+ for sentence in source:
176
+ from_str = sentence["from"]
177
+ if from_str.lower() == "human":
178
+ from_str = conversation_lib.default_conversation.roles[0]
179
+ elif from_str.lower() == "gpt":
180
+ from_str = conversation_lib.default_conversation.roles[1]
181
+ else:
182
+ from_str = "unknown"
183
+ sentence["value"] = (
184
+ BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
185
+ )
186
+ if get_conversation:
187
+ conversation += sentence["value"]
188
+ conversation += BEGIN_SIGNAL
189
+ return conversation
190
+
191
+
192
+ def expand2square(pil_img, background_color):
193
+ width, height = pil_img.size
194
+ if width == height:
195
+ return pil_img
196
+ elif width > height:
197
+ result = Image.new(pil_img.mode, (width, width), background_color)
198
+ result.paste(pil_img, (0, (width - height) // 2))
199
+ return result
200
+ else:
201
+ result = Image.new(pil_img.mode, (height, height), background_color)
202
+ result.paste(pil_img, ((height - width) // 2, 0))
203
+ return result
204
+
205
+
206
+ def crop2square(pil_img):
207
+ width, height = pil_img.size
208
+ if width == height:
209
+ return pil_img
210
+ elif width > height:
211
+ left = (width - height) // 2
212
+ right = left + height
213
+ top = 0
214
+ bottom = height
215
+ return pil_img.crop((left, top, right, bottom))
216
+ else:
217
+ top = (height - width) // 2
218
+ bottom = top + width
219
+ left = 0
220
+ right = width
221
+ return pil_img.crop((left, top, right, bottom))
222
+
223
+
224
+ def perpare_input_for_qwen_input(chunk_dict, pad_token_ids):
225
+ """Currently, only batch size = 1 is supported for evaluation."""
226
+
227
+ qwenvl_input_dict = {}
228
+ has_video = any(["video" in key for key in list(chunk_dict.keys())])
229
+
230
+ qwenvl_input_dict["input_ids"] = torch.nn.utils.rnn.pad_sequence(
231
+ chunk_dict["vlm_input_ids"],
232
+ batch_first=True,
233
+ padding_value=pad_token_ids,
234
+ )
235
+ qwenvl_input_dict["attention_mask"] = torch.nn.utils.rnn.pad_sequence(
236
+ chunk_dict["vlm_attention_mask"],
237
+ batch_first=True,
238
+ padding_value=0,
239
+ )
240
+
241
+ if has_video:
242
+ qwenvl_input_dict["pixel_values_videos"] = torch.cat(chunk_dict["pixel_values_videos"], dim=0)
243
+ qwenvl_input_dict["video_grid_thw"] = torch.cat(chunk_dict["video_grid_thw"], dim=0)
244
+ else:
245
+ qwenvl_input_dict["pixel_values"] = torch.cat(chunk_dict["pixel_values"], dim=0)
246
+ qwenvl_input_dict["image_grid_thw"] = torch.cat(chunk_dict["image_grid_thw"], dim=0)
247
+
248
+ return qwenvl_input_dict
249
+
250
+
251
+ def construct_message(
252
+ content_data,
253
+ data_type,
254
+ query,
255
+ multimodal_processor,
256
+ sample_fps=1,
257
+ return_text=False,
258
+ ):
259
+ # # prompt 0 (using during training)
260
+ # system_message = (
261
+ # "You are a query-conditioned visual compressor. "
262
+ # "Store in the provided memory tokens the minimal visual information needed to answer the Query. "
263
+ # "Ignore irrelevant details."
264
+ # )
265
+ # prompt 1 (using during inference)
266
+ system_message = (
267
+ "You are a query-conditioned visual compressor. "
268
+ "Store in the provided memory tokens the minimal visual information needed to answer the Query. "
269
+ "Ignore irrelevant details. "
270
+ "Now, before compressing, answer exactly 'Yes' or 'No': is this segment relevant to the Query?"
271
+ )
272
+
273
+ user_message = f"\nQuery:\n{query}"
274
+ # assistant_message = "Scanning for target features... The visual confidence representation is:"
275
+ assistant_message = None
276
+
277
+ if data_type == "image":
278
+ messages = [
279
+ {
280
+ "role": "system",
281
+ "content": [{"type": "text", "text": system_message}],
282
+ },
283
+ {
284
+ "role": "user",
285
+ "content": [
286
+ {"type": "image", "image": content_data},
287
+ {"type": "text", "text": user_message},
288
+ ],
289
+ },
290
+ ]
291
+ elif data_type == "video":
292
+ messages = [
293
+ {
294
+ "role": "system",
295
+ "content": [{"type": "text", "text": system_message}],
296
+ },
297
+ {
298
+ "role": "user",
299
+ "content": [
300
+ {
301
+ "type": "video",
302
+ "video": content_data,
303
+ "sample_fps": sample_fps,
304
+ },
305
+ {"type": "text", "text": user_message},
306
+ ],
307
+ },
308
+ ]
309
+ else:
310
+ raise ValueError(f"Unknown data type: {data_type}")
311
+
312
+ if return_text:
313
+ messages = multimodal_processor.apply_chat_template(
314
+ messages,
315
+ tokenize=False,
316
+ add_generation_prompt=True,
317
+ )
318
+
319
+ if assistant_message is not None:
320
+ messages = messages + assistant_message
321
+
322
+ return messages
323
+
324
+
325
+ def video_process_with_frame_idx(
326
+ chunk_frames, multimodal_processor, query, sample_fps=2, frame_offset=0
327
+ ):
328
+ messages = construct_message(
329
+ chunk_frames, "video", query, multimodal_processor, sample_fps
330
+ )
331
+ text = multimodal_processor.apply_chat_template(
332
+ messages, tokenize=False, add_generation_prompt=True
333
+ )
334
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
335
+ [messages],
336
+ return_video_kwargs=True,
337
+ image_patch_size=16,
338
+ return_video_metadata=True,
339
+ )
340
+ if video_inputs is not None:
341
+ video_inputs, video_metadatas = zip(*video_inputs)
342
+ video_inputs, video_metadatas = (
343
+ list(video_inputs),
344
+ list(video_metadatas),
345
+ )
346
+ else:
347
+ video_metadatas = None
348
+
349
+ if video_metadatas is not None:
350
+ video_metadatas[0]["frames_indices"] = [
351
+ f + frame_offset for f in video_metadatas[0]["frames_indices"]
352
+ ]
353
+
354
+ inputs = multimodal_processor(
355
+ text=[text],
356
+ images=image_inputs,
357
+ videos=video_inputs,
358
+ video_metadata=video_metadatas,
359
+ **video_kwargs,
360
+ do_resize=False,
361
+ return_tensors="pt",
362
+ )
363
+
364
+ return inputs
365
+
366
+
367
+ def process_qwen_content(
368
+ content_data,
369
+ data_type,
370
+ sources,
371
+ multimodal_processor,
372
+ real_fps=None,
373
+ frame_windows=8,
374
+ frame_stride=8,
375
+ is_eval=False,
376
+ ):
377
+ """
378
+ content_data:
379
+ - 'image': PIL.Image
380
+ - 'video': List[PIL.Image] or np.ndarray (T, H, W, C)
381
+ data_type: 'text', 'image' or 'video'
382
+ """
383
+
384
+ # query process
385
+ if is_eval:
386
+ # for evaluation, please input only text string
387
+ assert isinstance(sources, str), "During evaluation, sources should be a single query string."
388
+ query = sources
389
+ else:
390
+ conversation = sources[0]
391
+ if isinstance(conversation, list):
392
+ # This is acceptable during training to learn better representations,
393
+ # but cannot be used during inference as it may lead to data leakage.
394
+ # Currently, only single-turn dialogue is supported during inference.
395
+ # Set the maximum number of dialogue turns to 8.
396
+ human_queries = []
397
+ for turn in conversation[:16]:
398
+ if turn.get("from") == "human":
399
+ clean_text = turn["value"].replace("<image>", "").strip()
400
+ if clean_text:
401
+ human_queries.append(clean_text)
402
+
403
+ query = "\n".join(
404
+ [f"Context turn {i + 1}: {q}" for i, q in enumerate(human_queries)]
405
+ )
406
+ else:
407
+ query = "Describe this content."
408
+ if not query.strip():
409
+ query = "Describe this content."
410
+
411
+ chunk_results = []
412
+
413
+ def resize_images(frames, resolution=512):
414
+ resized_frames = []
415
+ for f in frames:
416
+ w, h = f.size
417
+ max_edge = max(w, h)
418
+ if max_edge > resolution:
419
+ ratio = resolution / max_edge
420
+ new_w = int(round((w * ratio) / 16) * 16)
421
+ new_h = int(round((h * ratio) / 16) * 16)
422
+ new_w = max(16, new_w)
423
+ new_h = max(16, new_h)
424
+ f = f.resize((new_w, new_h), resample=Image.Resampling.BICUBIC)
425
+ resized_frames.append(f)
426
+ return resized_frames
427
+
428
+ # === Text ===
429
+ if data_type == "text":
430
+ content_data = Image.new("RGB", (336, 336), color=(255, 255, 255)) # dummy image
431
+ messages = construct_message(
432
+ content_data, "image", query, multimodal_processor, return_text=True
433
+ ) # str
434
+ inputs = multimodal_processor(
435
+ text=[messages],
436
+ images=[content_data],
437
+ padding=False,
438
+ return_tensors="pt",
439
+ )
440
+ chunk_results.append(inputs)
441
+
442
+ # === Image ===
443
+ elif data_type == "image":
444
+ if isinstance(content_data, list):
445
+ # multi-image
446
+ content_data = resize_images(content_data, resolution=512)
447
+ messages = construct_message(
448
+ content_data[0], data_type, query, multimodal_processor, return_text=True,
449
+ ) # str
450
+ inputs = multimodal_processor(
451
+ text=[messages] * len(content_data),
452
+ images=content_data,
453
+ padding=True,
454
+ return_tensors="pt",
455
+ )
456
+ else:
457
+ messages = construct_message(
458
+ content_data, data_type, query, multimodal_processor, return_text=True
459
+ ) # str
460
+ inputs = multimodal_processor(
461
+ text=[messages],
462
+ images=[content_data],
463
+ padding=True,
464
+ return_tensors="pt",
465
+ )
466
+ chunk_results.append(inputs)
467
+
468
+ # === Video ===
469
+ elif data_type == "video":
470
+ if isinstance(content_data, np.ndarray):
471
+ frames = [Image.fromarray(f) for f in content_data]
472
+ else:
473
+ frames = content_data
474
+
475
+ if frames:
476
+ frames = resize_images(frames, resolution=512)
477
+
478
+ total_frames = len(frames)
479
+ window_size = frame_windows
480
+ stride = frame_stride
481
+
482
+ for i in range(0, total_frames, stride):
483
+ start_idx = i
484
+ frame_offset = i # use to compute timestamp
485
+ end_idx = min(start_idx + window_size, total_frames)
486
+
487
+ chunk_frames = frames[start_idx:end_idx]
488
+ # if len(chunk_frames) < window_size:
489
+ # chunk_frames.extend(
490
+ # [chunk_frames[-1]] * (window_size - len(chunk_frames))
491
+ # )
492
+
493
+ if len(chunk_frames) < window_size and len(chunk_frames) == 1:
494
+ chunk_frames.append(chunk_frames[-1])
495
+ print(f"Qwen processor requires at least 2 frames as video input, copy last frame to {len(chunk_frames)}")
496
+
497
+ inputs = video_process_with_frame_idx(
498
+ chunk_frames, multimodal_processor, query, real_fps, frame_offset,
499
+ )
500
+ chunk_results.append(inputs)
501
+
502
+ else:
503
+ raise ValueError(f"Unknown data type: {data_type}")
504
+
505
+ # ============Group for batch process===================
506
+ chunk_dict = {}
507
+ for key in chunk_results[0]:
508
+ if key in ["input_ids", "attention_mask"]:
509
+ chunk_dict[f"vlm_{key}"] = [r[key].squeeze(dim=0) for r in chunk_results]
510
+ else:
511
+ chunk_dict[key] = [r[key] for r in chunk_results]
512
+
513
+ if is_eval:
514
+ return perpare_input_for_qwen_input(
515
+ chunk_dict, multimodal_processor.tokenizer.pad_token_id
516
+ )
517
+
518
+ return chunk_dict
519
+
520
+
521
+ def compute_segment_timestamp(
522
+ num_segments,
523
+ tokenizer,
524
+ real_fps,
525
+ stride=None,
526
+ window_size=None,
527
+ use_center_timestamp=True,
528
+ ):
529
+ """
530
+ The current version only supports non-overlapping segments.
531
+ You need to modify the timestamp computation to support overlapping segments.
532
+ """
533
+ step = stride if stride is not None else window_size
534
+ fps = real_fps if real_fps and real_fps > 0 else 1.0
535
+
536
+ seg_timestamps_ids = []
537
+ for i in range(num_segments):
538
+ start_frame_idx = i * step
539
+ if use_center_timestamp:
540
+ frame_idx = start_frame_idx + (window_size / 2)
541
+ else:
542
+ frame_idx = start_frame_idx
543
+ cur_timestamp_sec = frame_idx / fps
544
+ text = f"<{cur_timestamp_sec:.1f} seconds>"
545
+
546
+ ids = tokenizer.encode(text, add_special_tokens=False)
547
+ seg_timestamps_ids.append(ids)
548
+
549
+ return seg_timestamps_ids
550
+
551
+
552
+ def compute_sample_indices(
553
+ total_frames: int,
554
+ original_fps: float,
555
+ target_fps: float,
556
+ min_frames: int,
557
+ max_frames: int,
558
+ ) -> List[int]:
559
+ if total_frames <= 1:
560
+ return [0]
561
+
562
+ if original_fps is None or original_fps <= 0:
563
+ original_fps = target_fps
564
+
565
+ video_duration = total_frames / original_fps
566
+ target_num_frames = max(1, round(video_duration * target_fps))
567
+
568
+ final_num_frames = target_num_frames
569
+ if final_num_frames < min_frames:
570
+ print(
571
+ f"Upsampling video from {target_num_frames} to {min_frames} frames (min_frames limit)."
572
+ )
573
+ final_num_frames = min_frames
574
+ elif final_num_frames > max_frames:
575
+ print(
576
+ f"Downsampling video from {target_num_frames} to {max_frames} frames (max_frames limit)."
577
+ )
578
+ final_num_frames = max_frames
579
+
580
+ if final_num_frames == 1:
581
+ return [total_frames - 1]
582
+
583
+ indices = np.linspace(0, total_frames - 1, final_num_frames).astype(int)
584
+ indices = np.clip(indices, 0, total_frames - 1)
585
+
586
+ return indices.tolist()
587
+
588
+
589
+ def process_images(images, image_processor, model_cfg):
590
+ # if image_processor is None:
591
+ # raise ValueError("image_processor cannot be None")
592
+ if isinstance(image_processor, list):
593
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
594
+ processor_aux_list = image_processor
595
+ new_images_aux_list = []
596
+ for image in images:
597
+ if isinstance(image, np.ndarray):
598
+ image = Image.fromarray(image)
599
+ image_aux_list = []
600
+ for processor_aux in processor_aux_list:
601
+ image_aux = image
602
+ if hasattr(processor_aux, "image_mean"):
603
+ try:
604
+ target_resolution = processor_aux.crop_size["height"]
605
+ except:
606
+ target_resolution = processor_aux.size["height"]
607
+ # image_aux = expand2square(
608
+ # image_aux, tuple(int(x * 255) for x in processor_aux.image_mean)
609
+ # ).resize((target_resolution, target_resolution))
610
+ if image_aspect_ratio == "pad":
611
+ image_aux = expand2square(
612
+ image_aux,
613
+ tuple(int(x * 255) for x in processor_aux.image_mean),
614
+ )
615
+ elif image_aspect_ratio == "crop":
616
+ image_aux = crop2square(image_aux)
617
+ image_aux = image_aux.resize((target_resolution, target_resolution))
618
+
619
+ image_aux = processor_aux.preprocess(image_aux, return_tensors="pt")[
620
+ "pixel_values"
621
+ ][0]
622
+ image_aux_list.append(image_aux)
623
+ new_images_aux_list.append(image_aux_list)
624
+ new_images_aux_list = [
625
+ list(batch_image_aux) for batch_image_aux in zip(*new_images_aux_list)
626
+ ]
627
+ new_images_aux_list = [
628
+ torch.stack(image_aux).half().cuda() for image_aux in new_images_aux_list
629
+ ]
630
+ return new_images_aux_list
631
+ else:
632
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
633
+ new_images = []
634
+ if image_aspect_ratio == "pad":
635
+ for image in images:
636
+ image = expand2square(
637
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
638
+ )
639
+ image = image_processor.preprocess(image, return_tensors="pt")[
640
+ "pixel_values"
641
+ ][0]
642
+ new_images.append(image)
643
+ elif image_aspect_ratio == "crop":
644
+ for image in images:
645
+ image = crop2square(image)
646
+ image = image_processor.preprocess(image, return_tensors="pt")[
647
+ "pixel_values"
648
+ ][0]
649
+ new_images.append(image)
650
+ else:
651
+ return image_processor(images, return_tensors="pt")["pixel_values"]
652
+ if all(x.shape == new_images[0].shape for x in new_images):
653
+ new_images = torch.stack(new_images, dim=0)
654
+ return new_images
655
+
656
+
657
+ def preprocess_multimodal(sources: Sequence[str], data_args) -> dict:
658
+ is_multimodal = data_args.is_multimodal
659
+ if not is_multimodal:
660
+ return sources
661
+
662
+ for source in sources:
663
+ for sentence in source:
664
+ num_im = sentence["value"].count(DEFAULT_IMAGE_TOKEN)
665
+ if num_im == 1 or "<video>" in sentence["value"]:
666
+ # process only when the vision info is not multi-images
667
+ sentence["value"] = (
668
+ sentence["value"]
669
+ .replace(DEFAULT_IMAGE_TOKEN, "")
670
+ .replace("<video>", "")
671
+ .strip()
672
+ )
673
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
674
+ sentence["value"] = sentence["value"].strip()
675
+ if "mmtag" in conversation_lib.default_conversation.version:
676
+ sentence["value"] = sentence["value"].replace(
677
+ DEFAULT_IMAGE_TOKEN,
678
+ "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>",
679
+ )
680
+ replace_token = DEFAULT_IMAGE_TOKEN
681
+ if data_args.mm_use_im_start_end:
682
+ replace_token = (
683
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
684
+ )
685
+ sentence["value"] = sentence["value"].replace(
686
+ DEFAULT_IMAGE_TOKEN, replace_token
687
+ )
688
+
689
+ return sources
690
+
691
+
692
+ def preprocess_llama_2(
693
+ sources,
694
+ tokenizer: transformers.PreTrainedTokenizer,
695
+ has_image: bool = False,
696
+ ) -> dict:
697
+ conv = conversation_lib.default_conversation.copy()
698
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
699
+
700
+ # Apply prompt templates
701
+ conversations = []
702
+ for i, source in enumerate(sources):
703
+ if roles[source[0]["from"]] != conv.roles[0]:
704
+ # Skip the first one if it is not from human
705
+ source = source[1:]
706
+
707
+ conv.messages = []
708
+ for j, sentence in enumerate(source):
709
+ role = roles[sentence["from"]]
710
+ assert role == conv.roles[j % 2], f"{i}"
711
+ conv.append_message(role, sentence["value"])
712
+ conversations.append(conv.get_prompt())
713
+
714
+ # Tokenize conversations
715
+
716
+ if has_image:
717
+ input_ids = torch.stack(
718
+ [
719
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
720
+ for prompt in conversations
721
+ ],
722
+ dim=0,
723
+ )
724
+ else:
725
+ input_ids = tokenizer(
726
+ conversations,
727
+ return_tensors="pt",
728
+ padding="longest",
729
+ max_length=tokenizer.model_max_length,
730
+ truncation=True,
731
+ ).input_ids
732
+
733
+ targets = input_ids.clone()
734
+
735
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
736
+
737
+ # Mask targets
738
+ sep = "[/INST] "
739
+ for conversation, target in zip(conversations, targets):
740
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
741
+
742
+ rounds = conversation.split(conv.sep2)
743
+ cur_len = 1
744
+ target[:cur_len] = IGNORE_INDEX
745
+ for rou in rounds:
746
+ if rou == "":
747
+ break
748
+
749
+ parts = rou.split(sep)
750
+ if len(parts) != 2:
751
+ break
752
+ parts[0] += sep
753
+
754
+ if has_image:
755
+ round_len = len(tokenizer_image_token(rou, tokenizer))
756
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
757
+ else:
758
+ round_len = len(tokenizer(rou).input_ids)
759
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
760
+
761
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
762
+
763
+ cur_len += round_len
764
+ target[cur_len:] = IGNORE_INDEX
765
+
766
+ if cur_len < tokenizer.model_max_length:
767
+ if cur_len != total_len:
768
+ target[:] = IGNORE_INDEX
769
+ print(
770
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
771
+ f" (ignored)"
772
+ )
773
+
774
+ return dict(
775
+ input_ids=input_ids,
776
+ labels=targets,
777
+ )
778
+
779
+
780
+ def preprocess_v1(
781
+ sources,
782
+ tokenizer: transformers.PreTrainedTokenizer,
783
+ has_image: bool = False,
784
+ ) -> dict:
785
+ conv = conversation_lib.default_conversation.copy()
786
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
787
+
788
+ # Apply prompt templates
789
+ conversations = []
790
+ for i, source in enumerate(sources):
791
+ if roles[source[0]["from"]] != conv.roles[0]:
792
+ # Skip the first one if it is not from human
793
+ source = source[1:]
794
+
795
+ conv.messages = []
796
+ for j, sentence in enumerate(source):
797
+ role = roles[sentence["from"]]
798
+ assert role == conv.roles[j % 2], f"{i}"
799
+ conv.append_message(role, sentence["value"])
800
+ conversations.append(conv.get_prompt())
801
+
802
+ # Tokenize conversations
803
+
804
+ if has_image:
805
+ input_ids = torch.stack(
806
+ [
807
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
808
+ for prompt in conversations
809
+ ],
810
+ dim=0,
811
+ )
812
+ else:
813
+ input_ids = tokenizer(
814
+ conversations,
815
+ return_tensors="pt",
816
+ padding="longest",
817
+ max_length=tokenizer.model_max_length,
818
+ truncation=True,
819
+ ).input_ids
820
+
821
+ targets = input_ids.clone()
822
+
823
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
824
+
825
+ # Mask targets
826
+ sep = conv.sep + conv.roles[1] + ": "
827
+ for conversation, target in zip(conversations, targets):
828
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
829
+
830
+ rounds = conversation.split(conv.sep2)
831
+ cur_len = 1
832
+ target[:cur_len] = IGNORE_INDEX
833
+ for i, rou in enumerate(rounds):
834
+ if rou == "":
835
+ break
836
+
837
+ parts = rou.split(sep)
838
+ if len(parts) != 2:
839
+ break
840
+ parts[0] += sep
841
+
842
+ if has_image:
843
+ round_len = len(tokenizer_image_token(rou, tokenizer))
844
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
845
+ else:
846
+ round_len = len(tokenizer(rou).input_ids)
847
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
848
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
849
+ round_len -= 1
850
+ instruction_len -= 1
851
+
852
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
853
+
854
+ cur_len += round_len
855
+ target[cur_len:] = IGNORE_INDEX
856
+
857
+ if cur_len < tokenizer.model_max_length:
858
+ if cur_len != total_len:
859
+ target[:] = IGNORE_INDEX
860
+ print(
861
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
862
+ f" (ignored)"
863
+ )
864
+
865
+ return dict(
866
+ input_ids=input_ids,
867
+ labels=targets,
868
+ )
869
+
870
+
871
+ def tokenizer_image_token(
872
+ prompt,
873
+ tokenizer,
874
+ image_token_index=IMAGE_TOKEN_INDEX,
875
+ return_tensors=None,
876
+ ):
877
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
878
+
879
+ def insert_separator(X, sep):
880
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
881
+
882
+ input_ids = []
883
+ offset = 0
884
+ if (
885
+ len(prompt_chunks) > 0
886
+ and len(prompt_chunks[0]) > 0
887
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
888
+ ):
889
+ offset = 1
890
+ input_ids.append(prompt_chunks[0][0])
891
+
892
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
893
+ input_ids.extend(x[offset:])
894
+
895
+ if return_tensors is not None:
896
+ if return_tensors == "pt":
897
+ return torch.tensor(input_ids, dtype=torch.long)
898
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
899
+ return input_ids
900
+
901
+
902
+ def tokenizer_image_token_llama3(
903
+ prompt,
904
+ tokenizer,
905
+ image_token_index=IMAGE_TOKEN_INDEX,
906
+ return_tensors=None,
907
+ ):
908
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
909
+
910
+ def insert_separator(X, sep):
911
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
912
+
913
+ input_ids = []
914
+ for x in insert_separator(prompt_chunks, [image_token_index]):
915
+ input_ids.extend(x)
916
+
917
+ if return_tensors is not None:
918
+ if return_tensors == "pt":
919
+ return torch.tensor(input_ids, dtype=torch.long)
920
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
921
+ return input_ids
922
+
923
+
924
+ def preprocess_qwen(
925
+ sources,
926
+ tokenizer: transformers.PreTrainedTokenizer,
927
+ has_image: bool = False,
928
+ system_message: str = "You are a helpful assistant.",
929
+ ) -> dict:
930
+ roles = {"human": "user", "gpt": "assistant"}
931
+
932
+ # Add image tokens to tokenizer as a special tokens
933
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
934
+ tokenizer = copy.deepcopy(tokenizer)
935
+ # When there is actually an image, we add the image tokens as a special token
936
+ if has_image:
937
+ tokenizer.add_tokens(["<image>"], special_tokens=True)
938
+
939
+ image_token_index = tokenizer.convert_tokens_to_ids("<image>")
940
+ im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
941
+ im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
942
+
943
+ unmask_tokens_idx = [198, im_start, im_end]
944
+ # nl_tokens = tokenizer("\n").input_ids
945
+
946
+ # Reset Qwen chat templates so that it won't include system message every time we apply
947
+ chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
948
+ tokenizer.chat_template = chat_template
949
+
950
+ # _system = tokenizer("system").input_ids + nl_tokens
951
+ # _user = tokenizer("user").input_ids + nl_tokens
952
+ # _assistant = tokenizer("assistant").input_ids + nl_tokens
953
+
954
+ # Apply prompt templates
955
+ input_ids, targets = [], []
956
+ for source in sources:
957
+ if roles[source[0]["from"]] != roles["human"]:
958
+ source = source[1:]
959
+
960
+ input_id, target = [], []
961
+
962
+ # New version, use apply chat template
963
+ # Build system message for each sentence
964
+ input_id += tokenizer.apply_chat_template(
965
+ [{"role": "system", "content": system_message}]
966
+ )
967
+ target += [IGNORE_INDEX] * len(input_id)
968
+
969
+ for conv in source:
970
+ # Make sure llava data can load
971
+ try:
972
+ role = conv["role"]
973
+ content = conv["content"]
974
+ except:
975
+ role = conv["from"]
976
+ content = conv["value"]
977
+
978
+ role = roles.get(role, role)
979
+
980
+ conv = [{"role": role, "content": content}]
981
+ encode_id = tokenizer.apply_chat_template(conv)
982
+ input_id += encode_id
983
+ if role in ["user", "system"]:
984
+ target += [IGNORE_INDEX] * len(encode_id)
985
+ else:
986
+ target += encode_id
987
+
988
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
989
+ for idx, encode_id in enumerate(input_id):
990
+ if encode_id in unmask_tokens_idx:
991
+ target[idx] = encode_id
992
+ if encode_id == image_token_index:
993
+ input_id[idx] = IMAGE_TOKEN_INDEX
994
+ input_ids.append(input_id)
995
+ targets.append(target)
996
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
997
+ targets = torch.tensor(targets, dtype=torch.long)
998
+
999
+ return dict(
1000
+ input_ids=input_ids, # tensor(bs x seq_len)
1001
+ labels=targets, # tensor(bs x seq_len)
1002
+ )
1003
+
1004
+
1005
+ def preprocess_llama3(
1006
+ sources,
1007
+ tokenizer: transformers.PreTrainedTokenizer,
1008
+ has_image: bool = False,
1009
+ system_message: str = "You are a helpful assistant.",
1010
+ ) -> dict:
1011
+ # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
1012
+ roles = {"human": "user", "gpt": "assistant"}
1013
+
1014
+ # Add image tokens to tokenizer as a special tokens
1015
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
1016
+ tokenizer = copy.deepcopy(tokenizer)
1017
+ # When there is actually an image, we add the image tokens as a special token
1018
+ if has_image:
1019
+ tokenizer.add_tokens(["<image>"], special_tokens=True)
1020
+ image_token_index = tokenizer.convert_tokens_to_ids("<image>")
1021
+ bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
1022
+ start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
1023
+ end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
1024
+ eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
1025
+
1026
+ unmask_tokens = [
1027
+ "<|begin_of_text|>",
1028
+ "<|start_header_id|>",
1029
+ "<|end_header_id|>",
1030
+ "<|eot_id|>",
1031
+ "\n\n",
1032
+ ]
1033
+ unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
1034
+
1035
+ # After update, calling tokenizer of llama3 will
1036
+ # auto add bos id for the tokens. ヽ(ο½€βŒ’Β΄)οΎ‰
1037
+ def safe_tokenizer_llama3(text):
1038
+ input_ids = tokenizer(text).input_ids
1039
+ if input_ids[0] == bos_token_id:
1040
+ input_ids = input_ids[1:]
1041
+ return input_ids
1042
+
1043
+ nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
1044
+
1045
+ # chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{%- if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{%- endif %}"
1046
+ chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
1047
+ tokenizer.chat_template = chat_template
1048
+
1049
+ # Apply prompt templates
1050
+ input_ids, targets = [], []
1051
+ for source in sources:
1052
+ if roles[source[0]["from"]] != roles["human"]:
1053
+ source = source[1:]
1054
+
1055
+ input_id, target = [], []
1056
+
1057
+ # New version, use apply chat template
1058
+ # Build system message for each sentence
1059
+ input_id += tokenizer.apply_chat_template(
1060
+ [{"role": "system", "content": system_message}]
1061
+ # pyre-fixme[6]: For 1st argument expected `Union[int, str]` but got `slice`.
1062
+ )[:-4]
1063
+
1064
+ target += [IGNORE_INDEX] * len(input_id)
1065
+
1066
+ for conv in source:
1067
+ # Make sure llava data can load
1068
+ try:
1069
+ role = conv["role"]
1070
+ content = conv["content"]
1071
+ except:
1072
+ role = conv["from"]
1073
+ content = conv["value"]
1074
+
1075
+ role = roles.get(role, role)
1076
+
1077
+ conv = [{"role": role, "content": content}]
1078
+ # First is bos token we don't need here
1079
+ encode_id = tokenizer.apply_chat_template(conv)[1:-4]
1080
+ input_id += encode_id
1081
+ if role in ["user", "system"]:
1082
+ target += [IGNORE_INDEX] * len(encode_id)
1083
+ else:
1084
+ target += encode_id
1085
+
1086
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
1087
+ for idx, encode_id in enumerate(input_id):
1088
+ if encode_id in unmask_tokens_idx:
1089
+ target[idx] = encode_id
1090
+ if encode_id == image_token_index:
1091
+ input_id[idx] = IMAGE_TOKEN_INDEX
1092
+ input_ids.append(input_id)
1093
+ targets.append(target)
1094
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
1095
+ targets = torch.tensor(targets, dtype=torch.long)
1096
+
1097
+ # print("input_ids", input_ids, flush=True)
1098
+ # print("targets", targets, flush=True)
1099
+ return dict(
1100
+ input_ids=input_ids, # tensor(bs x seq_len)
1101
+ labels=targets, # tensor(bs x seq_len)
1102
+ )
1103
+
1104
+
1105
+ def preprocess_llama_3_1(
1106
+ sources,
1107
+ tokenizer: transformers.PreTrainedTokenizer,
1108
+ has_image: bool = False,
1109
+ ) -> dict:
1110
+ conv = conversation_lib.default_conversation.copy()
1111
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
1112
+
1113
+ # Apply prompt templates
1114
+ conversations = []
1115
+ for source in sources:
1116
+ if roles[source[0]["from"]] != conv.roles[0]:
1117
+ # Skip the first one if it is not from human
1118
+ source = source[1:]
1119
+
1120
+ conv.messages = []
1121
+ for sentence in source:
1122
+ if sentence["from"] == "Answer":
1123
+ sentence["from"] = "gpt" # data bug
1124
+ role = roles[sentence["from"]]
1125
+ # assert role == conv.roles[j % 2], f"{i}"
1126
+ conv.append_message(role, sentence["value"])
1127
+ conversations.append(conv.get_prompt())
1128
+
1129
+ # Tokenize conversations
1130
+
1131
+ if has_image:
1132
+ input_ids = torch.stack(
1133
+ [
1134
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1135
+ for prompt in conversations
1136
+ ],
1137
+ dim=0,
1138
+ )
1139
+ else:
1140
+ input_ids = tokenizer(
1141
+ conversations,
1142
+ return_tensors="pt",
1143
+ padding="longest",
1144
+ max_length=tokenizer.model_max_length,
1145
+ truncation=True,
1146
+ ).input_ids
1147
+
1148
+ # remove the first bos token
1149
+ if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
1150
+ input_ids = input_ids[:, 1:]
1151
+ targets = input_ids.clone()
1152
+
1153
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_1
1154
+
1155
+ # Mask targets
1156
+ sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
1157
+ # sep = conv.sep + conv.roles[1] + ": "
1158
+ for conversation, target in zip(conversations, targets):
1159
+ total_len = int(target.shape[0])
1160
+
1161
+ rounds = conversation.split(conv.tokenizer.eos_token)
1162
+ rounds = [rounds[0]] + [
1163
+ rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
1164
+ ]
1165
+
1166
+ cur_len = 1
1167
+ target[:cur_len] = IGNORE_INDEX
1168
+ for i, rou in enumerate(rounds):
1169
+ if rou == "":
1170
+ break
1171
+
1172
+ parts = rou.split(sep)
1173
+ if len(parts) != 2 and i != 0:
1174
+ break
1175
+
1176
+ if i == 0:
1177
+ round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
1178
+ instruction_len = len(
1179
+ tokenizer(rou, add_special_tokens=False).input_ids
1180
+ )
1181
+
1182
+ else:
1183
+ parts[0] += sep
1184
+ if has_image:
1185
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
1186
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
1187
+ else:
1188
+ round_len = len(tokenizer(rou).input_ids) + 1
1189
+ instruction_len = len(tokenizer(parts[0]).input_ids)
1190
+
1191
+ # if i > 0: round_len += 1
1192
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
1193
+ cur_len += round_len
1194
+
1195
+ target[cur_len:] = IGNORE_INDEX
1196
+ cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
1197
+
1198
+ # if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
1199
+ if cur_len < tokenizer.model_max_length:
1200
+ if cur_len != total_len:
1201
+ target[:] = IGNORE_INDEX
1202
+ print(
1203
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
1204
+ f" (ignored)"
1205
+ )
1206
+
1207
+ return dict(
1208
+ input_ids=input_ids,
1209
+ labels=targets,
1210
+ )
1211
+
1212
+
1213
+ def preprocess_llama_3_2(
1214
+ sources,
1215
+ tokenizer: transformers.PreTrainedTokenizer,
1216
+ has_image: bool = False,
1217
+ ) -> dict:
1218
+ conv = conversation_lib.default_conversation.copy()
1219
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
1220
+
1221
+ # Apply prompt templates
1222
+ conversations = []
1223
+ for i, source in enumerate(sources):
1224
+ if roles[source[0]["from"]] != conv.roles[0]:
1225
+ # Skip the first one if it is not from human
1226
+ source = source[1:]
1227
+
1228
+ conv.messages = []
1229
+ for j, sentence in enumerate(source):
1230
+ role = roles[sentence["from"]]
1231
+ assert role == conv.roles[j % 2], f"{i}"
1232
+ conv.append_message(role, sentence["value"])
1233
+ conversations.append(conv.get_prompt())
1234
+
1235
+ # Tokenize conversations
1236
+
1237
+ if has_image:
1238
+ input_ids = torch.stack(
1239
+ [
1240
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1241
+ for prompt in conversations
1242
+ ],
1243
+ dim=0,
1244
+ )
1245
+ else:
1246
+ input_ids = tokenizer(
1247
+ conversations,
1248
+ return_tensors="pt",
1249
+ padding="longest",
1250
+ max_length=tokenizer.model_max_length,
1251
+ truncation=True,
1252
+ ).input_ids
1253
+
1254
+ # remove the first bos token
1255
+ if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id:
1256
+ input_ids = input_ids[:, 1:]
1257
+ targets = input_ids.clone()
1258
+
1259
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_2
1260
+
1261
+ # Mask targets
1262
+ sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n"
1263
+ # sep = conv.sep + conv.roles[1] + ": "
1264
+ for conversation, target in zip(conversations, targets):
1265
+ total_len = int(target.shape[0])
1266
+
1267
+ rounds = conversation.split(conv.tokenizer.eos_token)
1268
+ rounds = [rounds[0]] + [
1269
+ rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2)
1270
+ ]
1271
+
1272
+ cur_len = 1
1273
+ target[:cur_len] = IGNORE_INDEX
1274
+ for i, rou in enumerate(rounds):
1275
+ if rou == "":
1276
+ break
1277
+
1278
+ parts = rou.split(sep)
1279
+ if len(parts) != 2 and i != 0:
1280
+ break
1281
+
1282
+ if i == 0:
1283
+ round_len = len(tokenizer(rou, add_special_tokens=False).input_ids)
1284
+ instruction_len = len(
1285
+ tokenizer(rou, add_special_tokens=False).input_ids
1286
+ )
1287
+
1288
+ else:
1289
+ parts[0] += sep
1290
+ if has_image:
1291
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
1292
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
1293
+ else:
1294
+ round_len = len(tokenizer(rou).input_ids) + 1
1295
+ instruction_len = len(tokenizer(parts[0]).input_ids)
1296
+
1297
+ # if i > 0: round_len += 1
1298
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
1299
+ cur_len += round_len
1300
+
1301
+ target[cur_len:] = IGNORE_INDEX
1302
+ cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids)
1303
+
1304
+ # if cur_len > tokenizer.model_max_length: print(f"WARNING: max length context")
1305
+ if cur_len < tokenizer.model_max_length:
1306
+ if cur_len != total_len:
1307
+ target[:] = IGNORE_INDEX
1308
+ print(
1309
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
1310
+ f" (ignored)"
1311
+ )
1312
+
1313
+ return dict(
1314
+ input_ids=input_ids,
1315
+ labels=targets,
1316
+ )
1317
+
1318
+
1319
+ def preprocess_phi3(
1320
+ sources,
1321
+ tokenizer: transformers.PreTrainedTokenizer,
1322
+ has_image: bool = False,
1323
+ ) -> dict:
1324
+ conv = conversation_lib.conv_templates["phi3"].copy()
1325
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
1326
+
1327
+ # Apply prompt templates
1328
+ conversations = []
1329
+ for i, source in enumerate(sources):
1330
+ if roles[source[0]["from"]] != conv.roles[0]:
1331
+ # Skip the first one if it is not from human
1332
+ source = source[1:]
1333
+
1334
+ conv.messages = []
1335
+ for j, sentence in enumerate(source):
1336
+ role = roles[sentence["from"]]
1337
+ assert role == conv.roles[j % 2], f"{i}"
1338
+ conv.append_message(role, sentence["value"])
1339
+ conversations.append(conv.get_prompt())
1340
+
1341
+ # Tokenize conversations
1342
+ if has_image:
1343
+ input_ids = torch.stack(
1344
+ [
1345
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1346
+ for prompt in conversations
1347
+ ],
1348
+ dim=0,
1349
+ )
1350
+ else:
1351
+ input_ids = tokenizer(
1352
+ conversations,
1353
+ return_tensors="pt",
1354
+ padding="longest",
1355
+ max_length=tokenizer.model_max_length,
1356
+ truncation=True,
1357
+ ).input_ids
1358
+
1359
+ targets = input_ids.clone()
1360
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
1361
+
1362
+ # Mask targets
1363
+ sep = conv.sep + conv.roles[1]
1364
+ for conversation, target in zip(conversations, targets):
1365
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
1366
+
1367
+ rounds = conversation.split(conv.sep)
1368
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
1369
+ for conv_idx in range(3, len(rounds), 2):
1370
+ re_rounds.append(
1371
+ conv.sep.join(rounds[conv_idx : conv_idx + 2])
1372
+ ) # user + gpt
1373
+ cur_len = 0
1374
+ target[:cur_len] = IGNORE_INDEX
1375
+ for i, rou in enumerate(re_rounds):
1376
+ if rou == "":
1377
+ break
1378
+
1379
+ parts = rou.split(sep)
1380
+ if len(parts) != 2:
1381
+ break
1382
+ parts[0] += sep
1383
+
1384
+ if has_image:
1385
+ round_len = len(tokenizer_image_token(rou, tokenizer))
1386
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
1387
+ else:
1388
+ round_len = len(tokenizer(rou).input_ids)
1389
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
1390
+
1391
+ if i == 0:
1392
+ round_len += 1
1393
+ instruction_len += 1
1394
+ else:
1395
+ round_len -= 2
1396
+ instruction_len -= 2
1397
+
1398
+ if (
1399
+ i != 0
1400
+ and getattr(tokenizer, "legacy", False)
1401
+ and IS_TOKENIZER_GREATER_THAN_0_14
1402
+ ):
1403
+ round_len += 1
1404
+ instruction_len += 1
1405
+
1406
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
1407
+
1408
+ cur_len += round_len
1409
+ target[cur_len:] = IGNORE_INDEX
1410
+
1411
+ if cur_len < tokenizer.model_max_length:
1412
+ if cur_len != total_len:
1413
+ target[:] = IGNORE_INDEX
1414
+ print(
1415
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
1416
+ f" (ignored)"
1417
+ )
1418
+
1419
+ return dict(
1420
+ input_ids=input_ids,
1421
+ labels=targets,
1422
+ )
1423
+
1424
+
1425
+ def preprocess_mpt(
1426
+ sources,
1427
+ tokenizer: transformers.PreTrainedTokenizer,
1428
+ has_image: bool = False,
1429
+ ) -> dict:
1430
+ conv = conversation_lib.default_conversation.copy()
1431
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
1432
+
1433
+ # Apply prompt templates
1434
+ conversations = []
1435
+ for i, source in enumerate(sources):
1436
+ if roles[source[0]["from"]] != conv.roles[0]:
1437
+ # Skip the first one if it is not from human
1438
+ source = source[1:]
1439
+
1440
+ conv.messages = []
1441
+ for j, sentence in enumerate(source):
1442
+ role = roles[sentence["from"]]
1443
+ assert role == conv.roles[j % 2], f"{i}"
1444
+ conv.append_message(role, sentence["value"])
1445
+ conversations.append(conv.get_prompt())
1446
+
1447
+ # Tokenize conversations
1448
+ if has_image:
1449
+ input_ids = torch.stack(
1450
+ [
1451
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1452
+ for prompt in conversations
1453
+ ],
1454
+ dim=0,
1455
+ )
1456
+ else:
1457
+ input_ids = tokenizer(
1458
+ conversations,
1459
+ return_tensors="pt",
1460
+ padding="longest",
1461
+ max_length=tokenizer.model_max_length,
1462
+ truncation=True,
1463
+ ).input_ids
1464
+
1465
+ targets = input_ids.clone()
1466
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
1467
+
1468
+ # Mask targets
1469
+ sep = conv.sep + conv.roles[1]
1470
+ for conversation, target in zip(conversations, targets):
1471
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
1472
+
1473
+ rounds = conversation.split(conv.sep)
1474
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
1475
+ for conv_idx in range(3, len(rounds), 2):
1476
+ re_rounds.append(
1477
+ conv.sep.join(rounds[conv_idx : conv_idx + 2])
1478
+ ) # user + gpt
1479
+ cur_len = 0
1480
+ target[:cur_len] = IGNORE_INDEX
1481
+ for i, rou in enumerate(re_rounds):
1482
+ if rou == "":
1483
+ break
1484
+
1485
+ parts = rou.split(sep)
1486
+ if len(parts) != 2:
1487
+ break
1488
+ parts[0] += sep
1489
+ if has_image:
1490
+ round_len = len(tokenizer_image_token(rou, tokenizer))
1491
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
1492
+ else:
1493
+ round_len = len(tokenizer(rou).input_ids)
1494
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
1495
+
1496
+ if (
1497
+ i != 0
1498
+ and getattr(tokenizer, "legacy", False)
1499
+ and IS_TOKENIZER_GREATER_THAN_0_14
1500
+ ):
1501
+ round_len += 1
1502
+ instruction_len += 1
1503
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
1504
+
1505
+ cur_len += round_len
1506
+ target[cur_len:] = IGNORE_INDEX
1507
+
1508
+ if cur_len < tokenizer.model_max_length:
1509
+ if cur_len != total_len:
1510
+ target[:] = IGNORE_INDEX
1511
+ print(
1512
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
1513
+ f" (ignored)"
1514
+ )
1515
+
1516
+ return dict(
1517
+ input_ids=input_ids,
1518
+ labels=targets,
1519
+ )
1520
+
1521
+
1522
+ def preprocess_plain(
1523
+ sources: Sequence[str],
1524
+ tokenizer: transformers.PreTrainedTokenizer,
1525
+ ) -> dict:
1526
+ # add end signal and concatenate together
1527
+ conversations = []
1528
+ for source in sources:
1529
+ assert len(source) == 2
1530
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
1531
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
1532
+ conversation = (
1533
+ source[0]["value"]
1534
+ + source[1]["value"]
1535
+ + conversation_lib.default_conversation.sep
1536
+ )
1537
+ conversations.append(conversation)
1538
+ # tokenize conversations
1539
+ input_ids = [
1540
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1541
+ for prompt in conversations
1542
+ ]
1543
+ targets = copy.deepcopy(input_ids)
1544
+ for target, source in zip(targets, sources):
1545
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
1546
+ target[:tokenized_len] = IGNORE_INDEX
1547
+
1548
+ return dict(input_ids=input_ids, labels=targets)
1549
+
1550
+
1551
+ def preprocess(
1552
+ sources: Sequence[str],
1553
+ tokenizer: transformers.PreTrainedTokenizer,
1554
+ has_image: bool = False,
1555
+ ) -> dict:
1556
+ """
1557
+ Given a list of sources, each is a conversation list. This transform:
1558
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
1559
+ 2. Concatenate conversations together;
1560
+ 3. Tokenize the concatenated conversation;
1561
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
1562
+ """
1563
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
1564
+ return preprocess_plain(sources, tokenizer)
1565
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
1566
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
1567
+ if conversation_lib.default_conversation.version.startswith("v1"):
1568
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
1569
+ if conversation_lib.default_conversation.version == "mpt":
1570
+ return preprocess_mpt(sources, tokenizer, has_image=has_image)
1571
+ if conversation_lib.default_conversation.version == "phi3":
1572
+ return preprocess_phi3(sources, tokenizer, has_image=has_image)
1573
+ if conversation_lib.default_conversation.version == "qwen":
1574
+ return preprocess_qwen(sources, tokenizer, has_image=has_image)
1575
+ # add end signal and concatenate together
1576
+ conversations = []
1577
+ for source in sources:
1578
+ header = f"{conversation_lib.default_conversation.system}\n\n"
1579
+ conversation = _add_speaker_and_signal(header, source)
1580
+ conversations.append(conversation)
1581
+
1582
+ # tokenize conversations
1583
+ def get_tokenize_len(prompts):
1584
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
1585
+
1586
+ if has_image:
1587
+ input_ids = [
1588
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
1589
+ for prompt in conversations
1590
+ ]
1591
+ else:
1592
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
1593
+ input_ids = conversations_tokenized["input_ids"]
1594
+
1595
+ targets = copy.deepcopy(input_ids)
1596
+ for target, source in zip(targets, sources):
1597
+ if has_image:
1598
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
1599
+ else:
1600
+ tokenized_lens = _tokenize_fn(
1601
+ [header] + [s["value"] for s in source],
1602
+ tokenizer,
1603
+ )["input_ids_lens"]
1604
+ speakers = [sentence["from"] for sentence in source]
1605
+ _mask_targets(target, tokenized_lens, speakers)
1606
+
1607
+ return dict(input_ids=input_ids, labels=targets)
tempo/multimodal_encoder/__pycache__/base_encoder.cpython-312.pyc ADDED
Binary file (6.52 kB). View file
 
tempo/multimodal_encoder/__pycache__/builder.cpython-312.pyc ADDED
Binary file (1.28 kB). View file
 
tempo/multimodal_encoder/__pycache__/qwen3vl_encoder.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
tempo/multimodal_encoder/__pycache__/siglip_encoder.cpython-312.pyc ADDED
Binary file (4.29 kB). View file
 
tempo/multimodal_encoder/base_encoder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ProcessorWrapper:
8
+ def __init__(
9
+ self,
10
+ transform,
11
+ height=378,
12
+ width=378,
13
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
14
+ ):
15
+ self._crop_size = {
16
+ "height": height,
17
+ "width": width,
18
+ }
19
+ self._transforms = transform
20
+ # print(transform)
21
+ self.image_mean = image_mean
22
+
23
+ @property
24
+ def crop_size(self):
25
+ return self._crop_size
26
+
27
+ def preprocess(self, image, return_tensors="pt"):
28
+ # Ensure image is a PIL Image
29
+ output = {}
30
+ output["pixel_values"] = [self._transforms(image)]
31
+ return output
32
+
33
+
34
+ class BaseVisionTower(nn.Module):
35
+ def __init__(self, vision_tower_name, args, delay_load=False):
36
+ super().__init__()
37
+
38
+ self.is_loaded = False
39
+ self.args = args
40
+
41
+ self.vision_tower_name = vision_tower_name
42
+ self.select_layer = args.mm_vision_select_layer
43
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
44
+ self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
45
+ self.delay_load = delay_load
46
+
47
+ @abstractmethod
48
+ def load_model(self, device_map=None):
49
+ raise NotImplementedError("Subclasses must implement load_model")
50
+
51
+ @abstractmethod
52
+ def _forward(self, images):
53
+ raise NotImplementedError("Subclasses must implement forward")
54
+
55
+ def forward(self, images):
56
+ if type(images) is list:
57
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
58
+ else:
59
+ image_features = self._forward(images)
60
+
61
+ return image_features
62
+
63
+ @property
64
+ def dummy_feature(self):
65
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
66
+
67
+ @property
68
+ def dtype(self):
69
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
70
+ if hasattr(self.vision_tower, "dtype"):
71
+ return self.vision_tower.dtype
72
+ else:
73
+ params = list(self.vision_tower.parameters())
74
+ return (
75
+ params[0].dtype if len(params) > 0 else torch.float32
76
+ ) # Default to torch.float32 if no parameters
77
+
78
+ @property
79
+ def device(self):
80
+ # Dynamically infer the device from the first parameter, if not explicitly specified
81
+ if hasattr(self.vision_tower, "device"):
82
+ return self.vision_tower.device
83
+ else:
84
+ params = list(self.vision_tower.parameters())
85
+ return (
86
+ params[0].device if len(params) > 0 else torch.device("cpu")
87
+ ) # Default to CPU if no parameters
88
+
89
+ @property
90
+ def config(self):
91
+ if self.is_loaded:
92
+ return self.vision_tower.config
93
+ else:
94
+ return self.cfg_only
95
+
96
+ @property
97
+ def hidden_size(self):
98
+ try:
99
+ return self.config.hidden_size
100
+ except:
101
+ return self._hidden_size
102
+
103
+ @property
104
+ def image_size(self): # resolution
105
+ # return self.config.image_size
106
+ try:
107
+ return self.config.image_size
108
+ except:
109
+ return self._image_size
110
+
111
+ @property
112
+ def patch_size(self):
113
+ # return self.config.patch_size
114
+ try:
115
+ return self.config.patch_size
116
+ except:
117
+ return self._patch_size
118
+
119
+ @property
120
+ def num_patches_per_side(self):
121
+ if self._interp_size is not None:
122
+ return int(self._interp_size**0.5)
123
+ try:
124
+ return self.image_size // self.patch_size
125
+ except:
126
+ return self._num_patches_per_side
127
+
128
+ @property
129
+ def num_patches(self):
130
+ if self._interp_size is not None:
131
+ return self._interp_size
132
+ try:
133
+ return self.num_patches_per_side**2
134
+ except:
135
+ return self._num_patches
tempo/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from pathlib import Path
3
+ from .qwen3vl_encoder import Qwen3VLTower
4
+ from .siglip_encoder import SiglipVisionTower
5
+
6
+ def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
7
+
8
+ vision_tower_aux_name_list = getattr(vision_tower_cfg, "mm_vision_tower_aux_list", ["Qwen/Qwen3-VL-2B-Instruct"])
9
+
10
+ vision_tower_aux_list = []
11
+ for vision_tower_aux_name in vision_tower_aux_name_list:
12
+ config = copy.deepcopy(vision_tower_cfg)
13
+ vision_tower_basename = Path(vision_tower_aux_name).name.lower()
14
+ if "siglip" in vision_tower_basename:
15
+ vision_tower_aux_list.append(SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs))
16
+ elif "qwen3-vl" in vision_tower_basename:
17
+ vision_tower_aux_list.append(Qwen3VLTower(vision_tower_aux_name, args=config, **kwargs))
18
+ else:
19
+ raise ValueError(f"Unknown vision tower: {vision_tower_basename}")
20
+
21
+ return vision_tower_aux_list
tempo/multimodal_encoder/qwen3vl_encoder.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import random
3
+ random.seed(42)
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from accelerate import init_empty_weights
8
+ from transformers.utils import is_torchdynamo_compiling
9
+ from transformers import AutoConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor
10
+
11
+ class Qwen3VLTower(nn.Module):
12
+ def __init__(self, vision_tower_aux_name, args, **kwargs):
13
+ super(Qwen3VLTower, self).__init__()
14
+
15
+ self.is_loaded = True # for compatibility
16
+ self.model_path = vision_tower_aux_name
17
+ self.dynamic_compress = getattr(args, "dynamic_compress", False)
18
+
19
+ # load processor
20
+ self.image_processor = Qwen3VLProcessor.from_pretrained(self.model_path)
21
+
22
+ # load config
23
+ self.config = AutoConfig.from_pretrained(self.model_path)
24
+ self.config._attn_implementation = "flash_attention_2"
25
+ self.config.dtype = torch.bfloat16
26
+
27
+ # load model
28
+ with init_empty_weights():
29
+ self.vlm = Qwen3VLForConditionalGeneration(self.config)
30
+ if hasattr(self.vlm, "lm_head"):
31
+ del self.vlm.lm_head
32
+
33
+ self.vlm.requires_grad_(False)
34
+
35
+ self.hidden_size = self.config.text_config.hidden_size
36
+ self.num_compression_tokens = args.num_compression_tokens
37
+
38
+ self.compression_tokens = nn.Parameter(
39
+ torch.empty(1, self.num_compression_tokens, self.hidden_size)
40
+ )
41
+
42
+ def smart_init_vision_tower(self):
43
+ """Load only during Stage 0"""
44
+
45
+ temp_model = Qwen3VLForConditionalGeneration.from_pretrained(
46
+ self.model_path,
47
+ dtype=self.vlm.dtype,
48
+ device_map="cpu", # avoid multiple nodes and gpu conflicit
49
+ )
50
+
51
+ missing_keys, unexpected_keys = self.vlm.load_state_dict(temp_model.state_dict(), strict=False)
52
+
53
+ if len(missing_keys) > 0:
54
+ print(f"[Warning] Missing keys in Qwen3-VL loading: {missing_keys}")
55
+ if len(unexpected_keys) > 0:
56
+ print(f"[Warning] Unexpected keys keys in Qwen3-VL loading: {unexpected_keys}")
57
+
58
+ del temp_model
59
+ gc.collect()
60
+ torch.cuda.empty_cache()
61
+
62
+ self.vlm.requires_grad_(False)
63
+
64
+ with torch.no_grad():
65
+ embed_weights = self.vlm.model.language_model.embed_tokens.weight
66
+ mean = embed_weights.mean(dim=0)
67
+ std = embed_weights.std(dim=0)
68
+
69
+ self.compression_tokens.data = torch.normal(mean=mean.repeat(self.num_compression_tokens, 1), std=std.repeat(self.num_compression_tokens, 1)).unsqueeze(0)
70
+ print(f"[Smart Init] Done. Shape: {self.compression_tokens.shape}")
71
+
72
+ def smart_init_dynamic_compress(self):
73
+ if getattr(self, "vlm_head", None) is None:
74
+ temp_model = Qwen3VLForConditionalGeneration.from_pretrained(
75
+ self.model_path,
76
+ dtype=self.vlm.dtype,
77
+ device_map="cpu", # avoid multiple nodes and gpu conflicit
78
+ )
79
+ self.vlm_head = temp_model.lm_head
80
+
81
+ del temp_model
82
+ gc.collect()
83
+ torch.cuda.empty_cache()
84
+
85
+ token_true_id = self.image_processor.tokenizer.get_vocab()["Yes"]
86
+ token_false_id = self.image_processor.tokenizer.get_vocab()["No"]
87
+
88
+ lm_head_weights = self.vlm_head.weight.data
89
+ weight_yes = lm_head_weights[token_true_id]
90
+ weight_no = lm_head_weights[token_false_id]
91
+
92
+ D = weight_yes.size()[0]
93
+ self.linear_layer = nn.Linear(D, 1, bias=False)
94
+ with torch.no_grad():
95
+ self.linear_layer.weight[0] = weight_yes - weight_no
96
+
97
+ del self.vlm_head
98
+ self.linear_layer.to("cuda")
99
+
100
+ print(f"[Smart Init Router] Done!")
101
+
102
+ def compute_relevance(self, mask_compression_token, batch_size, last_hidden_state):
103
+ first_compression_idx = mask_compression_token.int().argmax(
104
+ dim=1
105
+ ) # (batch_size,)
106
+ prev_idx = first_compression_idx - 1 # (batch_size,)
107
+ batch_indices = torch.arange(batch_size, device=last_hidden_state.device)
108
+ prev_token_features = last_hidden_state[
109
+ batch_indices, prev_idx
110
+ ] # (batch_size, hidden_dim)
111
+
112
+ scores = self.linear_layer(prev_token_features.float())
113
+ scores = torch.sigmoid(scores).squeeze(-1).cpu().detach().tolist()
114
+
115
+ return scores
116
+
117
+ def load_model(self):
118
+ # for compatible with other encoder
119
+ pass
120
+
121
+ def forward(
122
+ self,
123
+ input_ids=None,
124
+ attention_mask=None,
125
+ position_ids=None,
126
+ past_key_values=None,
127
+ inputs_embeds=None,
128
+ pixel_values=None,
129
+ pixel_values_videos=None,
130
+ image_grid_thw=None,
131
+ video_grid_thw=None,
132
+ cache_position=None,
133
+ **kwargs,
134
+ ):
135
+ if self.dynamic_compress and not hasattr(self, "linear_layer"):
136
+ self.smart_init_dynamic_compress()
137
+
138
+ if (input_ids is None) ^ (inputs_embeds is not None):
139
+ raise ValueError(
140
+ "You must specify exactly one of input_ids or inputs_embeds"
141
+ )
142
+
143
+ used_compression_tokens = self.num_compression_tokens
144
+
145
+ if inputs_embeds is None:
146
+ # process input_ids to insert learnable token
147
+ batch_size, n_seq = input_ids.shape
148
+ valid_lengths = (
149
+ input_ids != self.image_processor.tokenizer.pad_token_id
150
+ ).sum(dim=1)
151
+ input_ids = torch.cat(
152
+ [
153
+ input_ids,
154
+ torch.full(
155
+ (batch_size, used_compression_tokens),
156
+ # self.config.pad_token_id,
157
+ self.image_processor.tokenizer.pad_token_id,
158
+ dtype=input_ids.dtype,
159
+ device=input_ids.device,
160
+ ),
161
+ ],
162
+ dim=1,
163
+ )
164
+ attention_mask = torch.cat(
165
+ [
166
+ attention_mask,
167
+ torch.zeros(
168
+ (batch_size, used_compression_tokens),
169
+ dtype=attention_mask.dtype,
170
+ device=attention_mask.device,
171
+ ),
172
+ ],
173
+ dim=1,
174
+ )
175
+ inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
176
+ else:
177
+ raise NotImplementedError(
178
+ "Current only support input_ids as vlm compressor inputs"
179
+ )
180
+
181
+ image_mask = None
182
+ video_mask = None
183
+
184
+ if pixel_values is not None:
185
+ image_embeds, deepstack_image_embeds = self.vlm.get_image_features(
186
+ pixel_values, image_grid_thw
187
+ )
188
+ image_embeds = torch.cat(image_embeds, dim=0).to(
189
+ inputs_embeds.device, inputs_embeds.dtype
190
+ )
191
+ image_mask, _ = self.vlm.model.get_placeholder_mask(
192
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
193
+ )
194
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
195
+
196
+ if pixel_values_videos is not None:
197
+ video_embeds, deepstack_video_embeds = self.vlm.get_video_features(
198
+ pixel_values_videos, video_grid_thw
199
+ )
200
+ video_embeds = torch.cat(video_embeds, dim=0).to(
201
+ inputs_embeds.device, inputs_embeds.dtype
202
+ )
203
+ _, video_mask = self.vlm.model.get_placeholder_mask(
204
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
205
+ )
206
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
207
+
208
+ visual_pos_masks = None
209
+ deepstack_visual_embeds = None
210
+ if image_mask is not None and video_mask is not None:
211
+ # aggregate visual_pos_masks and deepstack_visual_embeds
212
+ image_mask = image_mask[..., 0]
213
+ video_mask = video_mask[..., 0]
214
+ visual_pos_masks = image_mask | video_mask
215
+ deepstack_visual_embeds = []
216
+ image_mask_joint = image_mask[visual_pos_masks]
217
+ video_mask_joint = video_mask[visual_pos_masks]
218
+ for img_embed, vid_embed in zip(
219
+ deepstack_image_embeds, deepstack_video_embeds
220
+ ):
221
+ embed_joint = img_embed.new_zeros(
222
+ visual_pos_masks.sum(), img_embed.shape[-1]
223
+ ).to(img_embed.device)
224
+ embed_joint[image_mask_joint, :] = img_embed
225
+ embed_joint[video_mask_joint, :] = vid_embed
226
+ deepstack_visual_embeds.append(embed_joint)
227
+ elif image_mask is not None:
228
+ image_mask = image_mask[..., 0]
229
+ visual_pos_masks = image_mask
230
+ deepstack_visual_embeds = deepstack_image_embeds
231
+ elif video_mask is not None:
232
+ video_mask = video_mask[..., 0]
233
+ visual_pos_masks = video_mask
234
+ deepstack_visual_embeds = deepstack_video_embeds
235
+
236
+ # ------------------------------------------------------------------
237
+ # inputs_embeds, [Text + Image + Video],shape: (B, L, D)
238
+ # concat Learnable Tokens
239
+ position_compression_token = (
240
+ torch.arange(n_seq + used_compression_tokens, device=input_ids.device)
241
+ .unsqueeze(0)
242
+ .expand(batch_size, -1)
243
+ )
244
+ mask_compression_token = (
245
+ position_compression_token >= valid_lengths.unsqueeze(1)
246
+ ) & (
247
+ position_compression_token
248
+ < (valid_lengths + used_compression_tokens).unsqueeze(1)
249
+ )
250
+ compression_tokens_expanded = self.compression_tokens[
251
+ :, :used_compression_tokens, :
252
+ ].expand(batch_size, -1, -1)
253
+ inputs_embeds[mask_compression_token] = compression_tokens_expanded.reshape(
254
+ -1, self.hidden_size
255
+ ).to(inputs_embeds.dtype)
256
+ attention_mask.masked_fill_(mask_compression_token, 1)
257
+ # ------------------------------------------------------------------
258
+
259
+ if position_ids is None:
260
+ attention_mask_tensor = (
261
+ attention_mask
262
+ if not isinstance(attention_mask, dict)
263
+ else attention_mask["full_attention"]
264
+ )
265
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
266
+ attention_mask_tensor = torch.diagonal(
267
+ attention_mask_tensor[:, 0], dim1=1, dim2=2
268
+ )
269
+ # Only apply conversion for floating point tensors (inverted masks)
270
+ if attention_mask_tensor.dtype.is_floating_point:
271
+ attention_mask_tensor = (
272
+ attention_mask_tensor
273
+ / torch.finfo(attention_mask_tensor.dtype).min
274
+ )
275
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
276
+
277
+ # Calculate RoPE index once per generation in the pre-fill stage only.
278
+ # When compiling, we can't check tensor values thus we check only input length
279
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
280
+ # models currently cannot do asssisted decoding
281
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
282
+ (input_ids is not None and input_ids.shape[1] != 1)
283
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
284
+ )
285
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
286
+ (cache_position is not None and cache_position[0] == 0)
287
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
288
+ )
289
+ if (
290
+ prefill_compiled_stage or prefill_noncompiled_stage
291
+ ) or self.rope_deltas is None:
292
+ position_ids, rope_deltas = self.vlm.model.get_rope_index(
293
+ input_ids,
294
+ image_grid_thw,
295
+ video_grid_thw,
296
+ attention_mask=attention_mask_tensor,
297
+ )
298
+ self.rope_deltas = rope_deltas
299
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
300
+ else:
301
+ batch_size, seq_length, _ = inputs_embeds.shape
302
+ delta = (
303
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
304
+ if cache_position is not None
305
+ else 0
306
+ )
307
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
308
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
309
+ if cache_position is not None: # otherwise `deltas` is an int `0`
310
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
311
+ position_ids = position_ids.add(delta)
312
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
313
+
314
+ outputs = self.vlm.model.language_model(
315
+ input_ids=None,
316
+ position_ids=position_ids,
317
+ attention_mask=attention_mask,
318
+ past_key_values=past_key_values,
319
+ inputs_embeds=inputs_embeds,
320
+ cache_position=cache_position,
321
+ visual_pos_masks=visual_pos_masks,
322
+ deepstack_visual_embeds=deepstack_visual_embeds,
323
+ **kwargs,
324
+ )
325
+
326
+ last_hidden_state = outputs.last_hidden_state
327
+ compression_features_flat = last_hidden_state[mask_compression_token]
328
+ compression_features = compression_features_flat.reshape(
329
+ batch_size, used_compression_tokens, -1
330
+ )
331
+
332
+ relevance_scores = None
333
+ if self.dynamic_compress:
334
+ relevance_scores = self.compute_relevance(mask_compression_token, batch_size, last_hidden_state)
335
+
336
+ return compression_features, relevance_scores
tempo/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import SiglipImageProcessor, SiglipVisionModel
4
+
5
+ from .base_encoder import BaseVisionTower
6
+
7
+
8
+ class SiglipVisionTower(BaseVisionTower):
9
+ def __init__(self, vision_tower_name, args, delay_load=False):
10
+ super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
11
+ model_path, res, interp = vision_tower_name, 384, 576
12
+ self.vision_tower_name = model_path
13
+ self._image_size = res if res is not None else 512
14
+ self._interp_size = interp
15
+ if not self.delay_load:
16
+ self.load_model()
17
+ elif self.unfreeze_mm_vision_tower:
18
+ self.load_model()
19
+ else:
20
+ self._hidden_size = 1152
21
+
22
+ def load_model(self, device_map=None):
23
+ self.vision_model = "siglip"
24
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
25
+
26
+ # self.vision_tower = clip_model.visual.trunk
27
+ self.vision_tower.output_tokens = True
28
+
29
+ self._hidden_size = self.vision_tower.config.hidden_size
30
+ self._image_size = self.vision_tower.config.image_size
31
+ self._patch_size = self.vision_tower.config.patch_size
32
+ self.image_processor = SiglipImageProcessor.from_pretrained(
33
+ self.vision_tower_name
34
+ )
35
+
36
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
37
+ self.is_loaded = True
38
+
39
+ def interpolate(self, image_features):
40
+ if self._interp_size is None:
41
+ return image_features
42
+
43
+ b, num_tokens, dim = image_features.shape
44
+
45
+ if num_tokens != self.num_patches:
46
+ target_h = target_w = int(self._interp_size**0.5)
47
+ h = w = int(num_tokens**0.5)
48
+
49
+ image_features = image_features.view(b, h, w, dim)
50
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
51
+
52
+ image_features = F.interpolate(
53
+ image_features.to(torch.float32),
54
+ size=(target_h, target_w),
55
+ mode="bilinear",
56
+ align_corners=False,
57
+ ).to(image_features.dtype)
58
+
59
+ # Permute the dimensions back to (b, target_h, target_w, dim)
60
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
61
+
62
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
63
+ image_features = image_features.flatten(1, 2)
64
+
65
+ return image_features
66
+
67
+ def _forward(self, images, interpolate_token=576):
68
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
69
+ embeddings = self.vision_tower.vision_model.embeddings(images)
70
+ encoder_outputs = self.vision_tower.vision_model.encoder(
71
+ inputs_embeds=embeddings
72
+ )
73
+ image_features = encoder_outputs.last_hidden_state
74
+ interp_features = self.interpolate(image_features)
75
+ return interp_features
tempo/multimodal_projector/__pycache__/builder.cpython-312.pyc ADDED
Binary file (3.18 kB). View file
 
tempo/multimodal_projector/builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch.nn as nn
3
+
4
+ class IdentityMap(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x, *args, **kwargs):
9
+ return x
10
+
11
+ @property
12
+ def config(self):
13
+ return {"mm_projector_type": "identity"}
14
+
15
+
16
+ class SimpleResBlock(nn.Module):
17
+ def __init__(self, channels):
18
+ super().__init__()
19
+ self.pre_norm = nn.LayerNorm(channels)
20
+
21
+ self.proj = nn.Sequential(
22
+ nn.Linear(channels, channels),
23
+ nn.GELU(),
24
+ nn.Linear(channels, channels)
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config):
33
+
34
+ projector_type = getattr(config, "mm_projector_type", "linear")
35
+
36
+ if projector_type == "linear":
37
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
+
39
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
40
+ if mlp_gelu_match:
41
+ mlp_depth = int(mlp_gelu_match.group(1))
42
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
+ for _ in range(1, mlp_depth):
44
+ modules.append(nn.GELU())
45
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
+ return nn.Sequential(*modules)
47
+
48
+ if projector_type == "identity":
49
+ return IdentityMap()
50
+
51
+ raise ValueError(f"Unknown projector type: {projector_type}")
tempo/tempo_arch.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tempo.constants import (
21
+ DEFAULT_IM_END_TOKEN,
22
+ DEFAULT_IM_START_TOKEN,
23
+ DEFAULT_IMAGE_PATCH_TOKEN,
24
+ IGNORE_INDEX,
25
+ IMAGE_TOKEN_INDEX,
26
+ )
27
+
28
+ from tempo.multimodal_encoder.builder import build_vision_tower_aux_list
29
+ from tempo.multimodal_projector.builder import build_vision_projector
30
+ from tempo.vlm_multimodal_processor import VLMMultimodalProcessor
31
+
32
+
33
+ class TempoMetaModel:
34
+ def __init__(self, config):
35
+ super(TempoMetaModel, self).__init__(config)
36
+
37
+ if hasattr(config, "mm_vision_tower_aux_list"):
38
+ self.vision_tower_aux_list = nn.ModuleList(
39
+ build_vision_tower_aux_list(config, delay_load=True)
40
+ )
41
+ config.mm_hidden_size = sum(
42
+ [
43
+ vision_tower_aux.hidden_size for vision_tower_aux in self.vision_tower_aux_list
44
+ ]
45
+ )
46
+ self.mm_projector = build_vision_projector(config)
47
+ else:
48
+ raise NotImplementedError(
49
+ "mm_vision_tower_aux_list is not found in config. Please initialize vision modules in the subclass of TempoMetaModel."
50
+ )
51
+
52
+ def get_vision_tower_aux_list(self):
53
+ vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None)
54
+ return vision_tower_aux_list
55
+
56
+ def initialize_vision_modules(self, model_args, fsdp=None):
57
+ # vision_hidden_size = model_args.vision_hidden_size
58
+ vision_tower_aux_list = model_args.vision_tower_aux_list
59
+ # vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list
60
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
61
+ self.config.mm_vision_tower_aux_list = vision_tower_aux_list
62
+ # self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list
63
+
64
+ if self.get_vision_tower_aux_list() is None:
65
+ vision_tower_aux_list = build_vision_tower_aux_list(model_args)
66
+ if model_args.unfreeze_mm_vision_tower:
67
+ self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list)
68
+ else:
69
+ self.vision_tower_aux_list = vision_tower_aux_list
70
+ else:
71
+ vision_tower_aux_list = self.vision_tower_aux_list
72
+ for vision_tower_aux in vision_tower_aux_list:
73
+ vision_tower_aux.load_model()
74
+
75
+ if model_args.unfreeze_mm_vision_tower and not isinstance(self.vision_tower_aux_list, nn.ModuleList):
76
+ self.vision_tower_aux_list = nn.ModuleList(self.vision_tower_aux_list)
77
+
78
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
79
+ # self.config.vision_hidden_size = vision_hidden_size
80
+
81
+ if getattr(self, "mm_projector", None) is None:
82
+ self.config.mm_hidden_size = sum(
83
+ [
84
+ vision_tower_aux.hidden_size for vision_tower_aux in vision_tower_aux_list
85
+ ]
86
+ )
87
+ self.mm_projector = build_vision_projector(self.config)
88
+ else:
89
+ for p in self.mm_projector.parameters():
90
+ p.requires_grad = True
91
+
92
+ if pretrain_mm_mlp_adapter is not None:
93
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
94
+
95
+ def get_w(weights, keyword):
96
+ return {
97
+ k.split(keyword + ".")[1]: v
98
+ for k, v in weights.items()
99
+ if keyword + "." in k
100
+ }
101
+
102
+ self.mm_projector.load_state_dict(
103
+ get_w(mm_projector_weights, "mm_projector"), strict=True
104
+ )
105
+
106
+
107
+ class TempoMetaForCausalLM(ABC):
108
+ @abstractmethod
109
+ def get_model(self):
110
+ pass
111
+
112
+ def get_vision_tower_aux_list(self):
113
+ return self.get_model().get_vision_tower_aux_list()
114
+
115
+ def prepare_inputs_labels_for_multimodal(
116
+ self,
117
+ input_ids,
118
+ position_ids,
119
+ attention_mask,
120
+ past_key_values,
121
+ labels,
122
+ images=None,
123
+ image_sizes=None,
124
+ vlm_inputs=None,
125
+ seg_timestamps=None,
126
+ batch_split_size=None,
127
+ relevance=None,
128
+ ):
129
+ if input_ids.shape[1] == 1: # inference
130
+ return (
131
+ input_ids,
132
+ position_ids,
133
+ attention_mask,
134
+ past_key_values,
135
+ None,
136
+ labels,
137
+ )
138
+
139
+ is_video = "pixel_values_videos" in vlm_inputs
140
+
141
+ compressed_features, relevance_scores = VLMMultimodalProcessor.tokenize_vision_inputs(self.get_vision_tower_aux_list()[0], vlm_inputs, is_video)
142
+ compressed_features, count_allocations = (
143
+ VLMMultimodalProcessor.adaptive_linear_budget_allocation(
144
+ compressed_features,
145
+ relevance_scores,
146
+ is_video,
147
+ max_budget=self.config.visual_token_budget if hasattr(self.config, "visual_token_budget") else 8192,
148
+ min_tokens=4,
149
+ strategy="head",
150
+ )
151
+ )
152
+
153
+ # for visulization of the allocation results, can be removed
154
+ self._demo_count_allocations = count_allocations
155
+
156
+ if isinstance(compressed_features, list):
157
+ seg_lens = [feat.shape[0] for feat in compressed_features]
158
+ compressed_features = torch.cat(compressed_features, dim=0)
159
+ image_features = self.get_model().mm_projector(compressed_features)
160
+ print(f"[Total Segments: {len(seg_lens)}], compression features after dynamic compress", image_features.shape)
161
+ image_features = list(torch.split(image_features, seg_lens, dim=0))
162
+ else:
163
+ image_features = self.get_model().mm_projector(compressed_features)
164
+ print("final compression features for the whole batch:", image_features.shape)
165
+
166
+ if is_video:
167
+ # add timestamp embeddings for video inputs
168
+ if batch_split_size is not None and len(batch_split_size) > 1: # batch training
169
+ print(f"Number of segments for each video is: {batch_split_size}")
170
+ start_idx = 0
171
+ final_image_features_list = []
172
+ for b_split_size in batch_split_size:
173
+ current_image_features = image_features[start_idx : start_idx + b_split_size]
174
+ current_seg_timestamp = seg_timestamps[start_idx : start_idx + b_split_size]
175
+ final_image_features_list.append(
176
+ VLMMultimodalProcessor.add_seg_timestamp(
177
+ current_image_features,
178
+ self.get_model(),
179
+ current_seg_timestamp,
180
+ is_video,
181
+ )
182
+ )
183
+ start_idx += b_split_size
184
+ else:
185
+ final_image_features_list = [
186
+ VLMMultimodalProcessor.add_seg_timestamp(
187
+ image_features, self.get_model(), seg_timestamps, is_video
188
+ )
189
+ ]
190
+ else:
191
+ final_image_features_list = [img for img in image_features]
192
+
193
+ _labels = labels
194
+ _position_ids = position_ids
195
+ _attention_mask = attention_mask
196
+
197
+ if attention_mask is None:
198
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
199
+ else:
200
+ attention_mask = attention_mask.bool()
201
+
202
+ if position_ids is None:
203
+ position_ids = torch.arange(
204
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
205
+ )
206
+
207
+ if labels is None:
208
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
209
+
210
+ attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX)
211
+
212
+ input_ids = [
213
+ cur_input_ids[cur_attention_mask]
214
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
215
+ ]
216
+ labels = [
217
+ cur_labels[cur_attention_mask]
218
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
219
+ ]
220
+
221
+ new_input_embeds = []
222
+ new_labels = []
223
+ cur_image_idx = 0
224
+
225
+ for batch_idx, cur_input_ids in enumerate(input_ids):
226
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
227
+
228
+ if num_images == 0:
229
+ cur_image_features = final_image_features_list[cur_image_idx]
230
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
231
+ cur_input_embeds = torch.cat(
232
+ [cur_input_embeds_1, cur_image_features[0:0]], dim=0
233
+ )
234
+ new_input_embeds.append(cur_input_embeds)
235
+ new_labels.append(labels[batch_idx])
236
+ cur_image_idx += 1
237
+ continue
238
+
239
+ image_token_indices = (
240
+ [-1]
241
+ + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
242
+ + [cur_input_ids.shape[0]]
243
+ )
244
+
245
+ cur_input_ids_noim = []
246
+ cur_labels = labels[batch_idx]
247
+ cur_labels_noim = []
248
+
249
+ for i in range(len(image_token_indices) - 1):
250
+ cur_input_ids_noim.append(
251
+ cur_input_ids[
252
+ image_token_indices[i] + 1 : image_token_indices[i + 1]
253
+ ]
254
+ )
255
+ cur_labels_noim.append(
256
+ cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
257
+ )
258
+
259
+ split_sizes_text = [x.shape[0] for x in cur_labels_noim]
260
+ cur_input_embeds = self.get_model().embed_tokens(
261
+ torch.cat(cur_input_ids_noim)
262
+ )
263
+ cur_input_embeds_no_im = torch.split(
264
+ cur_input_embeds, split_sizes_text, dim=0
265
+ )
266
+
267
+ # for multi-image inputs, there is a bug.
268
+ cur_new_input_embeds = []
269
+ cur_new_labels = []
270
+ text_len = sum([x.shape[0] for x in cur_input_embeds_no_im])
271
+ visual_len = len(final_image_features_list[cur_image_idx])
272
+ max_visual_len = (
273
+ self.get_model().config.tokenizer_model_max_length
274
+ - getattr(self.get_model().config, "inference_max_length", 16)
275
+ - text_len
276
+ )
277
+
278
+ if max_visual_len < visual_len:
279
+ final_image_features_list[cur_image_idx] = final_image_features_list[cur_image_idx][:max_visual_len]
280
+
281
+ for i in range(num_images + 1):
282
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
283
+ cur_new_labels.append(cur_labels_noim[i])
284
+
285
+ if i < num_images:
286
+ try:
287
+ cur_image_features = final_image_features_list[cur_image_idx]
288
+ except IndexError:
289
+ print(f"cur_image_idx={cur_image_idx} is not ok, get {num_images} images!!!")
290
+ cur_image_features = final_image_features_list[cur_image_idx - 1]
291
+
292
+ cur_image_idx += 1
293
+ cur_new_input_embeds.append(cur_image_features)
294
+ cur_new_labels.append(
295
+ torch.full(
296
+ (cur_image_features.shape[0],),
297
+ IGNORE_INDEX,
298
+ device=cur_labels.device,
299
+ dtype=cur_labels.dtype,
300
+ )
301
+ )
302
+
303
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
304
+
305
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
306
+ cur_new_labels = torch.cat(cur_new_labels)
307
+ new_input_embeds.append(cur_new_input_embeds)
308
+ new_labels.append(cur_new_labels)
309
+
310
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
311
+ if tokenizer_model_max_length is not None:
312
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
313
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
314
+
315
+ max_len = max(x.shape[0] for x in new_input_embeds)
316
+ batch_size = len(new_input_embeds)
317
+
318
+ new_input_embeds_padded = []
319
+ new_labels_padded = torch.full(
320
+ (batch_size, max_len),
321
+ IGNORE_INDEX,
322
+ dtype=new_labels[0].dtype,
323
+ device=new_labels[0].device,
324
+ )
325
+ attention_mask = torch.zeros(
326
+ (batch_size, max_len),
327
+ dtype=attention_mask.dtype,
328
+ device=attention_mask.device,
329
+ )
330
+ position_ids = torch.zeros(
331
+ (batch_size, max_len),
332
+ dtype=position_ids.dtype,
333
+ device=position_ids.device,
334
+ )
335
+
336
+ for i, (cur_new_embed, cur_new_labels) in enumerate(
337
+ zip(new_input_embeds, new_labels)
338
+ ):
339
+ cur_len = cur_new_embed.shape[0]
340
+
341
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
342
+ new_input_embeds_padded.append(
343
+ torch.cat(
344
+ (
345
+ torch.zeros(
346
+ (max_len - cur_len, cur_new_embed.shape[1]),
347
+ dtype=cur_new_embed.dtype,
348
+ device=cur_new_embed.device,
349
+ ),
350
+ cur_new_embed,
351
+ ),
352
+ dim=0,
353
+ )
354
+ )
355
+ if cur_len > 0:
356
+ new_labels_padded[i, -cur_len:] = cur_new_labels
357
+ attention_mask[i, -cur_len:] = True
358
+ position_ids[i, -cur_len:] = torch.arange(
359
+ 0,
360
+ cur_len,
361
+ dtype=position_ids.dtype,
362
+ device=position_ids.device,
363
+ )
364
+ else:
365
+ new_input_embeds_padded.append(
366
+ torch.cat(
367
+ (
368
+ cur_new_embed,
369
+ torch.zeros(
370
+ (max_len - cur_len, cur_new_embed.shape[1]),
371
+ dtype=cur_new_embed.dtype,
372
+ device=cur_new_embed.device,
373
+ ),
374
+ ),
375
+ dim=0,
376
+ )
377
+ )
378
+ if cur_len > 0:
379
+ new_labels_padded[i, :cur_len] = cur_new_labels
380
+ attention_mask[i, :cur_len] = True
381
+ position_ids[i, :cur_len] = torch.arange(
382
+ 0,
383
+ cur_len,
384
+ dtype=position_ids.dtype,
385
+ device=position_ids.device,
386
+ )
387
+
388
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
389
+
390
+ if _labels is None:
391
+ new_labels = None
392
+ else:
393
+ new_labels = new_labels_padded
394
+
395
+ if _attention_mask is None:
396
+ attention_mask = None
397
+ else:
398
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
399
+
400
+ if _position_ids is None:
401
+ position_ids = None
402
+
403
+ return (
404
+ None,
405
+ position_ids,
406
+ attention_mask,
407
+ past_key_values,
408
+ new_input_embeds,
409
+ new_labels,
410
+ )
411
+
412
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
413
+ if model_args.mm_use_im_patch_token:
414
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
415
+ self.resize_token_embeddings(len(tokenizer))
416
+
417
+ if model_args.mm_use_im_start_end:
418
+ num_new_tokens = tokenizer.add_tokens(
419
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
420
+ )
421
+ self.resize_token_embeddings(len(tokenizer))
422
+
423
+ if num_new_tokens > 0:
424
+ input_embeddings = self.get_input_embeddings().weight.data
425
+ output_embeddings = self.get_output_embeddings().weight.data
426
+
427
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
428
+ dim=0, keepdim=True
429
+ )
430
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
431
+ dim=0, keepdim=True
432
+ )
433
+
434
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
435
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
436
+
437
+ if model_args.tune_mm_mlp_adapter:
438
+ for p in self.get_input_embeddings().parameters():
439
+ p.requires_grad = True
440
+ for p in self.get_output_embeddings().parameters():
441
+ p.requires_grad = False
442
+
443
+ if model_args.pretrain_mm_mlp_adapter:
444
+ mm_projector_weights = torch.load(
445
+ model_args.pretrain_mm_mlp_adapter, map_location="cpu"
446
+ )
447
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
448
+ assert num_new_tokens == 2
449
+ if input_embeddings.shape == embed_tokens_weight.shape:
450
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
451
+ -num_new_tokens:
452
+ ]
453
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
454
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
455
+ else:
456
+ raise ValueError(
457
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
458
+ )
459
+ elif model_args.mm_use_im_patch_token:
460
+ if model_args.tune_mm_mlp_adapter:
461
+ for p in self.get_input_embeddings().parameters():
462
+ p.requires_grad = False
463
+ for p in self.get_output_embeddings().parameters():
464
+ p.requires_grad = False
tempo/vlm_multimodal_processor.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+
5
+ class VLMMultimodalProcessor:
6
+ """SVLM-based vision compression."""
7
+
8
+ @staticmethod
9
+ def tokenize_vision_inputs(vision_language_model, vlm_inputs, is_video, chunk_size=4):
10
+ return vision_language_model(**vlm_inputs)
11
+
12
+ @staticmethod
13
+ def add_seg_timestamp(vision_features, model, seg_timestamps, is_video):
14
+ if not is_video:
15
+ return vision_features
16
+
17
+ device = vision_features[0].device
18
+ dtype = vision_features[0].dtype
19
+
20
+ max_len = max(len(ts) for ts in seg_timestamps)
21
+ num_segments = len(seg_timestamps)
22
+ # pad_token_id = getattr(model.config, "pad_token_id", 151643)
23
+ pad_token_id = 151643
24
+ timestamp_ids_tensor = torch.full(
25
+ (num_segments, max_len), pad_token_id, dtype=torch.long, device=device
26
+ )
27
+
28
+ for i, ts in enumerate(seg_timestamps):
29
+ length = len(ts)
30
+ timestamp_ids_tensor[i, :length] = torch.tensor(ts, device=device)
31
+
32
+ timestamp_embeds = model.get_input_embeddings()(timestamp_ids_tensor).to(dtype)
33
+
34
+ final_vision_features = []
35
+ for i in range(num_segments):
36
+ if vision_features[i].shape[0] == 0:
37
+ print("drop this segment directly.")
38
+ continue
39
+ final_vision_features.append(
40
+ torch.cat(
41
+ [
42
+ timestamp_embeds[i][: len(seg_timestamps[i])],
43
+ vision_features[i],
44
+ ],
45
+ dim=0,
46
+ )
47
+ )
48
+
49
+ # return torch.cat(final_vision_features, dim=0).unsqueeze(0)
50
+ return torch.cat(final_vision_features, dim=0) # (comp_frame1+comp_frame2+comp_frame3+..., d)
51
+
52
+ @staticmethod
53
+ def tome_merge(x: torch.Tensor, target_num: int) -> torch.Tensor:
54
+ """
55
+ Token Merging using bipartite soft matching.
56
+ Reference: "Token Merging: Your ViT But Faster" (Bolya et al.)
57
+ Args:
58
+ x: (n, d) tensor of token features
59
+ target_num: number of tokens to keep after merging
60
+ Returns:
61
+ merged tokens: (target_num, d) tensor
62
+ """
63
+ if target_num <= 0:
64
+ raise ValueError("target_num must be positive")
65
+
66
+ n, d = x.shape
67
+
68
+ if target_num >= n:
69
+ return x
70
+
71
+ while x.shape[0] > target_num:
72
+ current_n = x.shape[0]
73
+
74
+ if current_n < 2:
75
+ break
76
+
77
+ t1 = (current_n + 1) // 2 # ceil(n/2) - source token
78
+ t2 = current_n // 2 # floor(n/2) - target token
79
+
80
+ if t2 == 0:
81
+ break
82
+
83
+ tokens_to_remove = current_n - target_num
84
+ r = min(tokens_to_remove, t1)
85
+
86
+ if r <= 0:
87
+ break
88
+
89
+ x_batch = x.unsqueeze(0) # (1, n, d)
90
+ k = x_batch / x_batch.norm(dim=-1, keepdim=True)
91
+ a, b = k[..., ::2, :], k[..., 1::2, :]
92
+ scores = a @ b.transpose(-1, -2)
93
+ node_max, node_idx = scores.max(dim=-1)
94
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
95
+ unm_idx = edge_idx[..., r:, :]
96
+ src_idx = edge_idx[..., :r, :]
97
+ dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
98
+ unm_idx = unm_idx.sort(dim=-2)[0]
99
+ # merge
100
+ src, dst = x_batch[..., ::2, :], x_batch[..., 1::2, :]
101
+ batch, _, c = src.shape
102
+ unm = src.gather(dim=-2, index=unm_idx.expand(batch, t1 - r, c))
103
+ src_to_merge = src.gather(dim=-2, index=src_idx.expand(batch, r, c))
104
+ dst = dst.scatter_add(-2, dst_idx.expand(batch, r, c), src_to_merge)
105
+ # pooling
106
+ ones = torch.ones(batch, r, 1, device=x.device, dtype=x.dtype)
107
+ dst_counts = torch.ones(
108
+ batch, dst.shape[1], 1, device=x.device, dtype=x.dtype
109
+ )
110
+ dst_counts = dst_counts.scatter_add(-2, dst_idx.expand(batch, r, 1), ones)
111
+ dst = dst / dst_counts
112
+ x = torch.cat([unm, dst], dim=-2).squeeze(0) # (new_n, d)
113
+ return x
114
+
115
+ @staticmethod
116
+ def topk_compress(
117
+ vision_features,
118
+ relevance_scores,
119
+ is_video,
120
+ k=0,
121
+ drop_ratio=0.5,
122
+ strategy="topk",
123
+ ):
124
+ """
125
+ Drop/Truncate low-scoring segments to keep only first k tokens based on relevance scores
126
+ Args:
127
+ vision_features: (n_segment, n, d)
128
+ relevance_scores: n_segment (list of scores between 0 and 1)
129
+ k: number of tokens to keep, k=0 means drop directly
130
+ drop_ratio: ratio of segments to drop/truncate
131
+ strategy: "topk" (keep highest), "lastk" (keep lowest), "random" (random drop)
132
+ Return a list of tensor with length of n_segment. Each tensor is (k, d) or (num_compression_token, d)
133
+ """
134
+
135
+ if not is_video or relevance_scores is None:
136
+ return vision_features
137
+
138
+ n_segment = vision_features.shape[0]
139
+ if n_segment <= 1:
140
+ print("video segment is equal/less than 1, not compressing")
141
+ return vision_features
142
+
143
+ if strategy == "topk":
144
+ # in ascending order
145
+ sorted_indices = sorted(range(n_segment), key=lambda i: relevance_scores[i])
146
+ elif strategy == "lastk":
147
+ # in descending order
148
+ sorted_indices = sorted(
149
+ range(n_segment), key=lambda i: relevance_scores[i], reverse=True
150
+ )
151
+ elif strategy == "random":
152
+ # shuffle index, random
153
+ sorted_indices = list(range(n_segment))
154
+ random.shuffle(sorted_indices)
155
+ else:
156
+ raise ValueError(f"Unknown strategy: {strategy}")
157
+
158
+ # truncate or drop ratio
159
+ num_to_prune = int(n_segment * drop_ratio)
160
+ print(
161
+ f"[topk_compress] total segment: {n_segment}, drop/truncate segment: {num_to_prune}, keep first {k} token"
162
+ )
163
+
164
+ low_score_indices = set(sorted_indices[:num_to_prune])
165
+
166
+ result = []
167
+ for i in range(n_segment):
168
+ if i in low_score_indices:
169
+ result.append(vision_features[i, :k, :]) # shape: (k, d)
170
+ else:
171
+ result.append(vision_features[i]) # shape: (n, d)
172
+
173
+ print(
174
+ f"[topk_compress] segments: {n_segment} (pruned={num_to_prune}, kept={n_segment - num_to_prune}), "
175
+ f"kept tokens for pruned segment={k}"
176
+ )
177
+
178
+ return result
179
+
180
+ @staticmethod
181
+ def adaptive_linear_budget_allocation(
182
+ vision_features,
183
+ relevance_scores,
184
+ is_video,
185
+ max_budget=8192,
186
+ min_tokens=4,
187
+ max_tokens=None,
188
+ strategy="head",
189
+ ):
190
+ """
191
+ soft token allocation based on min-max normalized relevance scores, Uses linear mapping instead of softmax, providing more aggressive sparsity
192
+ Args:
193
+ vision_features: (n_segment, n_tokens, d)
194
+ relevance_scores: list/array of scores
195
+ is_video: bool
196
+ max_budget: the largest budget for a video
197
+ min_tokens: minimum number of tokens to allocate to each segment
198
+ max_tokens: maximum number of tokens to allocate to each segment
199
+ strategy: "head", "tail", "random", "tome"
200
+ Return a list of tensor with length of n_segment. Each tensor is (k, d), where k in [min_tokens, max_tokens]
201
+ """
202
+ if not is_video or relevance_scores is None:
203
+ return vision_features, None
204
+
205
+ n_segments = vision_features.shape[0]
206
+ n_tokens_per_segment = vision_features.shape[1]
207
+ max_tokens = min(max_tokens or n_tokens_per_segment, n_tokens_per_segment)
208
+
209
+ base_budget = n_segments * min_tokens
210
+ if base_budget > max_budget:
211
+ actual_min = max(1, max_budget // n_segments)
212
+ print(
213
+ f"[adaptive_linear_budget_allocation] Warning: budget insufficient, "
214
+ f"min_tokens: {min_tokens} -> {actual_min}"
215
+ )
216
+ min_tokens = actual_min
217
+ base_budget = n_segments * min_tokens
218
+
219
+ # Convert to tensor
220
+ scores = torch.tensor(relevance_scores, dtype=torch.float32)
221
+ score_min = scores.min()
222
+ score_max = scores.max()
223
+ score_range = score_max - score_min
224
+
225
+ if score_range < 1e-8:
226
+ k = min(max_tokens, max_budget // n_segments)
227
+ return [vision_features[i, :k, :] for i in range(n_segments)], torch.full(
228
+ (n_segments,), k, dtype=torch.long
229
+ )
230
+
231
+ # Normalize scores to [0, 1] range
232
+ normalized_scores = (scores - score_min) / score_range
233
+
234
+ # Linear mapping: [0, 1] -> [min_tokens, max_tokens]
235
+ token_range = max_tokens - min_tokens
236
+ ideal_allocations = (
237
+ min_tokens + (normalized_scores * token_range).floor().long()
238
+ )
239
+
240
+ # ========== Budget Protection ==========
241
+ total_desired = ideal_allocations.sum().item()
242
+
243
+ if total_desired > max_budget:
244
+ # Only scale down when total demand exceeds budget
245
+ print(
246
+ f"[adaptive_linear_budget_allocation] Warning: Desired budget "
247
+ f"({total_desired}) exceeds max ({max_budget}). Scaling down..."
248
+ )
249
+
250
+ extra_budget = max_budget - base_budget
251
+
252
+ # Prevent division by zero
253
+ score_sum = normalized_scores.sum().item()
254
+ if score_sum < 1e-8:
255
+ # Fallback to uniform distribution if all normalized scores are ~0
256
+ weights = torch.ones_like(normalized_scores) / n_segments
257
+ else:
258
+ weights = normalized_scores / score_sum
259
+
260
+ # Distribute extra budget proportionally
261
+ extra_allocation = (weights * extra_budget).floor().long()
262
+ remainder = int(extra_budget - extra_allocation.sum().item())
263
+ if remainder > 0:
264
+ top_indices = torch.argsort(weights, descending=True)[:remainder]
265
+ for idx in top_indices:
266
+ extra_allocation[idx] += 1
267
+
268
+ allocations = min_tokens + extra_allocation
269
+ else:
270
+ allocations = ideal_allocations
271
+
272
+ # Final clamp to ensure bounds
273
+ allocations = allocations.clamp(min=min_tokens, max=max_tokens)
274
+
275
+ # default, head truncation
276
+ if strategy == "head":
277
+ result = []
278
+ for i in range(n_segments):
279
+ k = min(allocations[i].item(), n_tokens_per_segment)
280
+ result.append(vision_features[i, :k, :])
281
+
282
+ elif strategy == "tail":
283
+ result = []
284
+ for i in range(n_segments):
285
+ k = min(allocations[i].item(), n_tokens_per_segment)
286
+ if k == 0:
287
+ result.append(vision_features[i, :0, :])
288
+ else:
289
+ result.append(vision_features[i, -k:, :])
290
+
291
+ elif strategy == "random":
292
+ result = []
293
+ for i in range(n_segments):
294
+ k = min(allocations[i].item(), n_tokens_per_segment)
295
+ if k == 0:
296
+ result.append(vision_features[i, :0, :])
297
+ else:
298
+ indices = torch.randperm(
299
+ n_tokens_per_segment, device=vision_features.device
300
+ )[:k]
301
+ indices = indices.sort().values
302
+ result.append(vision_features[i, indices, :])
303
+
304
+ elif strategy == "tome":
305
+ result = []
306
+ for i in range(n_segments):
307
+ k = min(allocations[i].item(), n_tokens_per_segment)
308
+ if k == 0:
309
+ result.append(vision_features[i, :0, :])
310
+ elif k == n_tokens_per_segment:
311
+ result.append(vision_features[i])
312
+ else:
313
+ # Token Merging
314
+ result.append(
315
+ VLMMultimodalProcessor.tome_merge(
316
+ vision_features[i], target_num=k
317
+ )
318
+ )
319
+ else:
320
+ raise ValueError(f"Unknown strategy: {strategy}")
321
+
322
+ total_used = allocations.sum().item()
323
+
324
+ print(
325
+ f"[adaptive_linear_budget_allocation] segments={n_segments}, "
326
+ f"budget_used={total_used}/{max_budget}, "
327
+ f"theoretical_range=[{min_tokens}, {max_tokens}], "
328
+ f"actual_range=[{allocations.min().item():.0f}, {allocations.max().item():.0f}]",
329
+ flush=True,
330
+ )
331
+
332
+ return result, allocations