import os import gc import random import torch import gradio as gr from diffusers import ( LTX2Pipeline, LTX2VideoTransformer3DModel, LTX2LatentUpsamplePipeline, LTX2LatentUpsamplerModel, AutoencoderKLLTX2Video, ) from diffusers.pipelines.ltx2.export_utils import encode_video from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import Gemma3ForConditionalGeneration torch_dtype = torch.bfloat16 DEVICE = torch.device("cuda") OFFLOAD_DEVICE = torch.device("cpu") MODEL_PATH = "Lightricks/LTX-2" OUTPUT_DIR = "./outputs/ltx2" if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) # ----------------------------- # 1. Preload all models once # ----------------------------- print("Loading text encoder...") text_encoder = Gemma3ForConditionalGeneration.from_pretrained( "OzzyGT/LTX-2-bnb-8bit-text-encoder", dtype=torch_dtype ) print("Loading first stage transformer...") transformer1 = LTX2VideoTransformer3DModel.from_pretrained( "OzzyGT/LTX-2-bnb-4bit-transformer-distilled", torch_dtype=torch_dtype, device_map="cpu" ) first_stage_pipe = LTX2Pipeline.from_pretrained( MODEL_PATH, transformer=transformer1, text_encoder=text_encoder, vocoder=None, torch_dtype=torch_dtype ) first_stage_pipe.enable_group_offload( onload_device=DEVICE, offload_device=OFFLOAD_DEVICE, offload_type="leaf_level", low_cpu_mem_usage=True ) print("Loading latent upsampler and VAE...") latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( "rootonchair/LTX-2-19b-distilled", subfolder="latent_upsampler", torch_dtype=torch_dtype ) vae = AutoencoderKLLTX2Video.from_pretrained(MODEL_PATH, subfolder="vae", torch_dtype=torch_dtype) upsample_pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) upsample_pipe.enable_model_cpu_offload(device=DEVICE) print("Loading second stage transformer...") transformer2 = LTX2VideoTransformer3DModel.from_pretrained( "OzzyGT/LTX-2-bnb-4bit-transformer-distilled", torch_dtype=torch_dtype, device_map="cpu" ) second_stage_pipe = LTX2Pipeline.from_pretrained( MODEL_PATH, transformer=transformer2, text_encoder=None, torch_dtype=torch_dtype ) stage2_scheduler = FlowMatchEulerDiscreteScheduler.from_config( second_stage_pipe.scheduler.config, use_dynamic_shifting=False ) second_stage_pipe.scheduler = stage2_scheduler second_stage_pipe.enable_group_offload( onload_device=DEVICE, offload_device=OFFLOAD_DEVICE, offload_type="leaf_level", low_cpu_mem_usage=True ) print("Loading decoder pipeline...") decode_pipe = LTX2Pipeline.from_pretrained(MODEL_PATH, text_encoder=None, transformer=None, scheduler=None, connectors=None, torch_dtype=torch_dtype) decode_pipe.to(DEVICE) decode_pipe.vae.enable_tiling( tile_sample_min_height=256, tile_sample_min_width=256, tile_sample_min_num_frames=16, tile_sample_stride_height=192, tile_sample_stride_width=192, tile_sample_stride_num_frames=8 ) decode_pipe.vae.use_framewise_encoding = True decode_pipe.vae.use_framewise_decoding = True decode_pipe.enable_model_cpu_offload() print("All models loaded ✅") # ----------------------------- # 2. Video generation function # ----------------------------- @spaces.GPU() # Replace with @spaces.GPU() in Spaces def generate_video(prompt: str, negative_prompt: str, seed: int = -1): if seed == -1: seed = random.randint(0, 999999) generator = torch.Generator(DEVICE).manual_seed(seed) # Encode prompt with torch.inference_mode(): prompt_embeds, prompt_attention_mask, _, _ = first_stage_pipe.encode_prompt( prompt, negative_prompt, do_classifier_free_guidance=False ) prompt_embeds = prompt_embeds.to(DEVICE) prompt_attention_mask = prompt_attention_mask.to(DEVICE) # First stage video_latent, audio_latent = first_stage_pipe( prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask, width=768, height=512, num_frames=481, frame_rate=24.0, num_inference_steps=8, sigmas=DISTILLED_SIGMA_VALUES, guidance_scale=1.0, generator=generator, output_type="latent", return_dict=False, ) video_latent = video_latent.detach().cpu() audio_latent = audio_latent.detach().cpu() # Upsample upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0].detach().cpu() # Second stage upscaled_video_latent = upscaled_video_latent.to(DEVICE) audio_latent = audio_latent.to(DEVICE) prompt_embeds = prompt_embeds.to(DEVICE) prompt_attention_mask = prompt_attention_mask.to(DEVICE) video_latent_stage2, audio_latent_stage2 = second_stage_pipe( latents=upscaled_video_latent, audio_latents=audio_latent, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask, width=768*2, height=512*2, num_frames=481, num_inference_steps=3, noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, generator=generator, guidance_scale=1.0, output_type="latent", return_dict=False, ) # Decode final video with torch.inference_mode(): video_tensor = decode_pipe.vae.decode( video_latent_stage2.to(DEVICE, dtype=decode_pipe.vae.dtype), None, return_dict=False )[0] video_np = decode_pipe.video_processor.postprocess_video(video_tensor, output_type="np") audio_tensor = decode_pipe.audio_vae.decode(audio_latent_stage2.to(DEVICE, dtype=decode_pipe.audio_vae.dtype), return_dict=False)[0] audio_out = decode_pipe.vocoder(audio_tensor) video_np = (video_np * 255).round().astype("uint8") video_tensor = torch.from_numpy(video_np) output_path = os.path.join(OUTPUT_DIR, f"t2v_{seed}.mp4") encode_video(video_tensor[0], fps=24.0, audio=audio_out[0].float().cpu(), audio_sample_rate=decode_pipe.vocoder.config.output_sampling_rate, output_path=output_path) return output_path, seed # ----------------------------- # 3. Gradio UI # ----------------------------- with gr.Blocks() as demo: gr.Markdown("## ⚡ LTX-2 Distilled Video Generator (Preloaded Models)") prompt_input = gr.Textbox(label="Prompt", lines=4, value="A person speaking in a backyard") negative_input = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, jittery, distorted") seed_input = gr.Number(label="Seed (-1 random)", value=-1, precision=0) output_video = gr.Video(label="Generated Video") run_btn = gr.Button("Generate Video") run_btn.click( fn=generate_video, inputs=[prompt_input, negative_input, seed_input], outputs=[output_video, seed_input] ) if __name__ == "__main__": demo.queue().launch()