添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
This works if you have image dataset in .tar file.
You have to have a .csv or .txt file, including a name per line of your dataset.

from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
import numpy as np
import random
import tarfile
import io
import os
import pandas as pd
from torch.utils.data import Dataset
import torch
class YourDataset(Dataset):
    def __init__(self, txt_path='filelist.txt', img_dir='data.tar', transform=None):
        Initialize data set as a list of IDs corresponding to each item of data set
        :param img_dir: path to image files as a uncompressed tar archive
        :param txt_path: a text file containing names of all of images line by line
        :param transform: apply some transforms like cropping, rotating, etc on input image
        df = pd.read_csv(txt_path, sep=' ', index_col=0)
        self.img_names = df.index.values
        self.txt_path = txt_path
        self.img_dir = img_dir
        self.transform = transform
        self.to_tensor = ToTensor()
        self.to_pil = ToPILImage()
        self.tf = tarfile.open(self.img_dir)
    def get_image_from_tar(self, name):
        Gets a image by a name gathered from file list csv file
        :param name: name of targeted image
        :return: a PIL image
        image = self.tf.extractfile(name)
        image = image.read()
        image = Image.open(io.BytesIO(image))
        return image
    def __len__(self):
        Return the length of data set using list of IDs
        :return: number of samples in data set
        return len(self.img_names)
    def __getitem__(self, index):
        Generate one item of data set.
        :param index: index of item in IDs list
        :return: a sample of data as a dict
        if index == (self.__len__() - 1) :  # close tarfile opened in __init__
            self.tf.close()
        image = self.get_image_from_tar(self.img_names[index])
        if self.transform is not None:
            image = self.transform(image)
        sample = {'X': image}
        return sample
              

Hello, Thank you for your solution. I have an issue when create dataloader with dataset from tar archive.
It gives me this error, if I run dataloader more than one time. Can you help me with that problem?

OSError: TarFile is closed
              

When i tried this method of extracting particular images inside getItem method for TinyImageNet dataset, I faced issues with zlib during data_loader collation. The same error disappeared once I extracted images in init and stored them in a dict, and then in getItem only did a transformation and returned the image and target.

That is possible, I just used this implementation for a large tar file on a single machine. It may not work in few different situations.
But still the best case even for large datasets is to extract entire dataset and loop over it. Here is the original post that lead to this code: