En este notebook se muestra cómo crear un modelo de clasificación de texto usando modelos basados en mecanismos de atención mediante la librería HuggingFace, y en concreto mediante la librería blurr que integra HuggingFace con FastAI.

Para esta práctica es necesario el uso de GPU, así que recuerda activar esta opción en Colab.

Librerías

Comenzamos actualizando la librería FastAI, e instalando las librerías de HuggingFace (llamadas transformers y datasets) y la librería blurr. Al finalizar la instalación deberás reiniciar el kernel (menú Entorno de ejecución -> Reiniciar Entorno de ejecución).

!pip install fastai -Uqq
!pip install datasets -Uqq
!pip install transformers[sentencepiece] -Uqq
!pip install git+https://github.com/ohmeow/blurr.git@dev-2.0.0 -Uqq
     |████████████████████████████████| 189 kB 4.2 MB/s 
     |████████████████████████████████| 55 kB 3.6 MB/s 
     |████████████████████████████████| 325 kB 3.5 MB/s 
     |████████████████████████████████| 1.1 MB 38.3 MB/s 
     |████████████████████████████████| 134 kB 31.0 MB/s 
     |████████████████████████████████| 212 kB 35.2 MB/s 
     |████████████████████████████████| 67 kB 4.7 MB/s 
     |████████████████████████████████| 127 kB 40.2 MB/s 
     |████████████████████████████████| 94 kB 1.6 MB/s 
     |████████████████████████████████| 271 kB 34.0 MB/s 
     |████████████████████████████████| 144 kB 49.3 MB/s 
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
     |████████████████████████████████| 3.8 MB 4.0 MB/s 
     |████████████████████████████████| 596 kB 45.2 MB/s 
     |████████████████████████████████| 895 kB 48.6 MB/s 
     |████████████████████████████████| 6.5 MB 36.5 MB/s 
     |████████████████████████████████| 1.2 MB 42.8 MB/s 
     |████████████████████████████████| 43 kB 1.4 MB/s 
     |████████████████████████████████| 96 kB 3.8 MB/s 
  Building wheel for ohmeow-blurr (setup.py) ... done
  Building wheel for seqeval (setup.py) ... done

Cargamos a continuación las librerías que necesitaremos en esta práctica.

from fastai.data.all import *
from fastai.learner import *
from fastai.losses import CrossEntropyLossFlat
from fastai.optimizer import Adam, OptimWrapper, params
from fastai.metrics import accuracy, F1Score
from fastai.torch_core import *
from fastai.torch_imports import *
from transformers import AutoModelForSequenceClassification

from blurr.data.core import *
from blurr.modeling.core import *
from blurr.utils import BLURR
from datasets import load_dataset,concatenate_datasets

Dataset

Para este ejemplo vamos a usar el dataset Gutenberg Poem Dataset, un dataset para detectar sentimientos en poemas (negativos, positivos, sin impacto, mezcla de positivo y negativo).

Descarga el dataset usando el siguiente comando.

poem_sentiment_dataset = load_dataset("poem_sentiment")
Using custom data configuration default
Reusing dataset poem_sentiment (/root/.cache/huggingface/datasets/poem_sentiment/default/1.0.0/4e44428256d42cdde0be6b3db1baa587195e91847adabf976e4f9454f6a82099)

Vamos a añadir a nuestro dataset una columna que nos indique si estamos trabajando con el conjunto de entrenamiento o el de validación. Para lo cuál debemos definir la siguiente función.

def add_is_valid_batch_friendly(examples, is_valid=False):
  return {"is_valid": [is_valid for txt in examples["verse_text"]]}

Y ahora añadimos esa información.

poem_sentiment_dataset["train"] = poem_sentiment_dataset["train"].map(partial(add_is_valid_batch_friendly,is_valid=False),batched=True)
poem_sentiment_dataset["validation"] = poem_sentiment_dataset["validation"].map(partial(add_is_valid_batch_friendly,is_valid=True),batched=True)

print(poem_sentiment_dataset)
print(poem_sentiment_dataset["train"][0])
Loading cached processed dataset at /root/.cache/huggingface/datasets/poem_sentiment/default/1.0.0/4e44428256d42cdde0be6b3db1baa587195e91847adabf976e4f9454f6a82099/cache-a5043da3f45ff884.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/poem_sentiment/default/1.0.0/4e44428256d42cdde0be6b3db1baa587195e91847adabf976e4f9454f6a82099/cache-078c48838227da06.arrow
DatasetDict({
    train: Dataset({
        features: ['id', 'verse_text', 'label', 'is_valid'],
        num_rows: 892
    })
    validation: Dataset({
        features: ['id', 'verse_text', 'label', 'is_valid'],
        num_rows: 105
    })
    test: Dataset({
        features: ['id', 'verse_text', 'label'],
        num_rows: 104
    })
})
{'id': 0, 'verse_text': 'with pale blue berries. in these peaceful shades--', 'label': 1, 'is_valid': False}

Antes de empezar a entrenar nuestro modelo combinamos los conjuntos de entrenamiento y test.

proc_ds = concatenate_datasets([poem_sentiment_dataset["train"],poem_sentiment_dataset["validation"]])
proc_ds
Dataset({
    features: ['id', 'verse_text', 'label', 'is_valid'],
    num_rows: 997
})

Entrenando un modelo de clasificación

El proceso a seguir para hacer fine-tuning sobre el modelo de lenguaje de FastAI es análogo al visto en prácticas anteriores. Comenzamos creando un DataBlock a partir de nuestro dataframe. Sin embargo, para llevar a cabo esta tarea tenemos que definirlo a partir de los constructores de la librería Blurr.

En primer lugar definimos la tarea con la que vamos a trabajar que es la clasificación de secuencias de texto.

model_cls = AutoModelForSequenceClassification

A continuación debemos indicar el modelo que vamos a utilizar, ya que la creación de los datablocks es dependiente de esta elección. En nuestro caso usaremo el modelo Bert.

pretrained_model_name = "bert-base-uncased"

Seguidamente debemos crear la configuración de nuesto datablock, para ello vamos a usar la configuración por defecto del modelo que hemos seleccionado y cambiaremos el número de posibles clases de nuestro problema (en este caso 4).

labels = proc_ds.features["label"].names
labels
n_labels = len(labels)

Por último generamos varias componentes que son necesarias para entrenar nuestro modelo y que son dependientes del modelo que elijamos:

- La arquitectura del modelo.
- La configuración del modelo.
- El tokenizer.
- El modelo de huggingface. 
hf_arch, hf_config, hf_tokenizer,hf_model = BLURR.get_hf_objects(
    pretrained_model_name,model_cls= model_cls, config_kwargs={"num_labels": n_labels}
)

hf_arch, type(hf_config), type(hf_tokenizer),type(hf_model)
('bert',
 transformers.models.bert.configuration_bert.BertConfig,
 transformers.models.bert.tokenization_bert_fast.BertTokenizerFast,
 transformers.models.bert.modeling_bert.BertForSequenceClassification)

Ahora ya podemos construir nuestro datablock para lo cual primero debemos preprocesarlo con el modelo que vamos a usar. Notar que debemos indicar el atributo donde se encuentran las frases de nuestro dataset (en este caso verse_text).

preprocessor = ClassificationPreprocessor(hf_tokenizer,label_mapping=labels,text_attr='verse_text')
ds = preprocessor.process_hf_dataset(proc_ds)
ds
Dataset({
    features: ['proc_verse_text', 'id', 'verse_text', 'label', 'is_valid', 'label_name', 'verse_text_start_char_idx', 'verse_text_end_char_idx'],
    num_rows: 997
})

También debemos identificar los índices de nuestro dataset que pertenecen al conjunto de validación.

val_idxs = [idx for idx,el in enumerate(ds) if el["is_valid"]==True]
min(val_idxs),max(val_idxs)
(892, 996)

Y por fin construimos nuestro datablock.

blocks = (
    TextBlock(
        hf_arch,
        hf_config,
        hf_tokenizer,
        hf_model,
        is_pretokenized=True,
        before_batch_kwargs={"labels":labels},
        tok_kwargs={"add_special_tokens":False}
    ),
    CategoryBlock
)

dblock = DataBlock(
    blocks=blocks,
    get_x=ItemGetter('verse_text'), 
    get_y=ItemGetter('label'), 
    splitter=IndexSplitter(val_idxs)
)

Creamos ahora nuestro dataloader.

dls = dblock.dataloaders(ds,bs=16)

Podemos ahora mostrar un batch de este dataloader.

dls.show_batch(dataloaders=dls, max_n=2)
text target
0 for'twas e'en as a great god's slaying, and they feared the wrath of the sky ; 0
1 " what hope wouldst thou hope, o sigurd, ere we kiss, we twain, and depart? " 3

Creamos ahora nuestro Learner.

model = BaseModelWrapper(hf_model)

learn = Learner(dls, 
                model,
                metrics=[accuracy],
                cbs=[BaseModelCallback],
                splitter=blurr_splitter
                )

Y por último entrenamos el modelo.

learn.fine_tune(10,1e-3)
epoch train_loss valid_loss accuracy time
0 1.128518 0.917785 0.657143 00:05
epoch train_loss valid_loss accuracy time
0 1.050082 0.924744 0.657143 00:10
1 0.986408 0.964888 0.657143 00:10
2 0.830305 0.625145 0.780952 00:10
3 0.544306 0.568373 0.809524 00:10
4 0.310422 0.601602 0.847619 00:10
5 0.181153 0.590411 0.857143 00:10
6 0.102145 0.958344 0.752381 00:10
7 0.057657 0.965625 0.761905 00:10
8 0.031127 0.774979 0.847619 00:10
9 0.021683 0.779923 0.847619 00:10

Una vez entrenado el modelo podemos guardarlo para el futuro.

export_fname = 'seq_class_learn_export'
learn.export(fname=f'{export_fname}.pkl')

Para realizar predicciones debemos usar el método blurr_predict.

learn.blurr_predict('with pale blue berries. in these peaceful shades--.')
[{'class_index': 1,
  'class_labels': [0, 1, 2, 3],
  'label': '1',
  'probs': [0.000199360991246067,
   0.9650265574455261,
   0.0011597874108701944,
   0.03361422196030617],
  'score': 0.9650265574455261}]

Por último, validamos en nuestro conjunto de test.

poem_sentiment_dataset["test"] = poem_sentiment_dataset["test"].map(partial(add_is_valid_batch_friendly,is_valid=True),batched=True)
proc_ds_test = concatenate_datasets([poem_sentiment_dataset["train"],poem_sentiment_dataset["test"]])
ds_test = preprocessor.process_hf_dataset(proc_ds_test)

test_idxs = [idx for idx,el in enumerate(ds_test) if el["is_valid"]==True]

dblock_test = DataBlock(
    blocks=blocks,
    get_x=ItemGetter('verse_text'), 
    get_y=ItemGetter('label'), 
    splitter=IndexSplitter(test_idxs)
)

dls_test = dblock_test.dataloaders(ds_test,bs=16)
learn.dls = dls_test
learn.validate()
(#2) [0.7783738374710083,0.8269230723381042]

Con esto hemos logrado un modelo con una accuracy del 82% (muy superior al 69% logrado en la práctica anterior).