Torch笔记之如何读取txt并在字母层面上将内容按照字典替换成相应的数字

版权声明:本文为博主原创文章,如未特别声明,均默认使用CC BY-SA 3.0许可。 https://blog.csdn.net/Geek_of_CSDN/article/details/82081743

这个转换的问题其实是贫僧在尝试将某个.txt文件转换成Tensor来喂给训练好的神经网络模型时遇到的(训练的神经网络是char level的,具体看贫僧之前的博文)时遇到的。实现的步骤分成以下几个部分:

  1. 读取txt文件内容
  2. 将txt内容按照字典转化成对应的数字
  3. 将文件保存为.t7格式,方便神经网络读取

在正式开始之前先说下字典,字典是用这种方式生成的:

alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} "
dict = {}
for i = 1,#alphabet do
    dict[alphabet:sub(i,i)] = i
end
alphabet_size = #alphabet

生成后的字典里面每一个字符都对应这一个数字。

读取txt文件内容

这部分的实现很简单,因为Lua已经实现了这个功能,直接调用就行了:

--[[
    函数名:read_files
    输入:txt文件路径
    输出:txt文件内容,string类型
]]
function read_files(filename)
    local f = assert(io.open(filename, 'r'))
    local content = f:read('*all')
    f:close()
    return content
end

上面这个函数最后会返回txt文件的内容,因为比较简单所以就不细讲了。

将txt内容按照字典转化成对应的数字

这个地方有两个难点:

  1. Lua不像Python可以直接用for来遍历字符串,所以要用另一种方法来遍历字符串
  2. 按照字典转换

先上代码:

--[[
    函数名:translate_char
    输入:文件名,是否保存(true或者false)
    输出:按照字典翻译之后的tensor
    备注:其实可以改一改就可以变成word型的,默认是只接受一个文件,然后重复60次内容,如果需要接收60个不同文件的内容要另外改
]]
function translate_char(filename, save_translation)
    local m = torch.Tensor(60, 201, 10):zero()
    local content = read_files(filename)
    local tmp_i = 0
    local j = 1
    for i = 1, #content do
        if content:sub(i, i) == '\n' then  -- 分行
            j = j + 1
            tmp_i = i
        else
            if j > 10 then  -- 为了保证后面输入进矩阵的时候不会越界,而且也有助于看哪个文档里面的caption多于10个
                print('error! j is %d, file name is %s', j, filename)
                break
            end
            if dict[content:sub(i, i)] ~= nil then  -- 避免文档中出现了字典中没有的特殊字符
                if i - tmp_i < 201 then
                    for k = 1, 60 do
                        m[{k, i - tmp_i, j}] = dict[content:sub(i, i)]
                    end
                    -- print(j)  -- 这是用来调试的代码,用来检查行数是不是正确的
                end
            end
        end
    end
    if save_translation then
        torch.save(string.gsub(filename, '.txt', '.t7'), m)
    end
    return m
end

上面多了一些次要的东西,例如local m = torch.Tensor(60, 201, 10):zero(),这行是设置零矩阵,这是根据论文的要求来做的,因为默认内容长度不够的地方会用0来补充,而内容长度超过201的时候才会忽略掉后面的内容(所以有这句if i - tmp_i < 201 then)。核心的替换代码其实是m[{k, i - tmp_i, j}] = dict[content:sub(i, i)]。而遍历字符串用到的语句就是content:sub(i, i),要搭配for循环来用。

先细讲一下遍历字符串,sub(j, k)这个函数其实就是截取第jk位的字符(注意,Lua里面字符从1开始计数),所以如果是sub(i, i)的话就会提取第i位的字符,因此Lua里面遍历字符串要这样做:

for i = 1, #content do
    character = content:sub(i, i)
    -- 补充对单个字母的操作
end

而查字典的话其实就是将字母输入到字典的索引里面(其实Lua就只有table一个类型,但是贫僧更加熟悉python,所以用了“字典”这个词)。

m[{k, i - tmp_i, j}] = dict[content:sub(i, i)]

结合上面的遍历字符串操作就可以很容易理解这里是在做什么了。

猜你喜欢

转载自blog.csdn.net/Geek_of_CSDN/article/details/82081743