App下載

pytorch中的kl散度計(jì)算問(wèn)題

猿友 2021-07-17 11:53:50 瀏覽數(shù) (5104)
反饋

在我們使用pytorch的時(shí)候會(huì)涉及到相對(duì)熵的使用,而有用過(guò)TensorFlow的小伙伴可能會(huì)發(fā)覺(jué)到pytorch的相對(duì)熵在使用上好像有一點(diǎn)奇怪,今天小編就來(lái)講講pytorch怎么計(jì)算相對(duì)熵,以及他為什么會(huì)有這些使用問(wèn)題吧!

起因

小編偶然從pytorch討論論壇中看到的一個(gè)問(wèn)題,kl divergence 在TensorFlow中和pytorch中計(jì)算結(jié)果不同,平時(shí)沒(méi)有注意到,記錄下。

kl divergence 介紹

KL散度( Kullback–Leibler divergence),又稱相對(duì)熵,是描述兩個(gè)概率分布 P 和 Q 差異的一種方法。計(jì)算公式:

相對(duì)熵計(jì)算公式

可以發(fā)現(xiàn),P 和 Q 中元素的個(gè)數(shù)不用相等,只需要兩個(gè)分布中的離散元素一致。

舉個(gè)簡(jiǎn)單例子:

兩個(gè)離散分布分布分別為 P 和 Q

P 的分布為:{1,1,2,2,3}

Q 的分布為:{1,1,1,1,1,2,3,3,3,3}

我們發(fā)現(xiàn),雖然兩個(gè)分布中元素個(gè)數(shù)不相同,P 的元素個(gè)數(shù)為 5,Q 的元素個(gè)數(shù)為 10。但里面的元素都有 “1”,“2”,“3” 這三個(gè)元素。

當(dāng) x = 1時(shí),在 P 分布中,“1” 這個(gè)元素的個(gè)數(shù)為 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 這個(gè)元素的個(gè)數(shù)為 5,故 Q(x = 1) = 5/10 = 0.5

同理,

當(dāng) x = 2 時(shí),P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1

當(dāng) x = 3 時(shí),P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4

把上述概率帶入公式:

公式計(jì)算

至此,就計(jì)算完成了兩個(gè)離散變量分布的KL散度。

pytorch 中的 kl_div 函數(shù)

pytorch中有用于計(jì)算kl散度的函數(shù) kl_div

torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')

代碼運(yùn)算

計(jì)算 D (p||q)

1、不用這個(gè)函數(shù)的計(jì)算結(jié)果為:

運(yùn)算結(jié)果

與手算結(jié)果相同

2、使用函數(shù):

(這是計(jì)算正確的,結(jié)果有差異是因?yàn)閜ytorch這個(gè)函數(shù)中默認(rèn)的是以e為底)

另外一種使用結(jié)果

注意:

1、函數(shù)中的 p q 位置相反(也就是想要計(jì)算D(p||q),要寫成kl_div(q.log(),p)的形式),而且q要先取 log

2、reduction 是選擇對(duì)各部分結(jié)果做什么操作,默認(rèn)為取平均數(shù),這里選擇求和

好別扭的用法,不知道為啥官方把它設(shè)計(jì)成這樣

補(bǔ)充:pytorch 的KL divergence的實(shí)現(xiàn)

看代碼吧~

import torch.nn.functional as F
# p_logit: [batch, class_num]
# q_logit: [batch, class_num]
def kl_categorical(p_logit, q_logit):
    p = F.softmax(p_logit, dim=-1)
    _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1)
                                  - F.log_softmax(q_logit, dim=-1)), 1)
    return torch.mean(_kl)

以上就是pytorch怎么計(jì)算相對(duì)熵的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。



0 人點(diǎn)贊