PyTorch 數(shù)據(jù)加載實(shí)用程序的核心是 torch.utils.data.DataLoader
類。 它表示可在數(shù)據(jù)集上迭代的 Python,并支持
這些選項(xiàng)由 DataLoader
的構(gòu)造函數(shù)參數(shù)配置,該參數(shù)具有簽名:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
以下各節(jié)詳細(xì)介紹了這些選項(xiàng)的效果和用法。
DataLoader
構(gòu)造函數(shù)的最重要參數(shù)是dataset
,它指示要從中加載數(shù)據(jù)的數(shù)據(jù)集對(duì)象。 PyTorch 支持兩種不同類型的數(shù)據(jù)集:
映射樣式數(shù)據(jù)集是一種實(shí)現(xiàn)__getitem__()
和__len__()
協(xié)議的數(shù)據(jù)集,它表示從(可能是非整數(shù))索引/關(guān)鍵字到數(shù)據(jù)樣本的映射。
例如,當(dāng)使用dataset[idx]
訪問時(shí),此類數(shù)據(jù)集可以從磁盤上的文件夾中讀取第idx
張圖像及其對(duì)應(yīng)的標(biāo)簽。
可迭代樣式的數(shù)據(jù)集是 IterableDataset
子類的實(shí)例,該子類實(shí)現(xiàn)了__iter__()
協(xié)議,并表示數(shù)據(jù)樣本上的可迭代。 這種類型的數(shù)據(jù)集特別適用于隨機(jī)讀取價(jià)格昂貴甚至不大可能,并且批處理大小取決于所獲取數(shù)據(jù)的情況。
例如,這種數(shù)據(jù)集稱為iter(dataset)
時(shí),可以返回從數(shù)據(jù)庫(kù),遠(yuǎn)程服務(wù)器甚至實(shí)時(shí)生成的日志中讀取的數(shù)據(jù)流。
注意
當(dāng)將 IterableDataset
與一起使用時(shí),多進(jìn)程數(shù)據(jù)加載。 在每個(gè)工作進(jìn)程上都復(fù)制相同的數(shù)據(jù)集對(duì)象,因此必須對(duì)副本進(jìn)行不同的配置,以避免重復(fù)的數(shù)據(jù)。 有關(guān)如何實(shí)現(xiàn)此功能的信息,請(qǐng)參見 IterableDataset
文檔。
Sampler
對(duì)于迭代式數(shù)據(jù)集,數(shù)據(jù)加載順序完全由用戶定義的迭代器控制。 這樣可以更輕松地實(shí)現(xiàn)塊讀取和動(dòng)態(tài)批次大小的實(shí)現(xiàn)(例如,通過每次生成一個(gè)批次的樣本)。
本節(jié)的其余部分涉及地圖樣式數(shù)據(jù)集的情況。 torch.utils.data.Sampler
類用于指定數(shù)據(jù)加載中使用的索引/鍵的順序。 它們代表數(shù)據(jù)集索引上的可迭代對(duì)象。 例如,在具有隨機(jī)梯度體面(SGD)的常見情況下, Sampler
可以隨機(jī)排列一列索引,一次生成每個(gè)索引,或者為小批量生成少量索引 新幣。
基于 DataLoader
的shuffle
參數(shù),將自動(dòng)構(gòu)建順序采樣或混洗的采樣器。 或者,用戶可以使用sampler
參數(shù)指定一個(gè)自定義 Sampler
對(duì)象,該對(duì)象每次都會(huì)產(chǎn)生要提取的下一個(gè)索引/關(guān)鍵字。
可以一次生成批量索引列表的自定義 Sampler
作為batch_sampler
參數(shù)傳遞。 也可以通過batch_size
和drop_last
參數(shù)啟用自動(dòng)批處理。 有關(guān)更多詳細(xì)信息,請(qǐng)參見下一部分的。
Note
sampler
和batch_sampler
都不與可迭代樣式的數(shù)據(jù)集兼容,因?yàn)榇祟悢?shù)據(jù)集沒有鍵或索引的概念。
DataLoader
支持通過參數(shù)batch_size
,drop_last
和batch_sampler
將各個(gè)提取的數(shù)據(jù)樣本自動(dòng)整理為批次。
這是最常見的情況,對(duì)應(yīng)于獲取一小批數(shù)據(jù)并將其整理為批處理的樣本,即包含張量,其中一維為批處理維度(通常是第一維)。
當(dāng)batch_size
(默認(rèn)1
)不是None
時(shí),數(shù)據(jù)加載器將生成批處理的樣本,而不是單個(gè)樣本。 batch_size
和drop_last
參數(shù)用于指定數(shù)據(jù)加載器如何獲取數(shù)據(jù)集密鑰的批處理。 對(duì)于地圖樣式的數(shù)據(jù)集,用戶可以選擇指定batch_sampler
,它一次生成一個(gè)鍵列表。
Note
batch_size
和drop_last
自變量本質(zhì)上用于從sampler
構(gòu)造batch_sampler
。 對(duì)于地圖樣式的數(shù)據(jù)集,sampler
由用戶提供或基于shuffle
參數(shù)構(gòu)造。 對(duì)于可迭代樣式的數(shù)據(jù)集,sampler
是一個(gè)虛擬的無限數(shù)據(jù)集。
Note
當(dāng)從可重復(fù)樣式數(shù)據(jù)集進(jìn)行多重處理提取時(shí),drop_last
參數(shù)會(huì)刪除每個(gè)工作人員數(shù)據(jù)集副本的最后一個(gè)非完整批次。
使用來自采樣器的索引獲取樣本列表后,作為collate_fn
參數(shù)傳遞的函數(shù)用于將樣本列表整理為批次。
在這種情況下,從地圖樣式數(shù)據(jù)集加載大致等效于:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
從可迭代樣式的數(shù)據(jù)集加載大致等效于:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
自定義collate_fn
可用于自定義排序規(guī)則,例如,將順序數(shù)據(jù)填充到批處理的最大長(zhǎng)度。
在某些情況下,用戶可能希望以數(shù)據(jù)集代碼手動(dòng)處理批處理,或僅加載單個(gè)樣本。 例如,直接加載批處理的數(shù)據(jù)(例如,從數(shù)據(jù)庫(kù)中批量讀取或讀取連續(xù)的內(nèi)存塊)可能更便宜,或者批處理大小取決于數(shù)據(jù),或者該程序設(shè)計(jì)為可處理單個(gè)樣本。 在這種情況下,最好不要使用自動(dòng)批處理(其中collate_fn
用于整理樣本),而應(yīng)讓數(shù)據(jù)加載器直接返回dataset
對(duì)象的每個(gè)成員。
當(dāng)batch_size
和batch_sampler
均為None
時(shí)(batch_sampler
的默認(rèn)值已為None
),自動(dòng)批處理被禁用。 從dataset
獲得的每個(gè)樣本都將作為collate_fn
參數(shù)傳遞的函數(shù)進(jìn)行處理。
禁用自動(dòng)批處理時(shí),默認(rèn)值collate_fn
僅將 NumPy 數(shù)組轉(zhuǎn)換為 PyTorch 張量,而其他所有內(nèi)容均保持不變。
In this case, loading from a map-style dataset is roughly equivalent with:
for index in sampler:
yield collate_fn(dataset[index])
and loading from an iterable-style dataset is roughly equivalent with:
for data in iter(dataset):
yield collate_fn(data)
collate_fn
啟用或禁用自動(dòng)批處理時(shí),collate_fn
的使用略有不同。
禁用自動(dòng)批處理時(shí),將對(duì)每個(gè)單獨(dú)的數(shù)據(jù)樣本調(diào)用collate_fn
,并且從數(shù)據(jù)加載器迭代器產(chǎn)生輸出。 在這種情況下,默認(rèn)的collate_fn
僅轉(zhuǎn)換 PyTorch 張量中的 NumPy 數(shù)組。
啟用自動(dòng)批處理時(shí),會(huì)每次調(diào)用collate_fn
并帶有數(shù)據(jù)樣本列表。 期望將輸入樣本整理為一批,以便從數(shù)據(jù)加載器迭代器中獲得收益。 本節(jié)的其余部分描述了這種情況下默認(rèn)collate_fn
的行為。
例如,如果每個(gè)數(shù)據(jù)樣本都包含一個(gè) 3 通道圖像和一個(gè)整體類標(biāo)簽,即數(shù)據(jù)集的每個(gè)元素返回一個(gè)元組(image, class_index)
,則默認(rèn)值collate_fn
將此類元組的列表整理為一個(gè)元組 批處理圖像張量和批處理類標(biāo)簽 Tensor。 特別是,默認(rèn)collate_fn
具有以下屬性:
list
,tuple
,namedtuple
等相同。
用戶可以使用自定義的collate_fn
來實(shí)現(xiàn)自定義批處理,例如,沿除第一個(gè)維度之外的其他維度進(jìn)行校對(duì),各種長(zhǎng)度的填充序列或添加對(duì)自定義數(shù)據(jù)類型的支持。
默認(rèn)情況下, DataLoader
使用單進(jìn)程數(shù)據(jù)加載。
在 Python 進(jìn)程中,全局解釋器鎖定(GIL)阻止了跨線程真正的完全并行化 Python 代碼。 為了避免在加載數(shù)據(jù)時(shí)阻塞計(jì)算代碼,PyTorch 提供了一個(gè)簡(jiǎn)單的開關(guān),只需將參數(shù)num_workers
設(shè)置為正整數(shù)即可執(zhí)行多進(jìn)程數(shù)據(jù)加載。
在此模式下,以與初始化 DataLoader
相同的過程完成數(shù)據(jù)提取。 因此,數(shù)據(jù)加載可能會(huì)阻止計(jì)算。 然而,當(dāng)用于在進(jìn)程之間共享數(shù)據(jù)的資源(例如,共享存儲(chǔ)器,文件描述符)受到限制時(shí),或者當(dāng)整個(gè)數(shù)據(jù)集很小并且可以完全加載到存儲(chǔ)器中時(shí),該模式可能是優(yōu)選的。 此外,單進(jìn)程加載通常顯示更多可讀的錯(cuò)誤跟蹤,因此對(duì)于調(diào)試很有用。
將參數(shù)num_workers
設(shè)置為正整數(shù)將打開具有指定數(shù)量的加載程序工作進(jìn)程的多進(jìn)程數(shù)據(jù)加載。
在此模式下,每次創(chuàng)建 DataLoader
的迭代器時(shí)(例如,當(dāng)您調(diào)用enumerate(dataloader)
時(shí)),都會(huì)創(chuàng)建num_workers
工作進(jìn)程。 此時(shí),dataset
,collate_fn
和worker_init_fn
被傳遞給每個(gè)工作程序,在這里它們被用來初始化和獲取數(shù)據(jù)。 這意味著數(shù)據(jù)集訪問及其內(nèi)部 IO 轉(zhuǎn)換(包括collate_fn
)在工作進(jìn)程中運(yùn)行。
torch.utils.data.get_worker_info()
在工作進(jìn)程中返回各種有用的信息(包括工作 ID,數(shù)據(jù)集副本,初始種子等),并在主進(jìn)程中返回None
。 用戶可以在數(shù)據(jù)集代碼和/或worker_init_fn
中使用此功能來分別配置每個(gè)數(shù)據(jù)集副本,并確定代碼是否正在工作進(jìn)程中運(yùn)行。 例如,這在分片數(shù)據(jù)集時(shí)特別有用。
對(duì)于地圖樣式的數(shù)據(jù)集,主過程使用sampler
生成索引并將其發(fā)送給工作人員。 因此,任何隨機(jī)播放都是在主過程中完成的,該過程通過為索引分配索引來引導(dǎo)加載。
對(duì)于可迭代樣式的數(shù)據(jù)集,由于每個(gè)工作進(jìn)程都獲得dataset
對(duì)象的副本,因此幼稚的多進(jìn)程加載通常會(huì)導(dǎo)致數(shù)據(jù)重復(fù)。 用戶可以使用 torch.utils.data.get_worker_info()
和/或worker_init_fn
獨(dú)立配置每個(gè)副本。 (有關(guān)如何實(shí)現(xiàn)此操作的信息,請(qǐng)參見 IterableDataset
文檔。)出于類似的原因,在多進(jìn)程加載中,drop_last
參數(shù)刪除每個(gè)工作程序的可迭代樣式數(shù)據(jù)集副本的最后一個(gè)非完整批次。
一旦迭代結(jié)束或迭代器被垃圾回收,工作器將關(guān)閉。
警告
通常不建議在多進(jìn)程加載中返回 CUDA 張量,因?yàn)樵谑褂?CUDA 和在并行處理中共享 CUDA 張量時(shí)存在很多微妙之處(請(qǐng)參見在并行處理中的 CUDA)。 相反,我們建議使用自動(dòng)內(nèi)存固定(即,設(shè)置pin_memory=True
),該功能可以將數(shù)據(jù)快速傳輸?shù)街С?CUDA 的 GPU。
由于工作程序依賴于 Python multiprocessing
,因此與 Unix 相比,Windows 上的工作程序啟動(dòng)行為有所不同。
fork()
是默認(rèn)的multiprocessing
啟動(dòng)方法。 使用fork()
,童工通??梢灾苯油ㄟ^克隆的地址空間訪問dataset
和 Python 參數(shù)函數(shù)。spawn()
是默認(rèn)的multiprocessing
啟動(dòng)方法。 使用spawn()
啟動(dòng)另一個(gè)解釋器,該解釋器運(yùn)行您的主腳本,然后運(yùn)行內(nèi)部工作程序函數(shù),該函數(shù)通過序列化pickle
接收dataset
,collate_fn
和其他參數(shù)。這種獨(dú)立的序列化意味著您應(yīng)該采取兩個(gè)步驟來確保在使用多進(jìn)程數(shù)據(jù)加載時(shí)與 Windows 兼容:
if __name__ == '__main__':
塊中,以確保在啟動(dòng)每個(gè)工作進(jìn)程時(shí),該腳本不會(huì)再次運(yùn)行(很可能會(huì)產(chǎn)生錯(cuò)誤)。 您可以在此處放置數(shù)據(jù)集和 DataLoader
實(shí)例創(chuàng)建邏輯,因?yàn)樗恍枰?worker 中重新執(zhí)行。__main__
檢查之外將任何自定義collate_fn
,worker_init_fn
或dataset
代碼聲明為頂級(jí)定義。 這樣可以確保它們?cè)诠ぷ鬟M(jìn)程中可用。 (這是必需的,因?yàn)閷⒑瘮?shù)僅作為引用而不是bytecode
進(jìn)行腌制。)
默認(rèn)情況下,每個(gè)工作人員的 PyTorch 種子將設(shè)置為base_seed + worker_id
,其中base_seed
是主進(jìn)程使用其 RNG 生成的長(zhǎng)整數(shù)(因此,強(qiáng)制使用 RNG 狀態(tài))。 但是,初始化工作程序(例如 NumPy)時(shí),可能會(huì)復(fù)制其他庫(kù)的種子,導(dǎo)致每個(gè)工作程序返回相同的隨機(jī)數(shù)。
在worker_init_fn
中,您可以使用 torch.utils.data.get_worker_info().seed
或 torch.initial_seed()
訪問每個(gè)工作人員的 PyTorch 種子集,并在加載數(shù)據(jù)之前使用它為其他庫(kù)提供種子。
主機(jī)到 GPU 副本源自固定(頁面鎖定)內(nèi)存時(shí),速度要快得多。 有關(guān)通常何時(shí)以及如何使用固定內(nèi)存的更多詳細(xì)信息,請(qǐng)參見使用固定內(nèi)存緩沖區(qū)。
對(duì)于數(shù)據(jù)加載,將pin_memory=True
傳遞到 DataLoader
將自動(dòng)將獲取的數(shù)據(jù)張量放置在固定內(nèi)存中,從而更快地將數(shù)據(jù)傳輸?shù)絾⒂?CUDA 的 GPU。
默認(rèn)的內(nèi)存固定邏輯僅識(shí)別張量以及包含張量的映射和可迭代對(duì)象。 默認(rèn)情況下,如果固定邏輯看到一個(gè)自定義類型的批處理(如果您有一個(gè)collate_fn
返回自定義批處理類型,則會(huì)發(fā)生),或者如果該批處理的每個(gè)元素都是自定義類型,則固定邏輯將 無法識(shí)別它們,它將返回該批處理(或那些元素)而不固定內(nèi)存。 要為自定義批處理或數(shù)據(jù)類型啟用內(nèi)存固定,請(qǐng)?jiān)谧远x類型上定義pin_memory()
方法。
請(qǐng)參見下面的示例。
例:
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)?
數(shù)據(jù)加載器。 組合數(shù)據(jù)集和采樣器,并在給定的數(shù)據(jù)集上提供可迭代的。
DataLoader
支持地圖樣式和可迭代樣式的數(shù)據(jù)集,具有單進(jìn)程或多進(jìn)程加載,自定義加載順序以及可選的自動(dòng)批處理(歸類)和內(nèi)存固定。
有關(guān)更多詳細(xì)信息,請(qǐng)參見 torch.utils.data
文檔頁面。
參數(shù)
1
)。True
以使數(shù)據(jù)在每個(gè)時(shí)間段都重新隨機(jī)播放(默認(rèn)值:False
)。shuffle
必須為False
。sampler
,但在 時(shí)間。 與batch_size
,shuffle
,sampler
和drop_last
互斥。0
表示將在主進(jìn)程中加載數(shù)據(jù)。 (默認(rèn):0
)True
,則數(shù)據(jù)加載器將張量復(fù)制到 CUDA 固定的內(nèi)存中,然后返回。 如果您的數(shù)據(jù)元素是自定義類型,或者您的collate_fn
返回的是自定義類型的批次,請(qǐng)參見下面的示例。True
以刪除最后不完整的批次,如果數(shù)據(jù)集大小不可分割 按批次大小。 如果False
并且數(shù)據(jù)集的大小不能被批次大小整除,那么最后一批將較小。 (默認(rèn):False
)0
)None
,則將在每個(gè)具有工作人員 ID (在播種之后和數(shù)據(jù)加載之前,將[0, num_workers - 1]
中的 int 作為輸入。 (默認(rèn):None
)Warning
如果使用spawn
啟動(dòng)方法,則worker_init_fn
不能是不可拾取的對(duì)象,例如 lambda 函數(shù)。 有關(guān) PyTorch 中與并行處理有關(guān)的更多詳細(xì)信息,請(qǐng)參見并行處理最佳實(shí)踐。
Note
len(dataloader)
啟發(fā)式方法基于所用采樣器的長(zhǎng)度。 當(dāng)dataset
是 IterableDataset
時(shí),無論多進(jìn)程加載配置如何,都將返回len(dataset)
(如果實(shí)現(xiàn)),因?yàn)?PyTorch 信任用戶dataset
代碼可以正確處理多進(jìn)程加載 避免重復(fù)數(shù)據(jù)。 有關(guān)這兩種類型的數(shù)據(jù)集以及 IterableDataset
如何與多進(jìn)程數(shù)據(jù)加載交互的更多詳細(xì)信息,請(qǐng)參見數(shù)據(jù)集類型。
class torch.utils.data.Dataset?
表示 Dataset
的抽象類。
代表從鍵到數(shù)據(jù)樣本的映射的所有數(shù)據(jù)集都應(yīng)將其子類化。 所有子類都應(yīng)該覆蓋__getitem__()
,支持為給定鍵獲取數(shù)據(jù)樣本。 子類還可以選擇覆蓋__len__()
,它有望通過許多 Sampler
實(shí)現(xiàn)以及 DataLoader
的默認(rèn)選項(xiàng)返回?cái)?shù)據(jù)集的大小。
Note
默認(rèn)情況下, DataLoader
構(gòu)造一個(gè)索引采樣器,該采樣器產(chǎn)生整數(shù)索引。 要使其與具有非整數(shù)索引/鍵的地圖樣式數(shù)據(jù)集一起使用,必須提供自定義采樣器。
class torch.utils.data.IterableDataset?
可迭代的數(shù)據(jù)集。
代表可迭代數(shù)據(jù)樣本的所有數(shù)據(jù)集都應(yīng)將其子類化。 當(dāng)數(shù)據(jù)來自流時(shí),這種形式的數(shù)據(jù)集特別有用。
所有子類都應(yīng)覆蓋__iter__()
,這將返回此數(shù)據(jù)集中的樣本迭代器。
當(dāng)子類與 DataLoader
一起使用時(shí),數(shù)據(jù)集中的每個(gè)項(xiàng)目都將由 DataLoader
迭代器產(chǎn)生。 當(dāng)num_workers > 0
時(shí),每個(gè)工作進(jìn)程將具有數(shù)據(jù)集對(duì)象的不同副本,因此通常需要獨(dú)立配置每個(gè)副本,以避免從工作進(jìn)程返回重復(fù)的數(shù)據(jù)。 get_worker_info()
在工作程序進(jìn)程中調(diào)用時(shí),返回有關(guān)工作程序的信息。 可以在數(shù)據(jù)集的__iter__()
方法或 DataLoader
的worker_init_fn
選項(xiàng)中使用它來修改每個(gè)副本的行為。
示例 1:在__iter__()
中將工作負(fù)載分配給所有工作人員:
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
示例 2:使用worker_init_fn
在所有工作人員之間分配工作量:
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)?
數(shù)據(jù)集包裝張量。
每個(gè)樣本將通過沿第一維索引張量來檢索。
Parameters
張量 (tensor*)–具有與第一維相同大小的張量。
class torch.utils.data.ConcatDataset(datasets)?
數(shù)據(jù)集是多個(gè)數(shù)據(jù)集的串聯(lián)。
此類對(duì)于組裝不同的現(xiàn)有數(shù)據(jù)集很有用。
Parameters
數(shù)據(jù)集(序列)–要連接的數(shù)據(jù)集列表
class torch.utils.data.ChainDataset(datasets)?
用于鏈接多個(gè) IterableDataset
的數(shù)據(jù)集。
此類對(duì)于組裝不同的現(xiàn)有數(shù)據(jù)集流很有用。 鏈接操作是即時(shí)完成的,因此將大型數(shù)據(jù)集與此類連接起來將非常有效。
Parameters
數(shù)據(jù)集(IterableDataset 的可迭代)–鏈接在一起的數(shù)據(jù)集
class torch.utils.data.Subset(dataset, indices)?
指定索引處的數(shù)據(jù)集子集。
Parameters
torch.utils.data.get_worker_info()?
返回有關(guān)當(dāng)前 DataLoader
迭代器工作進(jìn)程的信息。
在工作線程中調(diào)用時(shí),此方法返回一個(gè)保證具有以下屬性的對(duì)象:
id
:當(dāng)前工作人員 ID。num_workers
:工人總數(shù)。seed
:當(dāng)前工作程序的隨機(jī)種子集。 該值由主進(jìn)程 RNG 和工作程序 ID 確定。 有關(guān)更多詳細(xì)信息,請(qǐng)參見 DataLoader
的文檔。dataset
:此流程在中的數(shù)據(jù)集對(duì)象的副本。 請(qǐng)注意,在不同的過程中,這將是與主過程中的對(duì)象不同的對(duì)象。
在主進(jìn)程中調(diào)用時(shí),將返回None
。
Note
在傳遞給 DataLoader
的worker_init_fn
中使用時(shí),此方法可用于不同地設(shè)置每個(gè)工作進(jìn)程,例如,使用worker_id
將dataset
對(duì)象配置為僅讀取 分片數(shù)據(jù)集的特定部分,或使用seed
播種數(shù)據(jù)集代碼中使用的其他庫(kù)(例如 NumPy)。
torch.utils.data.random_split(dataset, lengths)?
將數(shù)據(jù)集隨機(jī)拆分為給定長(zhǎng)度的不重疊的新數(shù)據(jù)集。
Parameters
class torch.utils.data.Sampler(data_source)?
所有采樣器的基類。
每個(gè) Sampler 子類都必須提供__iter__()
方法(提供一種對(duì)數(shù)據(jù)集元素的索引進(jìn)行迭代的方法)和__len__()
方法,該方法返回返回的迭代器的長(zhǎng)度。
Note
DataLoader
并非嚴(yán)格要求__len__()
方法,但在涉及 DataLoader
長(zhǎng)度的任何計(jì)算中都應(yīng)采用。
class torch.utils.data.SequentialSampler(data_source)?
始終以相同順序順序采樣元素。
Parameters
data_source (數(shù)據(jù)集)–要從中采樣的數(shù)據(jù)集
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)?
隨機(jī)采樣元素。 如果不進(jìn)行替換,則從經(jīng)過改組的數(shù)據(jù)集中采樣。 如果要更換,則用戶可以指定num_samples
進(jìn)行繪制。
Parameters
True
為默認(rèn)值,則替換為True
True
時(shí)才應(yīng)指定此參數(shù)。class torch.utils.data.SubsetRandomSampler(indices)?
從給定的索引列表中隨機(jī)抽樣元素,而無需替換。
Parameters
索引(序列)–索引序列
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)?
以給定的概率(權(quán)重)從[0,..,len(weights)-1]
中采樣元素。
Parameters
True
,則抽取替代品抽取樣品。 如果沒有,則它們將被替換而不會(huì)被繪制,這意味著當(dāng)為一行繪制樣本索引時(shí),無法為該行再次繪制它。例
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)?
包裝另一個(gè)采樣器以產(chǎn)生一個(gè)小批量的索引。
Parameters
True
,則采樣器將丟棄最后一批,如果其大小小于batch_size
Example
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)?
將數(shù)據(jù)加載限制為數(shù)據(jù)集子集的采樣器。
與 torch.nn.parallel.DistributedDataParallel
結(jié)合使用時(shí)特別有用。 在這種情況下,每個(gè)進(jìn)程都可以將 DistributedSampler 實(shí)例作為 DataLoader 采樣器傳遞,并加載原始數(shù)據(jù)集的專有子集。
Note
假定數(shù)據(jù)集大小恒定。
Parameters
更多建議: