借kaldi来理解softmax

nnet-forward --feature-transform=final.feature_transform --no-softmax=false final.nnet scp:1.scp ark,t:output.ark.txt

final.feature_transform, final.nnet均由训练而来。final.feature_transform是对输入特征做转换,作为final.nnet的输入。

final.nnet是最后一层是softmax层,可以由–no-softmax这个参数来指定。

1.scp是输入特征文件的索引,为简单起见,只用一个特征文件。

ark,t:output.ark.txt将输出转成文本文件并输出到output.ark.txt中。这是一个m x n的矩阵,m是输入特征的帧数,n是final.nnet的输出状态数(这里是3000)。

选其中一帧的输出来理解softmax

sed -n '10p' output.ark.txt > 1.txt

(1)–no-softmax=false,做softmax

nnet-forward --feature-transform=final.feature_transform --no-softmax=false final.nnet scp:1.scp ark,t:output.ark.txt
sed -n '10p' output.ark.txt > 1.txt 
awk '{print NF}' 1.txt

显示是应该就是用nnet-info final.nnet看到的最后一层的值:3000

awk 'BEGIN{total=0.0}{for(i=1;i<=NF;i++)total += $i}END{print total}' 1.txt

这时应该看到就是1,说明softmax是起作用的。再看1.txt里应该是3000个特别小的数。

(2). –no-softmax=true,不做softmax

nnet-forward --feature-transform=final.feature_transform --no-softmax=true final.nnet scp:1.scp ark,t:output.ark.txt
sed -n '10p' output.ark.txt > 2.txt

再看2.txt中的内容,应该是非常大的不规整的值。大于1数有很多。

(3). 运行softmax算法将2.txt转成1.txt

算法描述:

    1. 找出3000个数中的最大值qMax。

    2. 将每一值转成Xi=exp(Xi-qMax), 并累加到qSum

    3. 将每一值转成Xi/qSum

用python实现:

import sys
import math

qList = open(sys.argv[1],'r').readline().split()
print len(qList)

qMax = 0.0
for i in range(len(qList)):
    cur = float(qList[i])
    if cur>qMax:
        qMax = cur

print qMax
qSum = 0.0
for i in range(len(qList)):
    cur = float(qList[i])
    qList[i] = math.exp(cur-qMax)
    qSum += qList[i]

print qSum
fout = open(sys.argv[2],'w')
for i in range(len(qList)):
    qList[i] = qList[i]/qSum
    fout.write('%g\t' % qList[i])

fout.write('\n')
fout.close()

以上只是为了描述算法,效率不做考虑。

python softmax.py 2.txt 2_sm.txt

(4)比较1.txt, 2_sm.txt的值,日常接近,小有差异可能是python的数据转换带来的,对于算法的正确性没有影响。


kaldi中的代码实现:

./src/matrix/kaldi-matrix.cc

template<typename Real>
Real MatrixBase<Real>::ApplySoftMax() {
  Real max = this->Max(), sum = 0.0;
  // the 'max' helps to get in good numeric range.
  for (MatrixIndexT i = 0; i < num_rows_; i++)
    for (MatrixIndexT j = 0; j < num_cols_; j++)
      sum += ((*this)(i, j) = Exp((*this)(i, j) - max));
  this->Scale(1.0 / sum);
  return max + Log(sum);
}

cannot find -lgcc_s

问题:

/usr/bin/ld: skipping incompatible /usr/lib/gcc/x86_64-redhat-linux/4.4.7/libgcc_s.so when searching for -lgcc_s
/usr/bin/ld: skipping incompatible /usr/lib/gcc/x86_64-redhat-linux/4.4.7/libgcc_s.so when searching for -lgcc_s
/usr/bin/ld: cannot find -lgcc_s

解决办法:

yum install -y libgcc.i686