YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

import pandas as pd from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader from transformers import BigGANModel, BigGANTokenizer import torch from tqdm import tqdm

Define the custom dataset class

class TextToImageDataset(Dataset): def init(self, data, tokenizer, max_length=128): self.data = data self.tokenizer = tokenizer self.max_length = max_length

def __len__(self):
    return len(self.data)

def __getitem__(self, idx):
    text = self.data.iloc[idx]['text']
    inputs = self.tokenizer(text, return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length")
    return inputs

Load and preprocess the dataset

dataset_path = "text_to_image_dataset.csv" data = pd.read_csv(dataset_path) train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

Define the tokenizer and model

tokenizer = BigGANTokenizer.from_pretrained("biggan-deep-128") model = BigGANModel.from_pretrained("biggan-deep-128")

Define the dataset and dataloader

train_dataset = TextToImageDataset(train_data, tokenizer) test_dataset = TextToImageDataset(test_data, tokenizer) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32)

Fine-tune the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) loss_fn = torch.nn.CrossEntropyLoss()

epochs = 5 for epoch in range(epochs): model.train() total_loss = 0.0 for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"): inputs = {k: v.to(device) for k, v in batch.items()} outputs = model(**inputs) loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), inputs['input_ids'].view(-1)) loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader)}")

Evaluate the model

model.eval() total_correct = 0 total_samples = 0 with torch.no_grad(): for batch in tqdm(test_loader, desc="Evaluating"): inputs = {k: v.to(device) for k, v in batch.items()} outputs = model(**inputs) predicted_ids = torch.argmax(outputs.logits, dim=-1) total_correct += (predicted_ids == inputs['input_ids']).sum().item() total_samples += inputs['input_ids'].numel()

accuracy = total_correct / total_samples print(f"Accuracy: {accuracy}")

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support