Applicazioni pratiche di deep learning/Classificazione dei tweets
Mentre col sentiment analysis dato un testo si può stabilire la sua polarità cioè se si tratta di un'opinione positiva o negativa, è possibile classificare un testo in base a più categorie. Nel caso specifico dal sito Zindi si preleva un dataset contenente 5287 tweets contenuti nella variabile text e classificati con la variabile target che vale 1 se il tweet riguarda il Covid, altrimenti vale 0. Si prelevano dal dataset 1000 tweets e tramite il modello digitalepidemiologylab/covid-twitter-bert-v2-mnli di Hugging Face si classificano in 1 se riguardano la categoria health (salute), altrimenti 0 e poi si calcola l'Accuracy relativa a questi 1000 tweets.
Caricamento librerie
[modifica | modifica sorgente]!pip install -q transformers
Caricamento dati da Zindi
[modifica | modifica sorgente]Per caricare su Colab il dataset updated_train.csv da Zindi occorre iscriversi alla competizione sulla classificazione dei Tweets in Covid si, Covid no che si trova qui https://zindi.africa/competitions/zindiweekendz-learning-covid-19-tweet-classification-challenge e utilizzare l'analizzatore del browser per conoscere l'url e il token del file, quindi copiarli nel seguente codice:
# Import libraries
import requests
from tqdm.auto import tqdm
# Function to download data
def zindi_data_downloader(url, token, file_name):
# Get the competition data
competition_data = requests.post(url = data_url, data= token, stream=True)
# Progress bar monitor download
pbar = tqdm(desc=file_name, total=int(competition_data.headers.get('content-length', 0)), unit='B', unit_scale=True, unit_divisor=512)
# Create and Write the data to colab drive in chunks
handle = open(file_name, "wb")
for chunk in competition_data.iter_content(chunk_size=512): # Download the data in chunks
if chunk: # filter out keep-alive new chunks
handle.write(chunk)
pbar.update(len(chunk))
handle.close()
pbar.close()
# Data url, token and file_name
data_url = "https://api.zindi.africa/v1/competitions/zindiweekendz-learning-covid-19-tweet-classification-challenge/files/updated_train.csv" # url
token = {'auth_token': ''} # Use your own token
file_name = 'updated_train.csv'
# Download data
zindi_data_downloader(url = data_url, token = token, file_name = file_name)
Tramite la libreria pandas si carica il dataset e si visualizzano le sue variabili e la sua dimensione:
import pandas as pd
df = pd.read_csv('updated_train.csv')
print(df.to_string())
list(df)
len(df)
['ID', 'text', 'target'] 5287
Si prelevano 1000 tweets dalla variabile text e 1000 labels dalla variabile target:
test_cases = df["text"][:1000]
test_labels = df["target"][:1000]
Si importa il modello da Hugging Face per la classificazione di testi relativi al Covid:
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="digitalepidemiologylab/covid-twitter-bert-v2-mnli"
Si prova il modello su di un testo da classificare con le categorie 'health', 'sport', 'vaccine' e si vede che il testo riguarda la categoria vaccine con un'accuracy del 58%:
sequence_to_classify = 'He say we want to classify COVID-tweets as vaccine-related and not vaccine-related.'
candidate_labels = ['health', 'sport', 'vaccine']
classifier(sequence_to_classify, candidate_labels)
{'labels': ['vaccine', 'health', 'sport'], 'scores': [0.5795532464981079, 0.22699832916259766, 0.19344842433929443], 'sequence': 'He say we want to classify COVID-tweets as vaccine-related and not vaccine-related.'}
Si crea il vettore predictions usando il modello sui 1000 tweets per valutare se riguarda l'argomento health e tramite la libreria evaluate si calcola l'accuracy che risulta del 61%:
predictions = []
for i in range(1000):
output = classifier(test_cases[i], candidate_labels)
if output['labels'][0] == 'health':
predictions.append(1)
else:
predictions.append(0)
if(i%100==0): print(i)
!pip install evaluate
import evaluate
accuracy_metric = evaluate.load("accuracy")
accuracy_output = accuracy_metric.compute(references=test_labels, predictions=predictions)
print(accuracy_output)
{'accuracy': 0.61}