| from classifier_utils import * |
|
|
|
|
| TQDM_DISABLE=True |
|
|
|
|
| class BertSentimentClassifier(torch.nn.Module): |
| def __init__(self, config, custom_bert = None): |
| super(BertSentimentClassifier, self).__init__() |
| self.num_labels = config.num_labels |
| self.bert: BertModel = custom_bert or BertModel.from_pretrained('bert-base-uncased') |
|
|
| |
| assert config.fine_tune_mode in ["last-linear-layer", "full-model"] |
| for param in self.bert.parameters(): |
| if config.fine_tune_mode == 'last-linear-layer': |
| param.requires_grad = False |
| elif config.fine_tune_mode == 'full-model': |
| param.requires_grad = True |
|
|
| |
| self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels) |
|
|
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.bert(input_ids, attention_mask) |
| pooler_output = outputs['pooler_output'] |
|
|
| return self.classifier(self.dropout(pooler_output)) |
|
|
|
|
| |
| def model_eval(dataloader, model: BertSentimentClassifier, device): |
| model.eval() |
| y_true = [] |
| y_pred = [] |
| sents = [] |
| sent_ids = [] |
| for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)): |
| b_labels, b_sents, b_sent_ids = batch['labels'], batch['sents'], batch['sent_ids'] |
|
|
| b_ids = batch['token_ids'].to(device) |
| b_mask = batch['attention_mask'].to(device) |
|
|
| logits = model(b_ids, b_mask) |
| logits = logits.detach().cpu().numpy() |
| preds = np.argmax(logits, axis=1).flatten() |
|
|
| b_labels = b_labels.flatten() |
| y_true.extend(b_labels) |
| y_pred.extend(preds) |
| sents.extend(b_sents) |
| sent_ids.extend(b_sent_ids) |
|
|
| f1 = f1_score(y_true, y_pred, average='macro') |
| acc = accuracy_score(y_true, y_pred) |
|
|
| return acc, f1, y_pred, y_true, sents, sent_ids |
|
|
|
|
| |
| def model_test_eval(dataloader, model, device): |
| model.eval() |
| y_pred = [] |
| sents = [] |
| sent_ids = [] |
| for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)): |
| b_sents, b_sent_ids = batch['sents'], batch['sent_ids'] |
|
|
| b_ids = batch['token_ids'].to(device) |
| b_mask = batch['attention_mask'].to(device) |
|
|
| logits = model(b_ids, b_mask) |
| logits = logits.detach().cpu().numpy() |
| preds = np.argmax(logits, axis=1).flatten() |
|
|
| y_pred.extend(preds) |
| sents.extend(b_sents) |
| sent_ids.extend(b_sent_ids) |
|
|
| return y_pred, sents, sent_ids |
|
|
|
|
| def save_model(model, args, config, filepath): |
| save_info = { |
| 'model': model.state_dict(), |
| 'args': args, |
| 'model_config': config, |
| 'system_rng': random.getstate(), |
| 'numpy_rng': np.random.get_state(), |
| 'torch_rng': torch.random.get_rng_state(), |
| } |
|
|
| torch.save(save_info, filepath) |
| print(f"save the model to {filepath}") |
|
|
|
|
| def train(args, custom_bert=None): |
| device = torch.device('cuda') if USE_GPU else torch.device('cpu') |
| |
| train_data, num_labels = load_data(args.train, 'train') |
| dev_data = load_data(args.dev, 'valid') |
|
|
| train_dataset = SentimentDataset(train_data) |
| dev_dataset = SentimentDataset(dev_data) |
|
|
| train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, |
| num_workers=NUM_CPU_CORES, collate_fn=train_dataset.collate_fn) |
| dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, |
| num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn) |
|
|
| |
| config = {'hidden_dropout_prob': HIDDEN_DROPOUT_PROB, |
| 'num_labels': num_labels, |
| 'hidden_size': 768, |
| 'data_dir': '.', |
| 'fine_tune_mode': args.fine_tune_mode} |
|
|
| config = SimpleNamespace(**config) |
|
|
| model = BertSentimentClassifier(config, custom_bert) |
| model = model.to(device) |
|
|
| lr = args.lr |
| optimizer = AdamW(model.parameters(), lr=lr) |
| best_dev_acc = 0 |
|
|
| |
| for epoch in range(EPOCHS): |
| model.train() |
| train_loss = 0 |
| num_batches = 0 |
| for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE): |
| b_ids = batch['token_ids'].to(device) |
| b_mask = batch['attention_mask'].to(device) |
| b_labels = batch['labels'].to(device) |
|
|
| optimizer.zero_grad() |
| logits = model(b_ids, b_mask) |
| loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size |
|
|
| loss.backward() |
| optimizer.step() |
|
|
| train_loss += loss.item() |
| num_batches += 1 |
|
|
| train_loss = train_loss / (num_batches) |
|
|
| train_acc, train_f1, *_ = model_eval(train_dataloader, model, device) |
| dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device) |
|
|
| if dev_acc > best_dev_acc: |
| best_dev_acc = dev_acc |
| save_model(model, args, config, args.filepath) |
|
|
| print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}") |
|
|
|
|
| def test(args): |
| with torch.no_grad(): |
| device = torch.device('cuda') if USE_GPU else torch.device('cpu') |
| saved = torch.load(args.filepath, weights_only=False) |
| config = saved['model_config'] |
| model = BertSentimentClassifier(config) |
| model.load_state_dict(saved['model']) |
| model = model.to(device) |
| print(f"load model from {args.filepath}") |
| |
| dev_data = load_data(args.dev, 'valid') |
| dev_dataset = SentimentDataset(dev_data) |
| dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, |
| num_workers=NUM_CPU_CORES, collate_fn=dev_dataset.collate_fn) |
|
|
| dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device) |
| print('DONE DEV') |
| print(f"dev acc :: {dev_acc :.3f}") |
|
|
|
|
| def classifier_run(args, custom_bert=None): |
| seed_everything(SEED) |
| torch.set_num_threads(NUM_CPU_CORES) |
| |
| print(f'Training Sentiment Classifier on {args.dataset}...') |
| config = SimpleNamespace( |
| filepath=f'{args.dataset}-classifier.pt', |
| lr=args.lr, |
| batch_size=args.batch_size, |
| fine_tune_mode=args.fine_tune_mode, |
| train=args.train, dev=args.dev, test=args.test, |
| dev_out = f'/predictions/{args.fine_tune_mode}-{args.dataset}-dev-out.csv', |
| test_out = f'/predictions/{args.fine_tune_mode}-{args.dataset}-test-out.csv' |
| ) |
|
|
| train(config, custom_bert) |
|
|
| print(f'Evaluating on {args.dataset}...') |
| test(config) |