Spaces:
Paused
Paused
| import os | |
| import gc | |
| import random | |
| import torch | |
| import gradio as gr | |
| from diffusers import ( | |
| LTX2Pipeline, | |
| LTX2VideoTransformer3DModel, | |
| LTX2LatentUpsamplePipeline, | |
| LTX2LatentUpsamplerModel, | |
| AutoencoderKLLTX2Video, | |
| ) | |
| from diffusers.pipelines.ltx2.export_utils import encode_video | |
| from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from transformers import Gemma3ForConditionalGeneration | |
| torch_dtype = torch.bfloat16 | |
| DEVICE = torch.device("cuda") | |
| OFFLOAD_DEVICE = torch.device("cpu") | |
| MODEL_PATH = "Lightricks/LTX-2" | |
| OUTPUT_DIR = "./outputs/ltx2" | |
| if not os.path.exists(OUTPUT_DIR): | |
| os.makedirs(OUTPUT_DIR) | |
| # ----------------------------- | |
| # 1. Preload all models once | |
| # ----------------------------- | |
| print("Loading text encoder...") | |
| text_encoder = Gemma3ForConditionalGeneration.from_pretrained( | |
| "OzzyGT/LTX-2-bnb-8bit-text-encoder", dtype=torch_dtype | |
| ) | |
| print("Loading first stage transformer...") | |
| transformer1 = LTX2VideoTransformer3DModel.from_pretrained( | |
| "OzzyGT/LTX-2-bnb-4bit-transformer-distilled", torch_dtype=torch_dtype, device_map="cpu" | |
| ) | |
| first_stage_pipe = LTX2Pipeline.from_pretrained( | |
| MODEL_PATH, transformer=transformer1, text_encoder=text_encoder, vocoder=None, torch_dtype=torch_dtype | |
| ) | |
| first_stage_pipe.enable_group_offload( | |
| onload_device=DEVICE, offload_device=OFFLOAD_DEVICE, offload_type="leaf_level", low_cpu_mem_usage=True | |
| ) | |
| print("Loading latent upsampler and VAE...") | |
| latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( | |
| "rootonchair/LTX-2-19b-distilled", subfolder="latent_upsampler", torch_dtype=torch_dtype | |
| ) | |
| vae = AutoencoderKLLTX2Video.from_pretrained(MODEL_PATH, subfolder="vae", torch_dtype=torch_dtype) | |
| upsample_pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) | |
| upsample_pipe.enable_model_cpu_offload(device=DEVICE) | |
| print("Loading second stage transformer...") | |
| transformer2 = LTX2VideoTransformer3DModel.from_pretrained( | |
| "OzzyGT/LTX-2-bnb-4bit-transformer-distilled", torch_dtype=torch_dtype, device_map="cpu" | |
| ) | |
| second_stage_pipe = LTX2Pipeline.from_pretrained( | |
| MODEL_PATH, transformer=transformer2, text_encoder=None, torch_dtype=torch_dtype | |
| ) | |
| stage2_scheduler = FlowMatchEulerDiscreteScheduler.from_config( | |
| second_stage_pipe.scheduler.config, use_dynamic_shifting=False | |
| ) | |
| second_stage_pipe.scheduler = stage2_scheduler | |
| second_stage_pipe.enable_group_offload( | |
| onload_device=DEVICE, offload_device=OFFLOAD_DEVICE, offload_type="leaf_level", low_cpu_mem_usage=True | |
| ) | |
| print("Loading decoder pipeline...") | |
| decode_pipe = LTX2Pipeline.from_pretrained(MODEL_PATH, text_encoder=None, transformer=None, scheduler=None, connectors=None, torch_dtype=torch_dtype) | |
| decode_pipe.to(DEVICE) | |
| decode_pipe.vae.enable_tiling( | |
| tile_sample_min_height=256, | |
| tile_sample_min_width=256, | |
| tile_sample_min_num_frames=16, | |
| tile_sample_stride_height=192, | |
| tile_sample_stride_width=192, | |
| tile_sample_stride_num_frames=8 | |
| ) | |
| decode_pipe.vae.use_framewise_encoding = True | |
| decode_pipe.vae.use_framewise_decoding = True | |
| decode_pipe.enable_model_cpu_offload() | |
| print("All models loaded ✅") | |
| # ----------------------------- | |
| # 2. Video generation function | |
| # ----------------------------- | |
| # Replace with @spaces.GPU() in Spaces | |
| def generate_video(prompt: str, negative_prompt: str, seed: int = -1): | |
| if seed == -1: | |
| seed = random.randint(0, 999999) | |
| generator = torch.Generator(DEVICE).manual_seed(seed) | |
| # Encode prompt | |
| with torch.inference_mode(): | |
| prompt_embeds, prompt_attention_mask, _, _ = first_stage_pipe.encode_prompt( | |
| prompt, negative_prompt, do_classifier_free_guidance=False | |
| ) | |
| prompt_embeds = prompt_embeds.to(DEVICE) | |
| prompt_attention_mask = prompt_attention_mask.to(DEVICE) | |
| # First stage | |
| video_latent, audio_latent = first_stage_pipe( | |
| prompt_embeds=prompt_embeds, | |
| prompt_attention_mask=prompt_attention_mask, | |
| width=768, | |
| height=512, | |
| num_frames=481, | |
| frame_rate=24.0, | |
| num_inference_steps=8, | |
| sigmas=DISTILLED_SIGMA_VALUES, | |
| guidance_scale=1.0, | |
| generator=generator, | |
| output_type="latent", | |
| return_dict=False, | |
| ) | |
| video_latent = video_latent.detach().cpu() | |
| audio_latent = audio_latent.detach().cpu() | |
| # Upsample | |
| upscaled_video_latent = upsample_pipe(latents=video_latent, output_type="latent", return_dict=False)[0].detach().cpu() | |
| # Second stage | |
| upscaled_video_latent = upscaled_video_latent.to(DEVICE) | |
| audio_latent = audio_latent.to(DEVICE) | |
| prompt_embeds = prompt_embeds.to(DEVICE) | |
| prompt_attention_mask = prompt_attention_mask.to(DEVICE) | |
| video_latent_stage2, audio_latent_stage2 = second_stage_pipe( | |
| latents=upscaled_video_latent, | |
| audio_latents=audio_latent, | |
| prompt_embeds=prompt_embeds, | |
| prompt_attention_mask=prompt_attention_mask, | |
| width=768*2, | |
| height=512*2, | |
| num_frames=481, | |
| num_inference_steps=3, | |
| noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], | |
| sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, | |
| generator=generator, | |
| guidance_scale=1.0, | |
| output_type="latent", | |
| return_dict=False, | |
| ) | |
| # Decode final video | |
| with torch.inference_mode(): | |
| video_tensor = decode_pipe.vae.decode( | |
| video_latent_stage2.to(DEVICE, dtype=decode_pipe.vae.dtype), None, return_dict=False | |
| )[0] | |
| video_np = decode_pipe.video_processor.postprocess_video(video_tensor, output_type="np") | |
| audio_tensor = decode_pipe.audio_vae.decode(audio_latent_stage2.to(DEVICE, dtype=decode_pipe.audio_vae.dtype), return_dict=False)[0] | |
| audio_out = decode_pipe.vocoder(audio_tensor) | |
| video_np = (video_np * 255).round().astype("uint8") | |
| video_tensor = torch.from_numpy(video_np) | |
| output_path = os.path.join(OUTPUT_DIR, f"t2v_{seed}.mp4") | |
| encode_video(video_tensor[0], fps=24.0, audio=audio_out[0].float().cpu(), | |
| audio_sample_rate=decode_pipe.vocoder.config.output_sampling_rate, | |
| output_path=output_path) | |
| return output_path, seed | |
| # ----------------------------- | |
| # 3. Gradio UI | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## ⚡ LTX-2 Distilled Video Generator (Preloaded Models)") | |
| prompt_input = gr.Textbox(label="Prompt", lines=4, value="A person speaking in a backyard") | |
| negative_input = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, jittery, distorted") | |
| seed_input = gr.Number(label="Seed (-1 random)", value=-1, precision=0) | |
| output_video = gr.Video(label="Generated Video") | |
| run_btn = gr.Button("Generate Video") | |
| run_btn.click( | |
| fn=generate_video, | |
| inputs=[prompt_input, negative_input, seed_input], | |
| outputs=[output_video, seed_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |