App下載

pytorch中怎么抽取一個tensor的行?x[...,0]使用介紹

猿友 2021-07-26 09:32:50 瀏覽數(shù) (5218)
反饋

很多小伙伴在學(xué)習(xí)pytorch的時候會遇到x[...,0]這樣的寫法,但并不了解這樣的寫法有什么用,小編經(jīng)過實驗得出了這個寫法的功能為pytorch抽取tensor的行。接下來就來看看x[...,0]怎么使用吧。

實驗代碼如下:

b=torch.Tensor([[[[10,2],[4,5],[7,8]],[[1,2],[4,5],[7,8]]]])
print(b.size())
(1, 2, 3, 2)
print(b[…,0])
tensor([[[10., 4., 7.],
[ 1., 4., 7.]]])
print(b[…,0].size())
(1, 2, 3)
print(b[…,2])
Traceback (most recent call last):
File “”, line 1, in
IndexError: index 2 is out of bounds for dimension 3 with size 2
print(b[0,…])
tensor([[[10., 2.],
[ 4., 5.],
[ 7., 8.]],
[[ 1., 2.],
[ 4., 5.],
[ 7., 8.]]])
print(b[0,…].size())
(2, 3, 2)
print(b[0,…,0].size())
(2, 3)
print(b[0,…,0])
tensor([[10., 4., 7.],
[ 1., 4., 7.]])

[…, 0]表示抽取tensor b的第4根軸上的第一列數(shù)字組成tensor,[0, …]表示抽取tensor b的第一根軸上的第一列數(shù)字組成tensor,[0, …, 0]表示抽取b的第一根和第四根軸上的第一列數(shù)字組成tensor。

還發(fā)現(xiàn)一個現(xiàn)象

print(b[…,0:])
tensor([[[[10., 2.],
[ 4., 5.],
[ 7., 8.]],
[[ 1., 2.],
[ 4., 5.],
[ 7., 8.]]]])
print(b[…,1:])
tensor([[[[2.],
[5.],
[8.]],
[[2.],
[5.],
[8.]]]])
print(b[…,2:])
tensor([], size=(1, 2, 3, 0))

補(bǔ)充:PyTorch中[..., 0]的使用案例

1. 示例1

import torch
a = torch.rand((17, 24, 8))
b = a[..., 0]
c = a[:, :, 0]
print(b == c)

結(jié)果為True

2. 示例2

import torch
a = torch.rand((64, 17, 24, 8))
b = a[..., 0]
c = a[:, :, :, 0]
print(b == c)

結(jié)果為True

3. 結(jié)論

可以看出[…, 0]相當(dāng)于[:, :, … :, 0]

以上就是pytorch抽取tensor的行的全部內(nèi)容,希望能給大家一個參考,也希望大家多多支持W3Cschool。



0 人點(diǎn)贊