LTX2_distill / app2.py
rahul7star's picture
Update app2.py
2236133 verified
import os
import gc
import random
import torch
import gradio as gr
from diffusers import (
LTX2Pipeline,
LTX2VideoTransformer3DModel,
LTX2LatentUpsamplePipeline,
LTX2LatentUpsamplerModel,
AutoencoderKLLTX2Video,
)
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import Gemma3ForConditionalGeneration
torch_dtype = torch.bfloat16
DEVICE = torch.device("cuda")
OFFLOAD_DEVICE = torch.device("cpu")
MODEL_PATH = "Lightricks/LTX-2"
OUTPUT_DIR = "./outputs/ltx2"
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
# -----------------------------
# 1. Preload all models once
# -----------------------------
print("Loading text encoder...")
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
"OzzyGT/LTX-2-bnb-8bit-text-encoder", dtype=torch_dtype
)
print("Loading first stage transformer...")
transformer1 = LTX2VideoTransformer3DModel.from_pretrained(
"OzzyGT/LTX-2-bnb-4bit-transformer-distilled", torch_dtype=torch_dtype, device_map="cpu"
)
first_stage_pipe = LTX2Pipeline.from_pretrained(
MODEL_PATH, transformer=transformer1, text_encoder=text_encoder, vocoder=None, torch_dtype=torch_dtype
)
first_stage_pipe.enable_group_offload(
onload_device=DEVICE, offload_device=OFFLOAD_DEVICE, offload_type="leaf_level", low_cpu_mem_usage=True
)
print("Loading latent upsampler and VAE...")
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
"rootonchair/LTX-2-19b-distilled", subfolder="latent_upsampler", torch_dtype=torch_dtype
)
vae = AutoencoderKLLTX2Video.from_pretrained(MODEL_PATH, subfolder="vae", torch_dtype=torch_dtype)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=DEVICE)
print("Loading second stage transformer...")
transformer2 = LTX2VideoTransformer3DModel.from_pretrained(
"OzzyGT/LTX-2-bnb-4bit-transformer-distilled", torch_dtype=torch_dtype, device_map="cpu"
)
second_stage_pipe = LTX2Pipeline.from_pretrained(
MODEL_PATH, transformer=transformer2, text_encoder=None, torch_dtype=torch_dtype
)
stage2_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
second_stage_pipe.scheduler.config, use_dynamic_shifting=False
)
second_stage_pipe.scheduler = stage2_scheduler
second_stage_pipe.enable_group_offload(
onload_device=DEVICE, offload_device=OFFLOAD_DEVICE, offload_type="leaf_level", low_cpu_mem_usage=True
)
print("Loading decoder pipeline...")
decode_pipe = LTX2Pipeline.from_pretrained(MODEL_PATH, text_encoder=None, transformer=None, scheduler=None, connectors=None, torch_dtype=torch_dtype)
decode_pipe.to(DEVICE)
decode_pipe.vae.enable_tiling(
tile_sample_min_height=256,
tile_sample_min_width=256,
tile_sample_min_num_frames=16,
tile_sample_stride_height=192,
tile_sample_stride_width=192,
tile_sample_stride_num_frames=8
)
decode_pipe.vae.use_framewise_encoding = True
decode_pipe.vae.use_framewise_decoding = True
decode_pipe.enable_model_cpu_offload()
print("All models loaded ✅")
# -----------------------------
# 2. Video generation function
# -----------------------------
@spaces.GPU() # Replace with @spaces.GPU() in Spaces
def generate_video(prompt: str, negative_prompt: str, seed: int = -1):
if seed == -1:
seed = random.randint(0, 999999)
generator = torch.Generator(DEVICE).manual_seed(seed)
# Encode prompt
with torch.inference_mode():
prompt_embeds, prompt_attention_mask, _, _ = first_stage_pipe.encode_prompt(
prompt, negative_prompt, do_classifier_free_guidance=False
)
prompt_embeds = prompt_embeds.to(DEVICE)
prompt_attention_mask = prompt_attention_mask.to(DEVICE)
# First stage
video_latent, audio_latent = first_stage_pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
width=768,
height=512,
num_frames=481,
frame_rate=24.0,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
generator=generator,
output_type="latent",
return_dict=False,
)
video_latent = video_latent.detach().cpu()
audio_latent = audio_latent.detach().cpu()
# Upsample
upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0].detach().cpu()
# Second stage
upscaled_video_latent = upscaled_video_latent.to(DEVICE)
audio_latent = audio_latent.to(DEVICE)
prompt_embeds = prompt_embeds.to(DEVICE)
prompt_attention_mask = prompt_attention_mask.to(DEVICE)
video_latent_stage2, audio_latent_stage2 = second_stage_pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
width=768*2,
height=512*2,
num_frames=481,
num_inference_steps=3,
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
generator=generator,
guidance_scale=1.0,
output_type="latent",
return_dict=False,
)
# Decode final video
with torch.inference_mode():
video_tensor = decode_pipe.vae.decode(
video_latent_stage2.to(DEVICE, dtype=decode_pipe.vae.dtype), None, return_dict=False
)[0]
video_np = decode_pipe.video_processor.postprocess_video(video_tensor, output_type="np")
audio_tensor = decode_pipe.audio_vae.decode(audio_latent_stage2.to(DEVICE, dtype=decode_pipe.audio_vae.dtype), return_dict=False)[0]
audio_out = decode_pipe.vocoder(audio_tensor)
video_np = (video_np * 255).round().astype("uint8")
video_tensor = torch.from_numpy(video_np)
output_path = os.path.join(OUTPUT_DIR, f"t2v_{seed}.mp4")
encode_video(video_tensor[0], fps=24.0, audio=audio_out[0].float().cpu(),
audio_sample_rate=decode_pipe.vocoder.config.output_sampling_rate,
output_path=output_path)
return output_path, seed
# -----------------------------
# 3. Gradio UI
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## ⚡ LTX-2 Distilled Video Generator (Preloaded Models)")
prompt_input = gr.Textbox(label="Prompt", lines=4, value="A person speaking in a backyard")
negative_input = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, jittery, distorted")
seed_input = gr.Number(label="Seed (-1 random)", value=-1, precision=0)
output_video = gr.Video(label="Generated Video")
run_btn = gr.Button("Generate Video")
run_btn.click(
fn=generate_video,
inputs=[prompt_input, negative_input, seed_input],
outputs=[output_video, seed_input]
)
if __name__ == "__main__":
demo.queue().launch()