-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path5_category.py
More file actions
61 lines (51 loc) · 2.42 KB
/
5_category.py
File metadata and controls
61 lines (51 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from transformers import BertForSequenceClassification, AutoConfig, AutoTokenizer, BertTokenizer, pipeline
import pandas as pd
print('Loading data...')
data = pd.read_excel('data/final_data_subject.xlsx')
print(len(data))
#data.head()
# Only take the part where there is an object, so where decision =1
data_claims = data[data['preds_binary'] == 1]
print(len(data_claims))
# Convert groups to strings
data_claims['group']= data_claims['group'].apply(eval).apply(lambda x: x[0])
data_claims['text-object'] = data_claims['Message_x'] + ' SEP ' + data_claims['group']
print('Data loaded.')
print('Loading model...')
model_path = 'code/BERTje_category'
model = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained('GroNLP/bert-base-dutch-cased', max_length=512, truncation=True)
print('Model loaded.')
print('Predicting...')
clf = pipeline("text-classification",model, tokenizer=tokenizer, max_length=512, truncation=True)
preds = clf(data_claims['text-object'].tolist())
# Since there are double digits I need to take the integer after the _
data_claims['preds_category'] = [i['label'].split('_')[1] for i in preds]
print('Predictions done.')
print('Converting labels...')
# I load the predictions from the model, which has the reversed labels
test_BERTje = pd.read_csv('code/outputs_BERTje_category/test_data.csv', sep='\t')
test_BERTje = test_BERTje[['labels', 'reverse_labels']]
#print(len(test_BERTje))
test_BERTje.drop_duplicates(inplace=True)
# Now, I need to replace the value in the data with the reverse label
data_claims['preds_category'] = data_claims['preds_category'].apply(int)
for index, row in data_claims.iterrows():
if row['preds_category'] in test_BERTje['labels'].values:
data_claims.at[index, 'preds_category_reversed'] = test_BERTje.loc[test_BERTje['labels'] == row['preds_category'], 'reverse_labels'].values[0]
print(f'value counts: {data_claims["preds_category_reversed"].value_counts()}')
print('Labels converted.')
print('Joining with non-claims...')
# Join back with non-claims
non_claims = data[data['preds_binary'] == 0]
non_claims['preds_category'] = 0
non_claims['preds_category_reversed'] = 0
#non_claims.head()
data = pd.concat([data_claims, non_claims])
print(len(data))
print('Joined with non-claims.')
print(f'value counts: {data["preds_category_reversed"].value_counts()}')
print('Saving data...')
data.to_excel('data/final_data_category.xlsx')
print('Data saved.')