YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
import pandas as pd from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader from transformers import BigGANModel, BigGANTokenizer import torch from tqdm import tqdm
Define the custom dataset class
class TextToImageDataset(Dataset): def init(self, data, tokenizer, max_length=128): self.data = data self.tokenizer = tokenizer self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data.iloc[idx]['text']
inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length")
return inputs
Load and preprocess the dataset
dataset_path = "text_to_image_dataset.csv" data = pd.read_csv(dataset_path) train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
Define the tokenizer and model
tokenizer = BigGANTokenizer.from_pretrained("biggan-deep-128") model = BigGANModel.from_pretrained("biggan-deep-128")
Define the dataset and dataloader
train_dataset = TextToImageDataset(train_data, tokenizer) test_dataset = TextToImageDataset(test_data, tokenizer) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32)
Fine-tune the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) loss_fn = torch.nn.CrossEntropyLoss()
epochs = 5 for epoch in range(epochs): model.train() total_loss = 0.0 for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"): inputs = {k: v.to(device) for k, v in batch.items()} outputs = model(**inputs) loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), inputs['input_ids'].view(-1)) loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader)}")
Evaluate the model
model.eval() total_correct = 0 total_samples = 0 with torch.no_grad(): for batch in tqdm(test_loader, desc="Evaluating"): inputs = {k: v.to(device) for k, v in batch.items()} outputs = model(**inputs) predicted_ids = torch.argmax(outputs.logits, dim=-1) total_correct += (predicted_ids == inputs['input_ids']).sum().item() total_samples += inputs['input_ids'].numel()
accuracy = total_correct / total_samples print(f"Accuracy: {accuracy}")