本文介绍 TensorFlow模型冷冻以及为什么tensor名字要加:0?

TensorFlow模型冷冻以及为什么tensor名字要加:0?

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

这个其实是很多人会遇到的问题,先来记录一下本文的干货,然后在抛出一个问题。首先是大家把模型训练好了,此时应该如何把模型冷冻为pb文件?在解答这个问题之前,需要先回答,为什么要转为pb文件?

模型部署为什么要转成pb文件

你训练了模型,把模型给别人验收,你需要把一坨python代码给别人吗?完全没有必要啊!你训练拿到了权重,我们叫做checkpoint, 这个里面存放的全部都是每一个tensor的权重,同时也包含了许多的varaible的值,这些权重你可以直接用,但是其实大多数情况下没有必要,而且显得很臃肿。

其实问题是可以很完美、简洁的解决,重点是,解决之后,你可以直接丢给对方一个pb文件,预测的代码你用python可以,C++也行,也就没有关系了,甚至可以转onnx,转ncnn,转二进制,都是没有问题的。

首先,保存网络结构

那么要达到上面那个目的,首先你需要把python代码构建的网络模型保存下来,这可以做到吗?可以!而且很简单:

with tf.Session() as sess:
	tf.train.write_graph(sess.graph_def, os.path.dirname(args.checkpoint), 'model.pbtxt', as_text=True)

一行代码,就可以把网络结构保存为model.pbtxt 请注意,这里我使用了as_text=True ,这个好处就是保存的为pbtxt的文本,可以直接用文本编辑器查看,当然默认是false,此时是二进制的。

你保存好了,如果要查看应该怎么办?也很简单:

import tensorflow as tf
from google.protobuf import text_format

graph_def = tf.GraphDef()
text_format.Merge(open('model.pbtxt').read(), graph_def)
print(graph_def)

十分简单,这是标准的网络结构读入与写入的过程。

但是, 假如你保存的为二进制的方式,就应该用另一端代码读取了,具体不赘述。

如何freeze模型

上面其实保存的仅仅是网络结构,我需要吧网路结构和权重冷冻到一起,这样部署的时候就不至于一坨python代码了,怎么做?这是时候就得使用官方唯一指定工具了:freeze_graph, 通常情况下,你安装tensorflow的时候,这个工具会自动安装,直接使用:

freeze_graph --input_graph ../model.pbtxt --input_checkpoint model-120009 --output_node_names Openpose/concat_stage7 --output_graph graph_freeze.pb

这些参数你不要看多,你不得不添加这么参数,因为要冷冻模型,你至少需要这些东西:

  • 刚才我们保存的网络结构;
  • 权重文件(这里指定第几次迭代即可无需指定特定文件);
  • 输入的nodenames;
  • 保存文件名

但是大家仔细分析一下,其他的都有了,好像这个输出的nodenames不知道啊。。。这个呢,其实如果你对网络结构熟悉,你应该就知道输出是哪个了。

但是,各位,tensorflow坑跌的地方就在这里,这个时候你已经成功的保存了pb模型,但是要进行推理的时候,你可能会遇到这样的错误:

image' looks like an (invalid) Operation name, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".

这是啥意思?意思是你在feed 一个image或者其他数据的时候,给一个tensor名字叫做 image的feed了,而不是 image:0 。。。

大家可能要问了,不是image吗,我的输入tensor就是image啊,这个:0是卵意思嘛???

为什么要在tensor后面加:0?

本质上, tensorflow保存的图里面,是一个一个的node链接,一个node是什么?node是什么都不知道?想象一下小朋友手拉手,小朋友就是一个node,左手进右手出,那就是输入和输出。在想一想,有的小朋友可能有三只手,两个左手,一个右手,那么你的输入就是两个入口,输出一个入口,通常情况下,你取一个tensor,如果你给它输入一个值,你的需要在tensor后面加上index,来告诉网络,我从哪个手出去,明白了吗?

想必大家都明白了,那么为什么有时候在输出节点我就不需要添加:0的操作了呢?原来也很简单,因为tensorflow对于多个输出的节点,自动把输出用一个tuple或者list包含起来了,你直接对tuple进行index也行。

好了,以后大家可以直接扔一个pb文件给对方,让他去推理,而真正的网络定义和训练则可以屏蔽起来。