Segmentación semántica
Usando FastAI para crear un modelo de segmentación semántica.
- Librerías
- Dataset
- Definiciones previas
- Partición del dataset
- Data augmentation
- Dataloader
- Definición de modelo
- Evaluando el modelo
- Inferencia
En este notebook se muestra cómo crear un modelo de segmentación semántica usando la arquitectura U-net incluida en la librería FastAI.
En esta práctica vamos a hacer un uso intensivo de la GPU, así que es importante activar su uso desde la opción Configuración del cuaderno del menú Editar (esta opción debería estar habilitada por defecto, pero es recomendable que lo compruebes).
!pip install fastai -Uq
Cargamos a continuación las librerías que necesitaremos en esta práctica.
from fastai.basics import *
from fastai.vision import models
from fastai.vision.all import *
from fastai.metrics import *
from fastai.data.all import *
from fastai.callback import *
from pathlib import Path
import random
Dataset
Para esta práctica vamos a usar como dataset el proporcionado en el trabajo Deep neural networks for grape bunch segmentation in natural images from a consumer‑grade camera. Este dataset dedicado a la segmentación de racimos de uva consta de 66 imágenes de entrenamiento y 14 de test con 5 categorías: background, leaves, wood, pole, y grape. Los siguientes comandos descargan y descomprimen dicho dataset. En este notebook vamos a usar solo dos clases: background y grape.
%%capture
!wget https://www.dropbox.com/s/uknzc914w311web/dataset.zip?dl=1 -O dataset.zip
!unzip dataset.zip
Vamos a explorar el contenido de este dataset. Para ello vamos a crear un objeto Path que apunta al directorio que acabamos de crear.
path=Path('dataset/')
Como en la práctica anterior, podemos ver el contenido de este directorio usando el comando ls()
.
path.ls()
Si exploráis el directorio podréis ver que hay dos carpetas llamadas Images y Labels. La carpeta Images contiene las imágenes del dataset, y la carpeta Labels contiene las en forma de máscara. Para cada imagen, hay un fichero de anotación siguiendo la siguiente nomenclatura: si la imagen se llama color_xxx.jpg, su fichero de anotación es gt_xxx.png. El dataset está partido en entrenamiento y test como puede verse en las carpetas Images y Labels. Además, se proporcionan dos ficheros txt que van a contener las clases de los objetos que utilizaremos en esta práctica. El fichero codes.txt contiene solo dos clases (background y grape), mientras que el fichero codesAll.txt contiene todas las posibles clases.
(path/'Images').ls()
(path/'Images/train').ls()
(path/'Labels/train').ls()
Definiciones previas
El proceso para entrenar nuestro modelo va a ser similar al visto en la práctica 1 para crear un modelo de clasificación. Sin embargo, para cargar nuestro dataset será necesario dar unas definiciones previas. Estas definiciones son necesarias para ajustar la carga del datos a la estructura de nuestro dataset.
En primer lugar vamos a definir los paths donde se van a encontrar nuestras imágenes y sus etiquetas.
path_images = path/"Images"
path_labels = path/"Labels"
A continuación definimos el nombre que va a tener nuestra carpeta de test.
test_name = "test"
Seguidamente definimos una función que dado el path de una imagen nos devuelve el path de su anotación.
def get_y_fn (x):
return Path(str(x).replace("Images","Labels").replace("color","gt").replace(".jpg",".png"))
Seguidamente cargamos las clases que pueden tener los píxeles de nuestra imágenes y lo almacenamos en una lista codes
.
codes = np.loadtxt(path/'codes.txt', dtype=str)
codes
Podemos ahora ver alguna de las imágenes de nuestro dataset.
img_f = path_images/'train/color_206.jpg'
img = PILImage.create(img_f)
img.show(figsize=(5, 5))
Y también la anotación asociada.
mask = PILMask.create(get_y_fn(img_f))
mask.show(figsize=(5, 5), alpha=1)
Como podemos ver en la imagen anterior tenemos una máscara donde cada tipo de objeto de nuestra imagen tiene un color distinto.
def ParentSplitter(x):
return Path(x).parent.name==test_name
Data augmentation
Al igual que con los modelos definidos en prácticas anteriores podemos usar técnicas de aumento de datos, para lo que usaremos la librería Albumentations. Recordar que dichas transformaciones no deben aplicarse solo a la imagen sino también a su anotación. Para ello vamos a definir una clase que hereda de la clase ItemTransform
y que nos va a permitir realizar transformaciones sobre pares (imagen,máscara).
La clase ItemTransform
tiene un método encodes
que es el encargado de realizar la transformación sobre su entrada x
que en este caso será un par (imagen,máscara). Además el constructor de la clase que vamos a definir recibirá como parámetro las transformaciones a aplicar.
from albumentations import (
Compose,
OneOf,
ElasticTransform,
GridDistortion,
OpticalDistortion,
HorizontalFlip,
Rotate,
Transpose,
CLAHE,
ShiftScaleRotate
)
class SegmentationAlbumentationsTransform(ItemTransform):
split_idx = 0
def __init__(self, aug):
self.aug = aug
def encodes(self, x):
img,mask = x
aug = self.aug(image=np.array(img), mask=np.array(mask))
return PILImage.create(aug["image"]), PILMask.create(aug["mask"])
En nuestro caso vamos a utilizar solo flips horizontales, rotaciones, y una operación que aplica una pequeña distorsión a la imagen. Dichas transformaciones se aplicarán de manera secuencia y de manera aleatoria.
transforms=Compose([HorizontalFlip(p=0.5),
Rotate(p=0.40,limit=10),GridDistortion()
],p=1)
Por último construimos un objeto de la clase definida anteriormente.
transformPipeline=SegmentationAlbumentationsTransform(transforms)
También va a ser necesario realizar una transformación adicional sobre las máscaras. Las máscaras contienen píxeles con 7 valores distintos (255: grape, 150: leaves, 76: pole, 74: pole, 29: wood, 25: wood, 0: background). Como vamos a trabajar únicamente con las clases grape y background, los píxeles del resto de clases deberán estar a 0 (es decir los vamos a considerar como background). Además, los números de las clases deben ser 0,1,2,... Es por esto que es necesario cambiar todos los píxeles con valor 255 a valor 1. Para realizar estas transformaciones definimos la siguiente clase.
class TargetMaskConvertTransform(ItemTransform):
def __init__(self):
pass
def encodes(self, x):
img,mask = x
#Convert to array
mask = np.array(mask)
mask[mask!=255]=0
# Change 255 for 1
mask[mask==255]=1
# Back to PILMask
mask = PILMask.create(mask)
return img, mask
trainDB = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=partial(get_image_files,folders=['train']),
get_y=get_y_fn,
splitter=RandomSplitter(valid_pct=0.2),
item_tfms=[Resize((480,640)), TargetMaskConvertTransform(), transformPipeline],
batch_tfms=Normalize.from_stats(*imagenet_stats)
)
Vamos a explicar cada una de las componentes anteriores:
-
blocks=(ImageBlock, MaskBlock(codes))
. En este caso tenemos que la entrada de nuestro modelo va a ser una imagen (representada mediante unImageBlock
) y su salida es una máscara (representado medianteMaskBlock
) cuyos posibles valores son aquellos proporcionados por la lista de clases almacenada en la variablecodes
. -
get_items=partial(get_image_files,folders=['train'])
. El parámetroget_items
sirve para indicar cómo cargar los datos de nuestro dataset. Para esto vamos a usar la funciónget_image_files
que devuelve los paths de las imágenes que se encuentran dentro de la carpetafolders
(en nuestro caso la carpetatrain
). -
get_y=get_y_fn
. El parámetroget_y
sirve para indicar cómo obtener la anotación asociada con una entrada (recordar que una entrada va a ser una imagen definida a partir de su path). Para esto tenemos la funciónget_y_fn
definida anteriormente. -
splitter=RandomSplitter(valid_pct=0.2)
. Como siempre debemos partir nuestro dataset para tener un conjunto de validación de cara a seleccionar nuestros hiperparámetros. En este caso partimos el conjunto de entrenamiento usando un porcentaje 80/20. -
item_tfms=[Resize((480,640)), TargetMaskConvertTransform(), transformPipeline]
. En el parámetroitem_tfms
indicamos las transformaciones que vamos a aplicar a nuestras imágenes y sus correspondientes máscaras. Además de las explicadas anteriormente vamos a reescalar las imágenes al tamaño 480x640. -
batch_tfms=Normalize.from_stats(*imagenet_stats)
. En el parámetrobatch_tfms
indicamos las transformaciones que se realizan a nivel de batch. En este caso como en nuestro modelo utilizaremos un backbone preentrenado en ImageNet debemos normalizar las imágenes para que tengan la escala de esas imágenes.
Con las explicaciones anteriores en sencillo comprender como definimos el siguiente DataBlock
que nos servirá para evaluar nuestros modelos en el conjunto de test.
testDB = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=partial(get_image_files,folders=['train','test']),
get_y=get_y_fn,
splitter=FuncSplitter(ParentSplitter),
item_tfms=[Resize((480,640)), TargetMaskConvertTransform(), transformPipeline],
batch_tfms=Normalize.from_stats(*imagenet_stats)
)
Ahora ya podemos definir nuestros Dataloaders
indicando el path donde se encuentran las imágenes y el batch size que vamos a utilizar.
bs = 4
trainDLS = trainDB.dataloaders(path_images,bs=bs)
testDLS = testDB.dataloaders(path_images,bs=bs)
Como siempre es conveniente mostrar un batch para comprobar que se están cargando los datos correctamente.
trainDLS.show_batch(vmin=0,vmax=1,figsize=(12, 9))
Definición de modelo
Ya podemos definir nuestro modelo y entrenarlo como hemos hecho en prácticas anteriores. Para ello vamos a crear un Learner
mediante la función unet_learner
a la cual le tenemos que proporcionar el DataLoader
el backbone que vamos a utilizar (en este caso usaremos un modelo Resnet-18) y las métricas Dice y Jaccard que emplearemos para evaluar nuestro modelo.
learn = unet_learner(trainDLS,resnet18,metrics=[Dice(),JaccardCoeff()]).to_fp16()
Por último entrenamos nuestro modelo.
learn.fit_one_cycle(20,3e-3)
Una vez entrenado el modelo lo vamos a guardar para usarlo posteriormente. Lo primero que hacemos es extraer el modelo del Learner
y caragarlo en la CPU.
aux=learn.model
aux=aux.cpu()
Ahora vamos a guardarlo, para lo cual es necesario cargar una imagen que le servirá como referencia para realizar las transformaciones necesarias. Para ello es necesario normalizar la imagen para que sigan el estándar de ImageNet.
import torchvision.transforms as transforms
img = PILImage.create(path_images/'train/color_206.jpg')
transformer=transforms.Compose([transforms.Resize((480,640)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
img=transformer(img).unsqueeze(0)
img=img.cpu()
traced_cell=torch.jit.trace(aux, (img))
traced_cell.save("unet.pth")
Para ello debemos modificar el dataloader del objeto Learn
que hemos entrenado anteriormente.
learn.dls = testDLS
Por último evaluamos nuestro modelo usando el método validate()
. En este caso el método validate()
devuelve tres valores, el valor de la pérdida, y el valor de las métricas definidas anteriormente con respecto al conjunto de test.
learn.validate()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("unet.pth")
model = model.cpu()
model.eval()
El siguiente paso es cargar la imagen, para lo que usaremos la librería PIL
.
import PIL
img = PIL.Image.open('dataset/Images/test/color_154.jpg')
La siguiente instrucción permite mostrar la imagen que acabamos de cargar.
img
Ya estaríamos listos para relizar las predicciones sobre la imagen. Sin embargo, cabe recordar que primero debemos reescalar las imágenes y normalizarlas.
import torchvision.transforms as transforms
def transform_image(image):
my_transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image_aux = image
return my_transforms(image_aux).unsqueeze(0).to(device)
El siguiente paso consiste en transformar la imagen.
image = transforms.Resize((480,640))(img)
tensor = transform_image(image=image)
Ahora ya podemos realizar pasarle el objeto construido anteriormente al modelo para realizar la predicción.
model.to(device)
with torch.no_grad():
outputs = model(tensor)
outputs = torch.argmax(outputs,1)
Ahora almacenamos el resultado en un array y convertimos el índice asociado con la clase grape (que era 1) al valor 255.
mask = np.array(outputs.cpu())
mask[mask==1]=255
La predicción devuelta por el modelo es un vector de tamaño 480x640 por lo que tendremos que ponerla en forma de matriz.
mask=np.reshape(mask,(480,640))
Con esto ya podemos mostrar la máscara generada.
Image.fromarray(mask.astype('uint8'))
Podemos compararla con la máscara real.
PIL.Image.open('dataset/Labels/test/gt_154.png')
Como vemos el modelo se aproxima bastante, pero la segmentación no es excesivamente buena. En la práctica veremos cómo crear mejores modelos.