memo: Datasets load multiple data files

Table of contents

DataLoader

(2024-05-29)

PyTorch DataLoader工作原理可视化 collate_fn


torch.utils.data.IterableDataset

pytorch forum 2020-02-26; Docs


ConcatDataset

Docs

  1. Example 1: Pytorch DataLoader multiple data source - SO;

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    
    import os
    import torch.utils.data as data
    
    class SingeJsonDataset(data.Dataset):
        # implement a single json dataset here...
        ...
    
    list_of_datasets = []
    for j in os.path.listdir(root_dir):
        if not j.endswith('.json'):
            continue  # skip non-json files
        list_of_datasets.append(SingeJsonDataset(json_file=j, root_dir=root_dir, transform=None))
    
    # once all single json datasets are created you can concat them into a single one:
    multiple_json_dataset = data.ConcatDataset(list_of_datasets)
    
  2. Example 2: PyTorch forum - Praateek

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    
    class LazyTextDataset(Dataset):
        def __init__(self, filename):
            self._filename = filename
            self._total_data = int(subprocess.check_output("wc -l " + filename, shell=True).split()[0]) - 1
    
        def __getitem__(self, idx):
            line = linecache.getline(self._filename, idx + 1)
            csv_line = csv.reader([line])
            return next(csv_line)
    
        def __len__(self):
            return self._total_data
    
    path = /where_csv_files_are_dumped/
    files = list(map(lambda x : path + x, (filter(lambda x : x.endswith("csv"), os.listdir(path)))))
    datasets = list(map(lambda x : LazyTextDataset(x), files))
    dataset = ConcatDataset(datasets)
    

Comments of Thomans Ahle:

The problem with ConcatDataset is that it doesn’t work with multiprocessing. It calls len(ds) on each dataset in it’s initializer, so you end up loading every dataset in the main process.


np.load(path, mmap_mode=‘r’)

Load multiple .npy files (size > 10GB) in pytorch - SO

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import torch
from bisect import bisect
import os, psutil # used to monitor memory usage

class BigDataset(torch.utils.data.Dataset):
    def __init__(self, data_paths, target_paths):
        self.data_memmaps = [np.load(path, mmap_mode='r') for path in data_paths]
        self.target_memmaps = [np.load(path, mmap_mode='r') for path in target_paths]
        self.start_indices = [0] * len(data_paths)
        self.data_count = 0
        for index, memmap in enumerate(self.data_memmaps):
            self.start_indices[index] = self.data_count
            self.data_count += memmap.shape[0]

    def __len__(self):
        return self.data_count

    def __getitem__(self, index):
        memmap_index = bisect(self.start_indices, index) - 1
        index_in_memmap = index - self.start_indices[memmap_index]
        data = self.data_memmaps[memmap_index][index_in_memmap]
        target = self.target_memmaps[memmap_index][index_in_memmap]
        return index, torch.from_numpy(data), torch.from_numpy(target)

# Test Code
if __name__ == "__main__":
    data_paths = [f'data/d{index}.npy' for index in range(10)]
    target_paths = [f'data/s{index}.npy' for index in range(10)]

    process = psutil.Process(os.getpid())
    memory_before = process.memory_info().rss

    dataset = BigDataset(data_paths, target_paths)

    used_memory = process.memory_info().rss - memory_before
    print("Used memory:", used_memory, "bytes")

    dataset_size = len(dataset)
    print("Dataset size:", dataset_size)
    print("Samples:")
    for sample_index in [0, dataset_size//2, dataset_size-1]:
        print(dataset[sample_index])
Built with Hugo
Theme Stack designed by Jimmy