tensorlayer学习日志15_chapter5_5.4

标签: tensorlayer  tensorflow  nlp

本来还想开开心心地结束第五章,谁知5.4节遇上大坑了。。

import tensorflow as tf
import tensorlayer as tl
from stringclean import *
import numpy as np


vocabulary_size = 50000
embedding_size = 128

model_file_name = "model_word2vec_50k_128"
batch_size = None
_UNK = "_UNK"

sess = tf.InteractiveSession()

all_var = tl.files.load_npy_to_any(name=model_file_name + '.npy')
data = all_var['data']
count = all_var['count']
dictionary = all_var['dictionary']
reverse_dictionary = all_var['reverse_dictionary']

print("~~~~~~~Loading npy successfully~~~~~~~~~~~~")

tl.nlp.save_vocab(count, name='vocab_' + model_file_name + '.txt') 
del all_var, data, count 

# load_params= tl.files.load_npz(name=model_file_name + '.npz')
load_params= tl.files.load_npz(name='53model.npz')

print("~~~~~~~Loading npz successfully~~~~~~~~~~~~")

x = tf.placeholder(tf.int32, shape=[batch_size])


emb_net = tl.layers.EmbeddingInputlayer(inputs=x, vocabulary_size=vocabulary_size, embedding_size=embedding_size, name='embedding_layer')

tl.files.assign_params(sess, load_params, emb_net)

tl.layers.initialize_global_variables(sess)

emb_net.print_params()
emb_net.print_layers()


print('~~~~~~~~~~单词~~~~~~~~~~~~')

word = 'hello'
word_id = dictionary[word]
print('word_id::::', word_id)


print('~~~~~~~~~~拆词~~~~~~~~~~~~')

word = 'by'
word_id = tl.nlp.words_to_word_ids(word, dictionary, _UNK)
print('word_id::::', word_id)

context = tl.nlp.word_ids_to_words(word_id, reverse_dictionary)
print('context::::', context)

print('~~~~~~~~多词~~~~~~~~~~~')

words = ['i', 'am', 'tensor', 'layer']
word_ids = tl.nlp.words_to_word_ids(words, dictionary, _UNK)
print('word_ids::::', word_ids)
context = tl.nlp.word_ids_to_words(word_ids, reverse_dictionary)
print('context::::', context)

vectors = sess.run(emb_net.outputs, feed_dict={x: word_ids})
print('vectors::::', vectors.shape)

 输出如下:

~~~~~~~Loading npy successfully~~~~~~~~~~~~
[TL] 50000 vocab saved to vocab_model_word2vec_50k_128.txt in C:\bbbb\学习\python教材\jfj\一起玩转Tensorlayer
~~~~~~~Loading npz successfully~~~~~~~~~~~~
[TL] EmbeddingInputlayer embedding_layer: (50000, 128)
[TL]   param   0: embedding_layer/embeddings:0 (50000, 128)       float32_ref (mean: -3.736475628102198e-05, median: -4.611164331436157e-05, std: 0.057736434042453766)   
[TL]   num of params: 6400000
[TL]   layer   0: embedding_layer/embedding_lookup:0 (?, 128)           float32
~~~~~~~~~~单词~~~~~~~~~~~~
word_id:::: 6436
~~~~~~~~~~拆词~~~~~~~~~~~~
word_id:::: [73, 495]
context:::: ['b', 'y']
~~~~~~~~多词~~~~~~~~~~~
word_ids:::: [72, 1226, 13297, 1987]
context:::: ['i', 'am', 'tensor', 'layer']
vectors:::: (4, 128)

我一开始运行时是有报错的如下:

~~~~~~~Loading npy successfully~~~~~~~~~~~~
[TL] 50000 vocab saved to vocab_model_word2vec_50k_128.txt in C:\bbbb\学习\python教材\jfj\一起玩转Tensorlayer
Traceback (most recent call last):
  File "C:\bbbb\学习\python教材\jfj\一起玩转Tensorlayer\5.4.py", line 22, in <module>
    load_params = tl.files.load_npz(name=model_file_name + '.npz')
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorlayer\files.py", line 1207, in load_npz
    return d['params']
  File "C:\ProgramData\Anaconda3\lib\site-packages\numpy\lib\npyio.py", line 239, in __getitem__
    raise KeyError("%s is not a file in the archive" % key)
KeyError: 'params is not a file in the archive'
[Finished in 4.4s]

拆腾了一天,发现主要问题是第5.3节上有个差错,

load_params= tl.files.load_npz(name=model_file_name + '.npz')

这行是导入不了npz的,会报上面的错。
load_params= tl.files.load_npz(name='53model.npz')

这行是我改的,53model.npz 是我后来生成的,我把5.3的代码这里改了,注意一下:

if (step % (print_freq * 5) == 0) and (step != 0):
        print("******Save model, data and dictionaries***" + "!" * 10)
        # Save to ckpt or npz file
        # saver = tf.train.Saver()
        # save_path = saver.save(sess, model_file_name+'.ckpt')
        tl.files.save_npz_dict(emb_net.all_params, name=model_file_name + '.npz', sess=sess)
        tl.files.save_npz(save_list=None, name='53model.npz', sess=sess)
        tl.files.save_any_to_npy(
            save_dict={
                'data': data,
                'count': count,
                'dictionary': dictionary,
                'reverse_dictionary': reverse_dictionary
            }, name=model_file_name + '.npy'
        )
    step += 1

对的,我新加了 tl.files.save_npz(save_list=None, name='53model.npz', sess=sess)这行。因为tl.files.save_npz_dict(emb_net.all_params, name=model_file_name + '.npz', sess=sess)这行生成的npz是不符合load npz规则的,因为npz起始文件名要params才行。就是因为这个害我重新run了一遍5.3的程序。。。。。

为什么,因为源码是这么要求的啊~~下面是源码

def load_npz(path='', name='model.npz'):
    """Load the parameters of a Model saved by tl.files.save_npz().

    Parameters
    ----------
    path : str
        Folder path to `.npz` file.
    name : str
        The name of the `.npz` file.

    Returns
    --------
    list of array
        A list of parameters in order.

    Examples
    --------
    - See ``tl.files.save_npz``

    References
    ----------
    - `Saving dictionary using numpy <http://stackoverflow.com/questions/22315595/saving-dictionary-of-header-information-using-numpy-savez>`__

    """
    d = np.load(path + name)
    return d['params'] 

 

 

 

版权声明:本文为weixin_42025210原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_42025210/article/details/81906595

智能推荐

phpstudy的mysql版本升级至5.7

phpstudy安装的mysql版本一般都是5.5或5.4的,但是有时候做项目又必须用到mysql5.7版本,所以我们现在来看一下如何在phpstudy的环境下将mysql版本升级至5.7   温馨提醒: 先删掉所有环境变量,如果是之前有的话,不然怎么安装cmd上指向的还是原来的版本。安装完再设新的环境变量。 并且卸载掉mysqld服务mysqld remove。如果不先删除的话,可能会...

RIP/DHCP/ACL综合实验

组播: 加入组的组成员才会接受到消息,只需要将流量发送一次到组播地址 减少控制面流量,减少头部复制, RIP1  广播   有类  不支持认证 RIP2  组播   无类  (支持VLAN)、支持认证 所有距离矢量路由协议:具有距离矢量特征的协议,都会在边界自动汇总 控制平面  路由的产生是控制平面的流量 数据平面  ...

【Sublime】使用 Sublime 工具时运行python文件

使用 Sublime 工具时报Decode error - output not utf-8解决办法   在菜单中tools中第四项编译系统 内最后一项增添新的编译系统 自动新建 Python.sublime-build文件,并添加"encoding":"cp936"这一行,保存即可 使用python2 则注释encoding改为utf-8 ctr...

java乐观锁和悲观锁最底层的实现

1. CAS实现的乐观锁 CAS(Compare And Swap 比较并且替换)是乐观锁的一种实现方式,是一种轻量级锁,JUC 中很多工具类的实现就是基于 CAS 的,也可以理解为自旋锁 JUC是指import java.util.concurrent下面的包, 比如:import java.util.concurrent.atomic.AtomicInteger; 最终实现是汇编指令:lock...

Python 中各种imread函数的区别与联系

  原博客:https://blog.csdn.net/renelian1572/article/details/78761278 最近一直在用python做图像处理相关的东西,被各种imread函数搞得很头疼,因此今天决定将这些imread总结一下,以免以后因此犯些愚蠢的错误。如果你正好也对此感到困惑可以看下这篇总结。当然,要了解具体的细节,还是应该 read the fuc...

猜你喜欢

用栈判断一个字符串是否平衡

注: (1)本文定义:左符号:‘(’、‘[’、‘{’…… 右符号:‘)’、‘]’、‘}’……. (2)所谓的字符串的符号平衡,是指字符串中的左符号与右符号对应且相等,如字符串中的如‘(&r...

JAVA环境变量配置

位置 计算机->属性->高级系统设置->环境变量 方式一 用户变量新建path 系统变量新建classpath 方式二 系统变量 新建JAVA_HOME,值为JDK路径 编辑path,前加 方式三 用户变量新建JAVA_HOME 此路径含lib、bin、jre等文件夹。后运行tomcat,eclipse等需此变量,故最好设。 用户变量编辑Path,前加 系统可在任何路径识别jav...

常用的伪类选择器

CSS选择器众多 CSS选择器及权重计算 最常用的莫过于类选择器,其它的相对用的就不会那么多了,当然属性选择器和为类选择器用的也会比较多,这里我们就常用的伪类选择器来讲一讲。 什么是伪类选择器? CSS伪类是用来添加一些选择器的特殊效果。 常用的为类选择器 状态伪类 我们中最常见的为类选择器就是a标签(链接)上的为类选择器。 当我们使用它们的时候,需要遵循一定的顺序问题,否则将可能出现bug 注意...

ButterKnife的使用介绍及原理探究(六)

前面分析了ButterKnife的源码,了解其实现原理,那么就将原理运用于实践吧。 github地址:       点击打开链接 一、自定义注解 这里为了便于理解,只提供BindView注解。 二、添加注解处理器 添加ViewInjectProcessor注解处理器,看代码, 这里分别实现了init、getSupportedAnnotationTypes、g...

1.写一个程序,提示输入两个字符串,然后进行比较,输出较小的字符串。考试复习题库1|要求:只能使用单字符比较操作。

1.写一个程序,提示输入两个字符串,然后进行比较,输出较小的字符串。 要求只能使用单字符比较操作。 参考代码: 实验结果截图:...