PyTorch torch.utils.data

2020-09-15 11:52 更新

原文: PyTorch torch.utils.datal

PyTorch 數(shù)據(jù)加載實(shí)用程序的核心是 torch.utils.data.DataLoader 類。 它表示可在數(shù)據(jù)集上迭代的 Python,并支持

這些選項(xiàng)由 DataLoader 的構(gòu)造函數(shù)參數(shù)配置,該參數(shù)具有簽名:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

以下各節(jié)詳細(xì)介紹了這些選項(xiàng)的效果和用法。

數(shù)據(jù)集類型

DataLoader 構(gòu)造函數(shù)的最重要參數(shù)是dataset,它指示要從中加載數(shù)據(jù)的數(shù)據(jù)集對(duì)象。 PyTorch 支持兩種不同類型的數(shù)據(jù)集:

地圖樣式數(shù)據(jù)集

映射樣式數(shù)據(jù)集是一種實(shí)現(xiàn)__getitem__()__len__()協(xié)議的數(shù)據(jù)集,它表示從(可能是非整數(shù))索引/關(guān)鍵字到數(shù)據(jù)樣本的映射。

例如,當(dāng)使用dataset[idx]訪問時(shí),此類數(shù)據(jù)集可以從磁盤上的文件夾中讀取第idx張圖像及其對(duì)應(yīng)的標(biāo)簽。

迭代式數(shù)據(jù)集

可迭代樣式的數(shù)據(jù)集是 IterableDataset 子類的實(shí)例,該子類實(shí)現(xiàn)了__iter__()協(xié)議,并表示數(shù)據(jù)樣本上的可迭代。 這種類型的數(shù)據(jù)集特別適用于隨機(jī)讀取價(jià)格昂貴甚至不大可能,并且批處理大小取決于所獲取數(shù)據(jù)的情況。

例如,這種數(shù)據(jù)集稱為iter(dataset)時(shí),可以返回從數(shù)據(jù)庫(kù),遠(yuǎn)程服務(wù)器甚至實(shí)時(shí)生成的日志中讀取的數(shù)據(jù)流。

注意

當(dāng)將 IterableDataset 與一起使用時(shí),多進(jìn)程數(shù)據(jù)加載。 在每個(gè)工作進(jìn)程上都復(fù)制相同的數(shù)據(jù)集對(duì)象,因此必須對(duì)副本進(jìn)行不同的配置,以避免重復(fù)的數(shù)據(jù)。 有關(guān)如何實(shí)現(xiàn)此功能的信息,請(qǐng)參見 IterableDataset 文檔。

數(shù)據(jù)加載順序和 Sampler

對(duì)于迭代式數(shù)據(jù)集,數(shù)據(jù)加載順序完全由用戶定義的迭代器控制。 這樣可以更輕松地實(shí)現(xiàn)塊讀取和動(dòng)態(tài)批次大小的實(shí)現(xiàn)(例如,通過每次生成一個(gè)批次的樣本)。

本節(jié)的其余部分涉及地圖樣式數(shù)據(jù)集的情況。 torch.utils.data.Sampler 類用于指定數(shù)據(jù)加載中使用的索引/鍵的順序。 它們代表數(shù)據(jù)集索引上的可迭代對(duì)象。 例如,在具有隨機(jī)梯度體面(SGD)的常見情況下, Sampler 可以隨機(jī)排列一列索引,一次生成每個(gè)索引,或者為小批量生成少量索引 新幣。

基于 DataLoadershuffle參數(shù),將自動(dòng)構(gòu)建順序采樣或混洗的采樣器。 或者,用戶可以使用sampler參數(shù)指定一個(gè)自定義 Sampler 對(duì)象,該對(duì)象每次都會(huì)產(chǎn)生要提取的下一個(gè)索引/關(guān)鍵字。

可以一次生成批量索引列表的自定義 Sampler 作為batch_sampler參數(shù)傳遞。 也可以通過batch_sizedrop_last參數(shù)啟用自動(dòng)批處理。 有關(guān)更多詳細(xì)信息,請(qǐng)參見下一部分的

Note

samplerbatch_sampler都不與可迭代樣式的數(shù)據(jù)集兼容,因?yàn)榇祟悢?shù)據(jù)集沒有鍵或索引的概念。

加載批處理和非批處理數(shù)據(jù)

DataLoader 支持通過參數(shù)batch_size,drop_lastbatch_sampler將各個(gè)提取的數(shù)據(jù)樣本自動(dòng)整理為批次。

自動(dòng)批處理(默認(rèn))

這是最常見的情況,對(duì)應(yīng)于獲取一小批數(shù)據(jù)并將其整理為批處理的樣本,即包含張量,其中一維為批處理維度(通常是第一維)。

當(dāng)batch_size(默認(rèn)1)不是None時(shí),數(shù)據(jù)加載器將生成批處理的樣本,而不是單個(gè)樣本。 batch_sizedrop_last參數(shù)用于指定數(shù)據(jù)加載器如何獲取數(shù)據(jù)集密鑰的批處理。 對(duì)于地圖樣式的數(shù)據(jù)集,用戶可以選擇指定batch_sampler,它一次生成一個(gè)鍵列表。

Note

batch_sizedrop_last自變量本質(zhì)上用于從sampler構(gòu)造batch_sampler。 對(duì)于地圖樣式的數(shù)據(jù)集,sampler由用戶提供或基于shuffle參數(shù)構(gòu)造。 對(duì)于可迭代樣式的數(shù)據(jù)集,sampler是一個(gè)虛擬的無限數(shù)據(jù)集。

Note

當(dāng)從可重復(fù)樣式數(shù)據(jù)集進(jìn)行多重處理提取時(shí),drop_last參數(shù)會(huì)刪除每個(gè)工作人員數(shù)據(jù)集副本的最后一個(gè)非完整批次。

使用來自采樣器的索引獲取樣本列表后,作為collate_fn參數(shù)傳遞的函數(shù)用于將樣本列表整理為批次。

在這種情況下,從地圖樣式數(shù)據(jù)集加載大致等效于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

從可迭代樣式的數(shù)據(jù)集加載大致等效于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定義collate_fn可用于自定義排序規(guī)則,例如,將順序數(shù)據(jù)填充到批處理的最大長(zhǎng)度。

禁用自動(dòng)批處理

在某些情況下,用戶可能希望以數(shù)據(jù)集代碼手動(dòng)處理批處理,或僅加載單個(gè)樣本。 例如,直接加載批處理的數(shù)據(jù)(例如,從數(shù)據(jù)庫(kù)中批量讀取或讀取連續(xù)的內(nèi)存塊)可能更便宜,或者批處理大小取決于數(shù)據(jù),或者該程序設(shè)計(jì)為可處理單個(gè)樣本。 在這種情況下,最好不要使用自動(dòng)批處理(其中collate_fn用于整理樣本),而應(yīng)讓數(shù)據(jù)加載器直接返回dataset對(duì)象的每個(gè)成員。

當(dāng)batch_sizebatch_sampler均為None時(shí)(batch_sampler的默認(rèn)值已為None),自動(dòng)批處理被禁用。 從dataset獲得的每個(gè)樣本都將作為collate_fn參數(shù)傳遞的函數(shù)進(jìn)行處理。

禁用自動(dòng)批處理時(shí),默認(rèn)值collate_fn僅將 NumPy 數(shù)組轉(zhuǎn)換為 PyTorch 張量,而其他所有內(nèi)容均保持不變。

In this case, loading from a map-style dataset is roughly equivalent with:

for index in sampler:
    yield collate_fn(dataset[index])

and loading from an iterable-style dataset is roughly equivalent with:

for data in iter(dataset):
    yield collate_fn(data)

使用collate_fn

啟用或禁用自動(dòng)批處理時(shí),collate_fn的使用略有不同。

禁用自動(dòng)批處理時(shí),將對(duì)每個(gè)單獨(dú)的數(shù)據(jù)樣本調(diào)用collate_fn,并且從數(shù)據(jù)加載器迭代器產(chǎn)生輸出。 在這種情況下,默認(rèn)的collate_fn僅轉(zhuǎn)換 PyTorch 張量中的 NumPy 數(shù)組。

啟用自動(dòng)批處理時(shí),會(huì)每次調(diào)用collate_fn并帶有數(shù)據(jù)樣本列表。 期望將輸入樣本整理為一批,以便從數(shù)據(jù)加載器迭代器中獲得收益。 本節(jié)的其余部分描述了這種情況下默認(rèn)collate_fn的行為。

例如,如果每個(gè)數(shù)據(jù)樣本都包含一個(gè) 3 通道圖像和一個(gè)整體類標(biāo)簽,即數(shù)據(jù)集的每個(gè)元素返回一個(gè)元組(image, class_index),則默認(rèn)值collate_fn將此類元組的列表整理為一個(gè)元組 批處理圖像張量和批處理類標(biāo)簽 Tensor。 特別是,默認(rèn)collate_fn具有以下屬性:

  • 它始終將新維度添加為批次維度。
  • 它會(huì)自動(dòng)將 NumPy 數(shù)組和 Python 數(shù)值轉(zhuǎn)換為 PyTorch 張量。
  • 它保留了數(shù)據(jù)結(jié)構(gòu),例如,如果每個(gè)樣本都是一個(gè)字典,它將輸出一個(gè)具有相同鍵集但將批處理 Tensors 作為值的字典(如果無法將這些值轉(zhuǎn)換為 Tensors,則將其列出)。 與listtuple,namedtuple等相同。

用戶可以使用自定義的collate_fn來實(shí)現(xiàn)自定義批處理,例如,沿除第一個(gè)維度之外的其他維度進(jìn)行校對(duì),各種長(zhǎng)度的填充序列或添加對(duì)自定義數(shù)據(jù)類型的支持。

單進(jìn)程和多進(jìn)程數(shù)據(jù)加載

默認(rèn)情況下, DataLoader 使用單進(jìn)程數(shù)據(jù)加載。

在 Python 進(jìn)程中,全局解釋器鎖定(GIL)阻止了跨線程真正的完全并行化 Python 代碼。 為了避免在加載數(shù)據(jù)時(shí)阻塞計(jì)算代碼,PyTorch 提供了一個(gè)簡(jiǎn)單的開關(guān),只需將參數(shù)num_workers設(shè)置為正整數(shù)即可執(zhí)行多進(jìn)程數(shù)據(jù)加載。

單進(jìn)程數(shù)據(jù)加載(默認(rèn))

在此模式下,以與初始化 DataLoader 相同的過程完成數(shù)據(jù)提取。 因此,數(shù)據(jù)加載可能會(huì)阻止計(jì)算。 然而,當(dāng)用于在進(jìn)程之間共享數(shù)據(jù)的資源(例如,共享存儲(chǔ)器,文件描述符)受到限制時(shí),或者當(dāng)整個(gè)數(shù)據(jù)集很小并且可以完全加載到存儲(chǔ)器中時(shí),該模式可能是優(yōu)選的。 此外,單進(jìn)程加載通常顯示更多可讀的錯(cuò)誤跟蹤,因此對(duì)于調(diào)試很有用。

多進(jìn)程數(shù)據(jù)加載

將參數(shù)num_workers設(shè)置為正整數(shù)將打開具有指定數(shù)量的加載程序工作進(jìn)程的多進(jìn)程數(shù)據(jù)加載。

在此模式下,每次創(chuàng)建 DataLoader 的迭代器時(shí)(例如,當(dāng)您調(diào)用enumerate(dataloader)時(shí)),都會(huì)創(chuàng)建num_workers工作進(jìn)程。 此時(shí),dataset,collate_fnworker_init_fn被傳遞給每個(gè)工作程序,在這里它們被用來初始化和獲取數(shù)據(jù)。 這意味著數(shù)據(jù)集訪問及其內(nèi)部 IO 轉(zhuǎn)換(包括collate_fn)在工作進(jìn)程中運(yùn)行。

torch.utils.data.get_worker_info() 在工作進(jìn)程中返回各種有用的信息(包括工作 ID,數(shù)據(jù)集副本,初始種子等),并在主進(jìn)程中返回None。 用戶可以在數(shù)據(jù)集代碼和/或worker_init_fn中使用此功能來分別配置每個(gè)數(shù)據(jù)集副本,并確定代碼是否正在工作進(jìn)程中運(yùn)行。 例如,這在分片數(shù)據(jù)集時(shí)特別有用。

對(duì)于地圖樣式的數(shù)據(jù)集,主過程使用sampler生成索引并將其發(fā)送給工作人員。 因此,任何隨機(jī)播放都是在主過程中完成的,該過程通過為索引分配索引來引導(dǎo)加載。

對(duì)于可迭代樣式的數(shù)據(jù)集,由于每個(gè)工作進(jìn)程都獲得dataset對(duì)象的副本,因此幼稚的多進(jìn)程加載通常會(huì)導(dǎo)致數(shù)據(jù)重復(fù)。 用戶可以使用 torch.utils.data.get_worker_info() 和/或worker_init_fn獨(dú)立配置每個(gè)副本。 (有關(guān)如何實(shí)現(xiàn)此操作的信息,請(qǐng)參見 IterableDataset 文檔。)出于類似的原因,在多進(jìn)程加載中,drop_last參數(shù)刪除每個(gè)工作程序的可迭代樣式數(shù)據(jù)集副本的最后一個(gè)非完整批次。

一旦迭代結(jié)束或迭代器被垃圾回收,工作器將關(guān)閉。

警告

通常不建議在多進(jìn)程加載中返回 CUDA 張量,因?yàn)樵谑褂?CUDA 和在并行處理中共享 CUDA 張量時(shí)存在很多微妙之處(請(qǐng)參見在并行處理中的 CUDA)。 相反,我們建議使用自動(dòng)內(nèi)存固定(即,設(shè)置pin_memory=True),該功能可以將數(shù)據(jù)快速傳輸?shù)街С?CUDA 的 GPU。

平臺(tái)特定的行為

由于工作程序依賴于 Python multiprocessing,因此與 Unix 相比,Windows 上的工作程序啟動(dòng)行為有所不同。

  • 在 Unix 上,fork()是默認(rèn)的multiprocessing啟動(dòng)方法。 使用fork(),童工通??梢灾苯油ㄟ^克隆的地址空間訪問dataset和 Python 參數(shù)函數(shù)。
  • 在 Windows 上,spawn()是默認(rèn)的multiprocessing啟動(dòng)方法。 使用spawn()啟動(dòng)另一個(gè)解釋器,該解釋器運(yùn)行您的主腳本,然后運(yùn)行內(nèi)部工作程序函數(shù),該函數(shù)通過序列化pickle接收dataset,collate_fn和其他參數(shù)。

這種獨(dú)立的序列化意味著您應(yīng)該采取兩個(gè)步驟來確保在使用多進(jìn)程數(shù)據(jù)加載時(shí)與 Windows 兼容:

  • 將您的大部分主腳本代碼包裝在if __name__ == '__main__':塊中,以確保在啟動(dòng)每個(gè)工作進(jìn)程時(shí),該腳本不會(huì)再次運(yùn)行(很可能會(huì)產(chǎn)生錯(cuò)誤)。 您可以在此處放置數(shù)據(jù)集和 DataLoader 實(shí)例創(chuàng)建邏輯,因?yàn)樗恍枰?worker 中重新執(zhí)行。
  • 確保在__main__檢查之外將任何自定義collate_fn,worker_init_fndataset代碼聲明為頂級(jí)定義。 這樣可以確保它們?cè)诠ぷ鬟M(jìn)程中可用。 (這是必需的,因?yàn)閷⒑瘮?shù)僅作為引用而不是bytecode進(jìn)行腌制。)

多進(jìn)程數(shù)據(jù)加載中的隨機(jī)性

默認(rèn)情況下,每個(gè)工作人員的 PyTorch 種子將設(shè)置為base_seed + worker_id,其中base_seed是主進(jìn)程使用其 RNG 生成的長(zhǎng)整數(shù)(因此,強(qiáng)制使用 RNG 狀態(tài))。 但是,初始化工作程序(例如 NumPy)時(shí),可能會(huì)復(fù)制其他庫(kù)的種子,導(dǎo)致每個(gè)工作程序返回相同的隨機(jī)數(shù)。

worker_init_fn中,您可以使用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 訪問每個(gè)工作人員的 PyTorch 種子集,并在加載數(shù)據(jù)之前使用它為其他庫(kù)提供種子。

內(nèi)存固定

主機(jī)到 GPU 副本源自固定(頁面鎖定)內(nèi)存時(shí),速度要快得多。 有關(guān)通常何時(shí)以及如何使用固定內(nèi)存的更多詳細(xì)信息,請(qǐng)參見使用固定內(nèi)存緩沖區(qū)。

對(duì)于數(shù)據(jù)加載,將pin_memory=True傳遞到 DataLoader 將自動(dòng)將獲取的數(shù)據(jù)張量放置在固定內(nèi)存中,從而更快地將數(shù)據(jù)傳輸?shù)絾⒂?CUDA 的 GPU。

默認(rèn)的內(nèi)存固定邏輯僅識(shí)別張量以及包含張量的映射和可迭代對(duì)象。 默認(rèn)情況下,如果固定邏輯看到一個(gè)自定義類型的批處理(如果您有一個(gè)collate_fn返回自定義批處理類型,則會(huì)發(fā)生),或者如果該批處理的每個(gè)元素都是自定義類型,則固定邏輯將 無法識(shí)別它們,它將返回該批處理(或那些元素)而不固定內(nèi)存。 要為自定義批處理或數(shù)據(jù)類型啟用內(nèi)存固定,請(qǐng)?jiān)谧远x類型上定義pin_memory()方法。

請(qǐng)參見下面的示例。

例:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)


    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self


def collate_wrapper(batch):
    return SimpleCustomBatch(batch)


inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)


loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)


for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)?

數(shù)據(jù)加載器。 組合數(shù)據(jù)集和采樣器,并在給定的數(shù)據(jù)集上提供可迭代的。

DataLoader 支持地圖樣式和可迭代樣式的數(shù)據(jù)集,具有單進(jìn)程或多進(jìn)程加載,自定義加載順序以及可選的自動(dòng)批處理(歸類)和內(nèi)存固定。

有關(guān)更多詳細(xì)信息,請(qǐng)參見 torch.utils.data 文檔頁面。

參數(shù)

  • 數(shù)據(jù)集 (數(shù)據(jù)集)–要從中加載數(shù)據(jù)的數(shù)據(jù)集。
  • batch_size (python:int , 可選)–每批次要加載多少個(gè)樣本(默認(rèn)值:1)。
  • 隨機(jī)播放 (bool , 可選)–設(shè)置為True以使數(shù)據(jù)在每個(gè)時(shí)間段都重新隨機(jī)播放(默認(rèn)值:False )。
  • 采樣器 (采樣器 , 可選)–定義了從數(shù)據(jù)集中抽取樣本的策略。 如果指定,則shuffle必須為False。
  • batch_sampler (采樣器 , 可選)–類似sampler,但在 時(shí)間。 與batch_sizeshuffle,samplerdrop_last互斥。
  • num_workers (python:int 可選)–多少個(gè)子進(jìn)程用于數(shù)據(jù)加載。 0表示將在主進(jìn)程中加載數(shù)據(jù)。 (默認(rèn):0
  • collate_fn (可調(diào)用的, 可選)–合并樣本列表以形成張量的小批量。 在從地圖樣式數(shù)據(jù)集中使用批量加載時(shí)使用。
  • pin_memory (bool , 可選)–如果True,則數(shù)據(jù)加載器將張量復(fù)制到 CUDA 固定的內(nèi)存中,然后返回。 如果您的數(shù)據(jù)元素是自定義類型,或者您的collate_fn返回的是自定義類型的批次,請(qǐng)參見下面的示例。
  • drop_last (布爾 , 可選)–設(shè)置為True以刪除最后不完整的批次,如果數(shù)據(jù)集大小不可分割 按批次大小。 如果False并且數(shù)據(jù)集的大小不能被批次大小整除,那么最后一批將較小。 (默認(rèn):False
  • 超時(shí)(數(shù)字 , 可選)–如果為正,則表示從工作人員處收集批次的超時(shí)值。 應(yīng)始終為非負(fù)數(shù)。 (默認(rèn):0
  • worker_init_fn (可調(diào)用 , 可選)–如果不是None,則將在每個(gè)具有工作人員 ID (在播種之后和數(shù)據(jù)加載之前,將[0, num_workers - 1]中的 int 作為輸入。 (默認(rèn):None

Warning

如果使用spawn啟動(dòng)方法,則worker_init_fn不能是不可拾取的對(duì)象,例如 lambda 函數(shù)。 有關(guān) PyTorch 中與并行處理有關(guān)的更多詳細(xì)信息,請(qǐng)參見并行處理最佳實(shí)踐。

Note

len(dataloader)啟發(fā)式方法基于所用采樣器的長(zhǎng)度。 當(dāng)datasetIterableDataset 時(shí),無論多進(jìn)程加載配置如何,都將返回len(dataset)(如果實(shí)現(xiàn)),因?yàn)?PyTorch 信任用戶dataset代碼可以正確處理多進(jìn)程加載 避免重復(fù)數(shù)據(jù)。 有關(guān)這兩種類型的數(shù)據(jù)集以及 IterableDataset 如何與多進(jìn)程數(shù)據(jù)加載交互的更多詳細(xì)信息,請(qǐng)參見數(shù)據(jù)集類型。

class torch.utils.data.Dataset?

表示 Dataset 的抽象類。

代表從鍵到數(shù)據(jù)樣本的映射的所有數(shù)據(jù)集都應(yīng)將其子類化。 所有子類都應(yīng)該覆蓋__getitem__(),支持為給定鍵獲取數(shù)據(jù)樣本。 子類還可以選擇覆蓋__len__(),它有望通過許多 Sampler 實(shí)現(xiàn)以及 DataLoader 的默認(rèn)選項(xiàng)返回?cái)?shù)據(jù)集的大小。

Note

默認(rèn)情況下, DataLoader 構(gòu)造一個(gè)索引采樣器,該采樣器產(chǎn)生整數(shù)索引。 要使其與具有非整數(shù)索引/鍵的地圖樣式數(shù)據(jù)集一起使用,必須提供自定義采樣器。

class torch.utils.data.IterableDataset?

可迭代的數(shù)據(jù)集。

代表可迭代數(shù)據(jù)樣本的所有數(shù)據(jù)集都應(yīng)將其子類化。 當(dāng)數(shù)據(jù)來自流時(shí),這種形式的數(shù)據(jù)集特別有用。

所有子類都應(yīng)覆蓋__iter__(),這將返回此數(shù)據(jù)集中的樣本迭代器。

當(dāng)子類與 DataLoader 一起使用時(shí),數(shù)據(jù)集中的每個(gè)項(xiàng)目都將由 DataLoader 迭代器產(chǎn)生。 當(dāng)num_workers > 0時(shí),每個(gè)工作進(jìn)程將具有數(shù)據(jù)集對(duì)象的不同副本,因此通常需要獨(dú)立配置每個(gè)副本,以避免從工作進(jìn)程返回重復(fù)的數(shù)據(jù)。 get_worker_info() 在工作程序進(jìn)程中調(diào)用時(shí),返回有關(guān)工作程序的信息。 可以在數(shù)據(jù)集的__iter__()方法或 DataLoaderworker_init_fn選項(xiàng)中使用它來修改每個(gè)副本的行為。

示例 1:在__iter__()中將工作負(fù)載分配給所有工作人員:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)


>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]


>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]


>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

示例 2:使用worker_init_fn在所有工作人員之間分配工作量:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)


>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]


>>> # Define a `worker_init_fn` that configures each dataset  differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...


>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]


>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

class torch.utils.data.TensorDataset(*tensors)?

數(shù)據(jù)集包裝張量。

每個(gè)樣本將通過沿第一維索引張量來檢索。

Parameters

張量 (tensor*)–具有與第一維相同大小的張量。

class torch.utils.data.ConcatDataset(datasets)?

數(shù)據(jù)集是多個(gè)數(shù)據(jù)集的串聯(lián)。

此類對(duì)于組裝不同的現(xiàn)有數(shù)據(jù)集很有用。

Parameters

數(shù)據(jù)集(序列)–要連接的數(shù)據(jù)集列表

class torch.utils.data.ChainDataset(datasets)?

用于鏈接多個(gè) IterableDataset 的數(shù)據(jù)集。

此類對(duì)于組裝不同的現(xiàn)有數(shù)據(jù)集流很有用。 鏈接操作是即時(shí)完成的,因此將大型數(shù)據(jù)集與此類連接起來將非常有效。

Parameters

數(shù)據(jù)集(IterableDataset 的可迭代)–鏈接在一起的數(shù)據(jù)集

class torch.utils.data.Subset(dataset, indices)?

指定索引處的數(shù)據(jù)集子集。

Parameters

  • 數(shù)據(jù)集 (數(shù)據(jù)集)–整個(gè)數(shù)據(jù)集
  • 索引(序列)–為子集選擇的整個(gè)集合中的索引

torch.utils.data.get_worker_info()?

返回有關(guān)當(dāng)前 DataLoader 迭代器工作進(jìn)程的信息。

在工作線程中調(diào)用時(shí),此方法返回一個(gè)保證具有以下屬性的對(duì)象:

  • id:當(dāng)前工作人員 ID。
  • num_workers:工人總數(shù)。
  • seed:當(dāng)前工作程序的隨機(jī)種子集。 該值由主進(jìn)程 RNG 和工作程序 ID 確定。 有關(guān)更多詳細(xì)信息,請(qǐng)參見 DataLoader 的文檔。
  • dataset:此流程在中的數(shù)據(jù)集對(duì)象的副本。 請(qǐng)注意,在不同的過程中,這將是與主過程中的對(duì)象不同的對(duì)象。

在主進(jìn)程中調(diào)用時(shí),將返回None

Note

在傳遞給 DataLoaderworker_init_fn中使用時(shí),此方法可用于不同地設(shè)置每個(gè)工作進(jìn)程,例如,使用worker_iddataset對(duì)象配置為僅讀取 分片數(shù)據(jù)集的特定部分,或使用seed播種數(shù)據(jù)集代碼中使用的其他庫(kù)(例如 NumPy)。

torch.utils.data.random_split(dataset, lengths)?

將數(shù)據(jù)集隨機(jī)拆分為給定長(zhǎng)度的不重疊的新數(shù)據(jù)集。

Parameters

  • 數(shù)據(jù)集 (數(shù)據(jù)集)–要拆分的數(shù)據(jù)集
  • 長(zhǎng)度(序列)–要產(chǎn)生的分割的長(zhǎng)度

class torch.utils.data.Sampler(data_source)?

所有采樣器的基類。

每個(gè) Sampler 子類都必須提供__iter__()方法(提供一種對(duì)數(shù)據(jù)集元素的索引進(jìn)行迭代的方法)和__len__()方法,該方法返回返回的迭代器的長(zhǎng)度。

Note

DataLoader 并非嚴(yán)格要求__len__()方法,但在涉及 DataLoader 長(zhǎng)度的任何計(jì)算中都應(yīng)采用。

class torch.utils.data.SequentialSampler(data_source)?

始終以相同順序順序采樣元素。

Parameters

data_source (數(shù)據(jù)集)–要從中采樣的數(shù)據(jù)集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)?

隨機(jī)采樣元素。 如果不進(jìn)行替換,則從經(jīng)過改組的數(shù)據(jù)集中采樣。 如果要更換,則用戶可以指定num_samples進(jìn)行繪制。

Parameters

  • data_source (Dataset) – dataset to sample from
  • 替換 (bool )–如果True為默認(rèn)值,則替換為True
  • num_samples (python:int )–要繪制的樣本數(shù),默認(rèn)為 len(dataset)。 僅當(dāng)<cite>替換</cite>為True時(shí)才應(yīng)指定此參數(shù)。

class torch.utils.data.SubsetRandomSampler(indices)?

從給定的索引列表中隨機(jī)抽樣元素,而無需替換。

Parameters

索引(序列)–索引序列

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)?

以給定的概率(權(quán)重)從[0,..,len(weights)-1]中采樣元素。

Parameters

  • 權(quán)重(序列)–權(quán)重序列,不必累加一個(gè)
  • num_samples (python:int )–要繪制的樣本數(shù)
  • 替代品 (bool )–如果True,則抽取替代品抽取樣品。 如果沒有,則它們將被替換而不會(huì)被繪制,這意味著當(dāng)為一行繪制樣本索引時(shí),無法為該行再次繪制它。

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)?

包裝另一個(gè)采樣器以產(chǎn)生一個(gè)小批量的索引。

Parameters

  • 采樣器 (采樣器)–基本采樣器。
  • batch_size (python:int )–迷你批量的大小。
  • drop_last (bool )–如果為True,則采樣器將丟棄最后一批,如果其大小小于batch_size

Example

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)?

將數(shù)據(jù)加載限制為數(shù)據(jù)集子集的采樣器。

torch.nn.parallel.DistributedDataParallel 結(jié)合使用時(shí)特別有用。 在這種情況下,每個(gè)進(jìn)程都可以將 DistributedSampler 實(shí)例作為 DataLoader 采樣器傳遞,并加載原始數(shù)據(jù)集的專有子集。

Note

假定數(shù)據(jù)集大小恒定。

Parameters

  • 數(shù)據(jù)集 –用于采樣的數(shù)據(jù)集。
  • num_replicas (可選)–參與分布式訓(xùn)練的進(jìn)程數(shù)。
  • 等級(jí)(可選)–當(dāng)前進(jìn)程在 num_replicas 中的等級(jí)。
  • 隨機(jī)播放(可選)–如果為 true(默認(rèn)值),采樣器將隨機(jī)播放索引
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)