Tensorflow在深度学习模型研究中起到了很大的促进作用,灵活的框架免去了研究人员、开发者大量的自动求导代码工作。本文总结一下常用的模型代码和工程化需要的代码。有需求的同学收藏一下,以便日后查阅。
Tensorflow常见模型
A. LSTM模型结构
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.python.ops import array_ops
class lstm(object):
? ? def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):
? ? ? ? self.in_data = in_data
? ? ? ? self.hidden_dim = hidden_dim
? ? ? ? self.batch_seqlen = batch_seqlen
? ? ? ? self.flag = flag
? ? ? ? lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)
? ? ? ? out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)
? ? ? ? if flag=='all_ht':
? ? ? ? ? ? self.out = out
? ? ? ? if flag = 'first_ht':
? ? ? ? ? ? self.out = out[:,0,:]
? ? ? ? if flag = 'last_ht':
? ? ? ? ? ? self.out = out[:,-1,:]
? ? ? ? if flag = 'concat':
? ? ? ? ? ? self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)
B. Bi-LSTM模型结构
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.python.ops import array_ops
from tensorflow.python.framework import dtypes
class bilstm(object):
? ? def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):
? ? ? ? self.in_data = in_data
? ? ? ? self.hidden_dim = hidden_dim
? ? ? ? self.batch_seqlen = batch_seqlen
? ? ? ? self.flag = flag
? ? ? ? lstm_cell_fw = contrib.rnn.LSTMCell(self.hidden_dim)
? ? ? ? lstm_cell_bw = contrib.rnn.LSTMCell(self.hidden_dim)
? ? ? ? out, state = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_cell_fw,cell_bw=lstm_cell_bw,inputs=self.in_data, sequence_lenth=self.batch_seqlen,dtype=tf.float32)
? ? ? ? bi_out = tf.concat(out, 2)
? ? ? ? if flag=='all_ht':
? ? ? ? ? ? self.out = bi_out
? ? ? ? if flag=='first_ht':
? ? ? ? ? ? self.out = bi_out[:,0,:]
? ? ? ? if flag=='last_ht':
? ? ? ? ? ? self.out = tf.concat([state[0].h,state[1].h], 1)
? ? ? ? if flag=='concat':
? ? ? ? ? ? self.out = tf.concat([bi_out[:,0,:],tf.concat([state[0].h,state[1].h], 1)],1)
C multi-channel CNN
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.python.ops import array_ops
class lstm(object):
? ? def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):
? ? ? ? self.in_data = in_data
? ? ? ? self.hidden_dim = hidden_dim
? ? ? ? self.batch_seqlen = batch_seqlen
? ? ? ? self.flag = flag
? ? ? ? lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)
? ? ? ? out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)
? ? ? ? if flag=='all_ht':
? ? ? ? ? ? self.out = out
? ? ? ? if flag = 'first_ht':
? ? ? ? ? ? self.out = out[:,0,:]
? ? ? ? if flag = 'last_ht':
? ? ? ? ? ? self.out = out[:,-1,:]
? ? ? ? if flag = 'concat':
? ? ? ? ? ? self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)
D depth-wise cnn
import tensorflow as tf
def depth_wise_conv(in_data, scope, kernel_size, dim):
? ? with tf.variable_scope(scope):
? ? ? ? shapes = in_data.shape.as_list()
? ? ? ? depthwise_filter = tf.get_varibale("depthwise_conv.weight",
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (kernel_size[0], kernel_size[1], shapes[-1]
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dtype=tf.float32, )
? ? ? ? pointwise_filter = tf.get_variable("pointwise_conv.weight",
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (1,1, shapes[-1], dim),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dtype=tf.float32, )
? ? ? ? outputs = tf.nn.separable_conv2d(in_data,?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?depthwise_filter,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pointwise_filter,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?strides=(1,1,1,1),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?padding="SAME"
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? )
? ? ? ? return outputs
D multi-layer depth-wise cnn
def multi_convs(input_x, dim, conv_number=2, k=5):
? ? # input_x: 输入数据,为batch * seq * dim
? ? # dim:对应的输入的维度
? ? # conv_number: 对应的卷积的层数,一般2,
? ? # k对应的是卷积核的窗口大小
? ? res = input_x
? ? for index in range(conv_number):
? ? ? ? out = norm(res)? # layer norm
? ? ? ? out = tf.expand_dims(out, 2)? # bach * seq * 1 * dim
? ? ? ? out = depth_wise_conv(out, kernel_size=(k, 1), dim=dim, scope="convs.%d" % index)
? ? ? ? out = tf.squeeze(out, 2)? # batch * seq * dim
? ? ? ? out = tf.nn.relu(out)
? ? ? ? out = out + res
? ? ? ? res = out
? ? out = norm(out)? ? ? ? ? ? ? ? ? ? ? ? # 输出为 batch * seq * 1 * dim
? ? out = tf.squeeze(out, squeeze_dims=2)? # 输出为 batch * seq * dim
? ? return out
模型参数查看
已知模型文件的ckpt文件,通过pywrap_tensorflow获取模型的各参数名。
import tensoflow as tf
from tensorflow.python import pywrap_tensorflow
model_dir = "./ckpt/"
ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
param_dict = reader.get_variable_to_shape_map()
for key, val in param_dict.items():
? ? try:
? ? ? ? print key, val
? ? except:
? ? ? ? pass
工程化方法
A. tennsorflow模型文件打包成PB文件
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
with tf.Graph().as_default():
? ? with tf.device("/cpu:0"):
? ? ? ? config = tf.ConfigProto(allow_soft_placement=True)
? ? ? ? with tf.Session(config=config).as_default() as sess:
? ? ? ? ? ? model = Your_Model_Name()
? ? ? ? ? ? model.build_graph()
? ? ? ? ? ? sess.run(tf.initialize_all_variables())
? ? ? ? ? ? saver = tf.train.Saver()
? ? ? ? ? ? ckpt_path = "/your/model/path"
? ? ? ? ? ? saver.restore(sess, ckpt_path)
? ? ? ? ? ? graphdef = tf.get_default_graph().as_graph_def()
? ? ? ? ? ? tf.train.write_graph(sess.graph_def,"/your/save/path/","save_name.pb",as_text=False)
? ? ? ? ? ? frozen_graph = tf.graph_util.convert_variables_to_constants(sess,graphdef,['output/node/name'])
? ? ? ? ? ? frozen_graph_trim = tf.graph_util.remove_training_nodes(frozen_graph)
? ? ? ? ? ? freeze_graph.freeze_graph('/your/save/path/save_name.pb','',True, ckpt_path,'output/node/name','save/restore_all','save/Const:0','frozen_name.pb',True,"")
output_graph_def = tf.GraphDef()
with open("your_name.pb","rb") as f:
? ? output_graph_def.ParseFromString(f.read())
? ? _ = tf.import_graph_def(output_graph_def, name="")
node_in = sess.graph.get_tensor_by_name("input_node_name")
model_out = sess.graph.get_tensor_by_name("out_node_name")
feed_dict = {node_in:in_data}
pred = sess.run(model_out, feed_dict)
注:本文代码均为笔者手敲留存,如代码有误可以咨询探讨。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。