基本用法
- 使用图 (graph) 来表示计算任务.
- 在被称之为
会话 (Session)
的上下文 (context) 中执行图. - 使用 tensor 表示数据.
- 通过
变量 (Variable)
维护状态. - 使用 feed 和 fetch 可以为任意的操作(arbitrary operation) 赋值或者从其中获取数据.
构建图
1 | import tensorflow as tf |
启动图
1 | # 启动默认图. |
ession
对象在使用完后需要关闭以释放资源. 除了显式调用 close 外, 也可以使用 “with” 代码块 来自动完成关闭动作.
1 | with tf.Session() as sess: |
如果机器上有超过一个可用的 GPU, 除第一个外的其它 GPU 默认是不参与计算的. 为了让 TensorFlow 使用这些 GPU, 你必须将 op 明确指派给它们执行. with...Device
语句用来指派特定的 CPU 或 GPU 执行操作:
1 | with tf.Session() as sess: |
设备用字符串进行标识. 目前支持的设备包括:
"/cpu:0"
: 机器的 CPU."/gpu:0"
: 机器的第一个 GPU, 如果有的话."/gpu:1"
: 机器的第二个 GPU, 以此类推.
交互式使用
文档中的 Python 示例使用一个会话 Session
来 启动图, 并调用 Session.run()
方法执行操作.
为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用 InteractiveSession
代替 Session
类, 使用 Tensor.eval()
和 Operation.run()
方法代替 Session.run()
. 这样可以避免使用一个变量来持有会话.
1 | # 进入一个交互式 TensorFlow 会话. |
Variable
1 | # 创建一个变量, 初始化为标量 0. |
Fetch
为了取回操作的输出内容, 可以在使用 Session
对象的 run()
调用 执行图时, 传入一些 tensor, 这些 tensor 会帮助你取回结果. 在之前的例子里, 我们只取回了单个节点 state
, 但是你也可以取回多个 tensor:
1 | input1 = tf.constant(3.0) |
Feed
上述示例在计算图中引入了 tensor, 以常量或变量的形式存储. TensorFlow 还提供了 feed 机制, 该机制 可以临时替代图中的任意操作中的 tensor 可以对图中任何操作提交补丁, 直接插入一个 tensor.
feed 使用一个 tensor 值临时替换一个操作的输出结果. 你可以提供 feed 数据作为 run()
调用的参数. feed 只在调用它的方法内有效, 方法结束, feed 就会消失. 最常见的用例是将某些特殊的操作指定为 “feed” 操作, 标记的方法是使用 tf.placeholder() 为这些操作创建占位符.
1 | input1 = tf.placeholder(tf.types.float32) |
基础教程
MNIST
1 | import tensorflow as tf |
构建Softmax 回归模型
1 | x = tf.placeholder("float", shape=[None, 784]) |
这里的x
和y
并不是特定的值,相反,他们都只是一个占位符
,可以在TensorFlow运行某一计算时根据该占位符输入具体的值。
变量
需要通过seesion初始化后,才能在session中使用。这一初始化步骤为,为初始值指定具体值(本例当中是全为零),并将其分配给每个变量
,可以一次性为所有变量
完成此操作。
1 | sess.run(tf.initialize_all_variables()) |
类别预测与损失函数:
1 | y = tf.nn.softmax(tf.matmul(x,W) + b) |
训练模型:
1 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) |
这一行代码实际上是用来往计算图上添加一个新操作,其中包括计算梯度,计算每个参数的步长变化,并且计算出新的参数值。
返回的train_step
操作对象,在运行时会使用梯度下降来更新参数。因此,整个模型的训练可以通过反复地运行train_step
来完成。
1 | for i in range(1000): |
每一步迭代,我们都会加载50个训练样本,然后执行一次train_step
,并通过feed_dict
将x
和 y_
张量占位符
用训练训练数据替代。
注意,在计算图中,你可以用feed_dict
来替代任何张量,并不仅限于替换占位符
。
评估模型
tf.argmax
是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。
1 | correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) |
这里返回一个布尔数组。为了计算我们分类的准确率,我们将布尔值转换为浮点数来代表对、错,然后取平均值。例如:[True, False, True, True]
变为[1,0,1,1]
,计算出平均值为0.75
。
1 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) |
最后,我们可以计算出在测试数据上的准确率
1 | print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}) |
构建一个多层卷积网络
权重初始化
这个模型中的权重在初始化时应该加入少量的噪声来打破对称性以及避免0梯度。由于我们使用的是ReLU神经元,因此比较好的做法是用一个较小的正数来初始化偏置项,以避免神经元节点输出恒为0的问题(dead neurons)
1 | def weight_variable(shape): |
卷积和池化
1 | def conv2d(x, W): |
1 | W_conv1 = weight_variable([5, 5, 1, 32]) |
全连接层
1 | W_fc1 = weight_variable([7 * 7 * 64, 1024]) |
Dropout
我们用一个placeholder
来代表一个神经元的输出在dropout中保持不变的概率。这样我们可以在训练过程中启用dropout,在测试过程中关闭dropout。 TensorFlow的tf.nn.dropout
操作除了可以屏蔽神经元的输出外,还会自动处理神经元输出值的scale。所以用dropout的时候可以不用考虑scale。
1 | keep_prob = tf.placeholder("float") |
输出
1 | W_fc2 = weight_variable([1024, 10]) |
训练和评估模型
1 | cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) |
Cifar
进阶
API
Build Model
1 | tf.get_variable(name, |
创建或返回给定名称的变量
tf.variable_scope()
https://www.cnblogs.com/MY0213/p/9208503.html
用来指定变量的作用域,作为变量名的前缀,支持嵌套
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)
tf.nn.bias_add
tf.contrib.layers.batch_norm
tf.conv2d_transpose(value, filter, output_shape, strides, padding="SAME", data_format="NHWC", name=None)
tf.nn.tanh()
tf.reduce_mean(inputs, [1, 2], name='global_average_pooling', keepdims=True)
tf.image.resize_bilinear(image_level_features, inputs_size, name='upsample')
Operate
tf.slice(inputs,begin,size,name='')
inputs:可以是list,array,tensor
begin:n维列表,begin[i] 表示从inputs中第i维抽取数据时,相对0的起始偏移量,也就是从第i维的begin[i]开始抽取数据
size:n维列表,size[i]表示要抽取的第i维元素的数目
tf.concat([tensor1, tensor2, tensor3,...], axis)
tf.logging
tf.logging.set_verbosity (tf.logging.INFO)
设计日志级别
tf.logging.info(msg, *args, **kwargs)
记录INFO级别的日志.
tf.gfile
https://blog.csdn.net/pursuit_zhangyu/article/details/80557958
tf.contrib.slim
https://www.2cto.com/kf/201706/649266.html