TensorFlow函數(shù)教程:tf.nn.ctc_loss

2019-01-31 13:45 更新

tf.nn.ctc_loss函數(shù)

tf.nn.ctc_loss(
    labels,
    inputs,
    sequence_length,
    preprocess_collapse_repeated=False,
    ctc_merge_repeated=True,
    ignore_longer_outputs_than_inputs=False,
    time_major=True
)

定義在:tensorflow/python/ops/ctc_ops.py.

參見指南:神經網絡>連接時間分類(CTC)

計算CTC(連接時間分類)loss.

輸入要求:

sequence_length(b) <= time for all b

max(labels.indices(labels.indices[:, 1] == b, 2))
  <= sequence_length(b) for all b.

筆記:

此類為您執(zhí)行softmax操作,因此輸入應該是例如LSTM對輸出的線性預測.

inputs張量的最內層的維度大小,num_classes,代表num_labels + 1類別,其中num_labels是實際的標簽的數(shù)量,而最大的值(num_classes - 1)是為空白標簽保留的.

例如,對于包含3個標簽[a, b, c]的詞匯表,num_classes = 4,并且標簽索引是{a: 0, b: 1, c: 2, blank: 3}.

關于參數(shù)preprocess_collapse_repeatedctc_merge_repeated

如果preprocess_collapse_repeated為True,則在loss計算之前運行預處理步驟,其中傳遞給loss的重復標簽會合并為單個標簽.如果訓練標簽來自,例如強制對齊,并因此具有不必要的重復,則這是有用的.

如果ctc_merge_repeated設置為False,則在CTC計算的深處,重復的非空白標簽將不會合并,并被解釋為單個標簽.這是CTC的簡化(非標準)版本.

以下是(大致)預期的第一順序行為表:

  • preprocess_collapse_repeated=Falsectc_merge_repeated=True

典型的CTC行為:輸出實際的重復類,其間有空白,還可以輸出中間沒有空白的重復類,這需要由解碼器折疊.

  • preprocess_collapse_repeated=Truectc_merge_repeated=False

不要得知輸出重復的類,因為它們在訓練之前在輸入標簽中折疊.

  • preprocess_collapse_repeated=Falsectc_merge_repeated=False

輸出中間有空白的重復類,但通常不需要解碼器折疊/合并重復的類.

  • preprocess_collapse_repeated=Truectc_merge_repeated=True

未經測試,很可能不會得知輸出重復的類.

ignore_longer_outputs_than_inputs選項允許在處理輸出長于輸入的序列時指定CTCLoss的行為.如果為true,則CTCLoss將僅為這些項返回零梯度,否則返回InvalidArgument錯誤,停止訓練.

參數(shù):

  • labels:一個int32SparseTensor;labels.indices[i, :] == [b, t]表示labels.values[i]存儲(batch b, time t)的id;labels.values[i]必須采用[0, num_labels)中的值.
  • inputs:3-D float Tensor;如果time_major == False,這將是一個Tensor,形狀:[batch_size, max_time, num_classes]如果time_major == True(默認值),這將是一個Tensor,形狀:[max_time, batch_size, num_classes];是logits.
  • sequence_length:1-Dint32向量,大小為[batch_size];序列長度.
  • preprocess_collapse_repeatedBoolean,默認值:False;如果為True,則在CTC計算之前折疊重復的標簽.
  • ctc_merge_repeatedBoolean,默認值:True.
  • ignore_longer_outputs_than_inputs:Boolean,默認值:False;如果為True,則輸出比輸入長的序列將被忽略.
  • time_majorinputs張量的形狀格式;如果是True,那些Tensors必須具有形狀[max_time, batch_size, num_classes];如果為False,則Tensors必須具有形狀[batch_size, max_time, num_classes];使用time_major = True(默認)更有效,因為它避免了在ctc_loss計算開始時的轉置.但是,大多數(shù)TensorFlow數(shù)據(jù)都是批處理為主的,因此通過此函數(shù)還可以接受以批處理為主的形式的輸入.

返回:

1-DfloatTensor,大小為[batch]包含負對數(shù)概率.

可能引發(fā)的異常:

  • TypeError:如果標簽不是SparseTensor.
以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號