| from everything import * |
| from bert import BertModel |
| from optimizer import AdamW |
| from tokenizer import BertTokenizer |
|
|
|
|
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
| class SentimentDataset(Dataset): |
| def __init__(self, dataset): |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| return self.dataset[idx] |
|
|
| def pad_data(self, data): |
| sents = [x[0] for x in data] |
| labels = [x[1] for x in data] |
| sent_ids = [x[2] for x in data] |
|
|
| encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
| token_ids = torch.LongTensor(encoding['input_ids']) |
| attention_mask = torch.LongTensor(encoding['attention_mask']) |
| labels = torch.LongTensor(labels) |
|
|
| return token_ids, attention_mask, labels, sents, sent_ids |
|
|
| def collate_fn(self, all_data): |
| token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data) |
|
|
| batched_data = { |
| 'token_ids': token_ids, |
| 'attention_mask': attention_mask, |
| 'labels': labels, |
| 'sents': sents, |
| 'sent_ids': sent_ids |
| } |
|
|
| return batched_data |
|
|
|
|
| class SentimentTestDataset(Dataset): |
| def __init__(self, dataset): |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| return self.dataset[idx] |
|
|
| def pad_data(self, data): |
| sents = [x[0] for x in data] |
| sent_ids = [x[1] for x in data] |
|
|
| encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
| token_ids = torch.LongTensor(encoding['input_ids']) |
| attention_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
| return token_ids, attention_mask, sents, sent_ids |
|
|
| def collate_fn(self, all_data): |
| token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data) |
|
|
| batched_data = { |
| 'token_ids': token_ids, |
| 'attention_mask': attention_mask, |
| 'sents': sents, |
| 'sent_ids': sent_ids |
| } |
|
|
| return batched_data |
|
|
|
|
| class AmazonDataset(Dataset): |
| def __init__(self, dataset): |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return len(self.dataset) |
| |
| def __getitem__(self, idx): |
| return self.dataset[idx] |
| |
| def pad_data(self, data): |
| sents = [x[0] for x in data] |
| sent_ids = [x[1] for x in data] |
| encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
| token_ids = torch.LongTensor(encoding['input_ids']) |
| attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
| return token_ids, attension_mask, sent_ids |
| |
| def collate_fn(self, data): |
| token_ids, attention_mask, sent_ids = self.pad_data(data) |
|
|
| batched_data = { |
| 'token_ids': token_ids, |
| 'attention_mask': attention_mask, |
| 'sent_ids': sent_ids |
| } |
|
|
| return batched_data |
|
|
|
|
| class SemanticDataset(Dataset): |
| def __init__(self, dataset): |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return len(self.dataset) |
| |
| def __getitem__(self, idx): |
| return self.dataset[idx] |
| |
| def pad_data(self, data): |
| sents1 = [x[0] for x in data] |
| sents2 = [x[1] for x in data] |
| score = [x[2] for x in data] |
| sent_ids = [x[3] for x in data] |
| encoding = tokenizer(sents1 + sents2, return_tensors='pt', padding=True, truncation=True) |
| token_ids = torch.LongTensor(encoding['input_ids']) |
| attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
| return token_ids, attension_mask, score, sent_ids |
| |
| def collate_fn(self, data): |
| token_ids, attention_mask, score, sent_ids = self.pad_data(data) |
| n = len(sent_ids) |
|
|
| batched_data = { |
| 'token_ids_1': token_ids[:n], |
| 'token_ids_2': token_ids[n:], |
| 'attention_mask_1': attention_mask[:n], |
| 'attention_mask_2': attention_mask[n:], |
| 'score': score, |
| 'sent_ids': sent_ids |
| } |
|
|
| return batched_data |
|
|
| |
| class InferenceDataset(Dataset): |
| def __init__(self, dataset): |
| self.dataset = dataset |
|
|
| def __len__(self): |
| return len(self.dataset) |
| |
| def __getitem__(self, idx): |
| return self.dataset[idx] |
| |
| def pad_data(self, data): |
| anchor = [x[0] for x in data] |
| positive = [x[1] for x in data] |
| negative = [x[2] for x in data] |
| sent_ids = [x[3] for x in data] |
| encoding = tokenizer(anchor + positive + negative, return_tensors='pt', padding=True, truncation=True) |
| token_ids = torch.LongTensor(encoding['input_ids']) |
| attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
| return token_ids, attension_mask, sent_ids |
| |
| def collate_fn(self, data): |
| token_ids, attention_mask, sent_ids = self.pad_data(data) |
| n = len(sent_ids) |
|
|
| batched_data = { |
| 'anchor_ids': token_ids[:n], |
| 'positive_ids': token_ids[n:2*n], |
| 'negative_ids': token_ids[2*n:], |
| 'anchor_masks': attention_mask[:n], |
| 'positive_masks': attention_mask[n:2*n], |
| 'negative_masks': attention_mask[2*n:], |
| 'sent_ids': sent_ids |
| } |
|
|
| return batched_data |
|
|
|
|
| def load_data(filename, flag='train'): |
| ''' |
| - for amazon dataset: list of (sent, id) |
| - for nli dataset: list of (anchor, positive, negative, id) |
| - for stsb dataset: list of (sentence1, sentence2, score, id) |
| |
| - for test dataset: list of (sent, id) |
| - for train dataset: list of (sent, label, id) |
| ''' |
|
|
| if flag == 'amazon': |
| df = pd.read_parquet(filename) |
| data = list(zip(df['content'], df.index)) |
| elif flag == 'nli': |
| df = pd.read_parquet(filename) |
| data = list(zip(df['anchor'], df['positive'], df['negative'], df.index)) |
| elif flag == 'stsb': |
| df = pd.read_parquet(filename) |
| data = list(zip(df['sentence1'], df['sentence2'], df['score'], df.index)) |
| else: |
| data, num_labels = [], set() |
|
|
| with open(filename, 'r') as fp: |
| if flag == 'test': |
| for record in csv.DictReader(fp, delimiter = '\t'): |
| sent = record['sentence'].lower().strip() |
| sent_id = record['id'].lower().strip() |
| data.append((sent,sent_id)) |
| else: |
| for record in csv.DictReader(fp, delimiter = '\t'): |
| sent = record['sentence'].lower().strip() |
| sent_id = record['id'].lower().strip() |
| label = int(record['sentiment'].strip()) |
| num_labels.add(label) |
| data.append((sent, label, sent_id)) |
|
|
| print(f"load {len(data)} data from {filename}") |
| if flag == "train": |
| return data, len(num_labels) |
| else: |
| return data |
|
|
|
|
| def seed_everything(seed=11711): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.benchmark = False |
| torch.backends.cudnn.deterministic = True |