PyTorch - conjuntos de dados

Neste capítulo, vamos nos concentrar mais em torchvision.datasetse seus vários tipos. PyTorch inclui os seguintes carregadores de conjunto de dados -

  • MNIST
  • COCO (legendagem e detecção)

O conjunto de dados inclui a maioria dos dois tipos de funções fornecidas abaixo -

  • Transform- uma função que obtém uma imagem e retorna uma versão modificada do material padrão. Eles podem ser compostos junto com as transformações.

  • Target_transform- uma função que pega o alvo e o transforma. Por exemplo, recebe a string de legenda e retorna um tensor de índices mundiais.

MNIST

A seguir está o código de amostra para o conjunto de dados MNIST -

dset.MNIST(root, train = TRUE, transform = NONE, 
target_transform = None, download = FALSE)

Os parâmetros são os seguintes -

  • root - diretório raiz do conjunto de dados onde existem os dados processados.

  • train - Verdadeiro = Conjunto de treinamento, Falso = Conjunto de teste

  • download - Verdadeiro = baixa o conjunto de dados da Internet e o coloca na raiz.

COCO

Isso requer que a API COCO seja instalada. O exemplo a seguir é usado para demonstrar a implementação COCO do conjunto de dados usando PyTorch -

import torchvision.dataset as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = ‘ dir where images are’, 
annFile = ’json annotation file’,
transform = transforms.ToTensor())
print(‘Number of samples: ‘, len(cap))
print(target)

O resultado obtido é o seguinte -

Number of samples: 82783
Image Size: (3L, 427L, 640L)