| # Training Guide |
| This guide provides simple snippets to train diffnext models. |
|
|
| # 1. Build VQVAE cache |
| To optimize training workflow, we preprocess images or videos into VQVAE latents. |
|
|
| ## Requirements: |
| ```bash |
| pip install protobuf==3.20.3 codewithgpu decord |
| ``` |
|
|
| ## Build T2I cache |
| Following snippet can be used to cache image latents: |
|
|
| ```python |
| import os, codewithgpu, torch, PIL.Image, numpy as np |
| from diffnext.models.autoencoders.autoencoder_vq import AutoencoderVQ |
| |
| device, dtype = torch.device("cuda"), torch.float16 |
| vae = AutoencoderVQ.from_pretrained("/path/to/BAAI/URSA-1.7B-IBQ1024/vae") |
| vae = vae.to(device=device, dtype=dtype).eval() |
| |
| features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"]} |
| os.makedirs("./datasets/ibq1024_dataset", exist_ok=True) |
| writer = codewithgpu.RecordWriter("./datasets/ibq1024_dataset", features) |
| |
| img = PIL.Image.open("./assets/sample_image.jpg") |
| x = torch.as_tensor(np.array(img)[None, ...].transpose(0, 3, 1, 2)).to(device).to(dtype) |
| with torch.no_grad(): |
| x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.unsqueeze(1).cpu().numpy()[0] |
| example = {"caption": "long caption", "text": "short text"} |
| # Ensure enough examples for codewithgou distributed dataset. |
| [writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)] |
| writer.close() |
| ``` |
|
|
| ## Build T2V cache |
| Following snippet can be used to cache video latents: |
|
|
| ```python |
| import os, codewithgpu, torch, decord, numpy as np |
| from diffnext.models.autoencoders.autoencoder_vq_cosmos3d import AutoencoderVQCosmos3D |
| |
| device, dtype = torch.device("cuda"), torch.float16 |
| vae = AutoencoderVQCosmos3D.from_pretrained("/path/to/URSA-1.7B-FSQ320/vae") |
| vae = vae.to(device=device, dtype=dtype).eval() |
| |
| features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"], "flow": "float64"} |
| os.makedirs("./datasets/fsq320_dataset", exist_ok=True) |
| writer = codewithgpu.RecordWriter("./datasets/fsq320_dataset", features) |
| |
| resize, crop_size, frame_ids = 320, (320, 512), list(range(0, 97, 2)) |
| vid = decord.VideoReader("./assets/sample_video.mp4") |
| h, w = vid[0].shape[:2] |
| scale = float(resize) / float(min(h, w)) |
| size = int(h * scale + 0.5), int(w * scale + 0.5) |
| y, x = (size[0] - crop_size[0]) // 2, (size[1] - crop_size[1]) // 2 |
| vid = decord.VideoReader("./assets/sample_video.mp4", height=size[0], width=size[1]) |
| vid = vid.get_batch(frame_ids).asnumpy() |
| vid = vid[:, y : y + crop_size[0], x : x + crop_size[1]] |
| x = torch.as_tensor(vid[None, ...].transpose((0, 4, 1, 2, 3))).to(device).to(dtype) |
| with torch.no_grad(): |
| x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.cpu().numpy()[0] |
| example = {"caption": "long caption", "text": "short text", "flow": 9} |
| # Ensure enough examples for codewithgou distributed dataset. |
| [writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)] |
| writer.close() |
| ``` |
|
|
| # 2. Train models |
|
|
| ## Train T2I model |
| Following snippet provides simple T2I training arguments: |
|
|
| ```bash |
| accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ |
| --machine_rank 0 --num_machines 1 --num_processes 8 \ |
| scripts/train.py \ |
| config="./configs/ursa_1.7b_ibq1024.yaml" \ |
| experiment.name="ursa_1.7b_ibq1024" \ |
| experiment.output_dir="./experiments/ursa_1.7b_ibq1024" \ |
| pipeline.paths.pretrained_path="/path/to/URSA-1.7B-IBQ1024" \ |
| train_dataloader.params.dataset="./datasets/ibq1024_dataset" \ |
| model.gradient_checkpointing=3 \ |
| training.batch_size=4 \ |
| trainin.gradient_accumulation_steps=16 |
| ``` |
|
|
| ## Train T2V model |
| Following snippet provides simple T2V training arguments: |
|
|
| ```bash |
| accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ |
| --machine_rank 0 --num_machines 1 --num_processes 8 \ |
| scripts/train.py \ |
| config="./configs/ursa_1.7b_fsq320.yaml" \ |
| experiment.name="ursa_1.7b_fsq320" \ |
| experiment.output_dir="./experiments/ursa_1.7b_fsq320" \ |
| pipeline.paths.pretrained_path="/path/to/URSA-1.7B-FSQ320" \ |
| train_dataloader.params.dataset="./datasets/fsq320_dataset" \ |
| model.gradient_checkpointing=3 \ |
| training.batch_size=1 \ |
| trainin.gradient_accumulation_steps=32 |
| ``` |
|
|