#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 19 09:42:22 2017

@author: myhaspl
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

INPUT_NODE=784
OUTPUT_NODE=10

LAYER1_NODE=500
BATCH_SIZE=100

LEARNING_RATE_BASE=0.8
LEARNING_RATE_DECAY=0.99

REGULARIZATION_RATE=0.0001
TRANING_STEPS=30000
MOVING_AVERAGE_DECAY=0.99

def inference(input_tensor,avg_class,weights1,biases1,weights2,biases2):
    if avg_class==None:#非滑动平均
        layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1)
        return tf.matmul(layer1,weights2)+biases2
    else:#滑动平均
        layer1=tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1))+avg_class.average(biases1))
        return tf.matmul(layer1,avg_class.average(weights2))+avg_class.average(biases2)

def train(mnist):
    #样本数据与样本标签
    x_=tf.placeholder(tf.float32,[None,INPUT_NODE],name='x_-input')
    y_=tf.placeholder(tf.float32,OUTPUT_NODE],name='y_-input')
    #参数初始值
    weights1=tf.Variable(tf.truncated_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1))
    biases1=tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))
    weights2=tf.Variable(tf.truncated_normal([LAYER1_NODE,stddev=0.1))
    biases2=tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE]))
    global_step=tf.Variable(0,trainable=False)
    
    #非滑动平均    
    y_nohd=inference(x_,None,biases2)

    #滑动平均
    variable_averages=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
    #滑动平均更新变量的操作
    variable_averages_op=variable_averages.apply(tf.trainable_variables())
    y_hd=inference(x_,variable_averages,biases2)

    #交叉嫡损失函数,使用softmax归一化
    cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y_nohd,labels=tf.arg_max(y_,1))
    cross_entropy_mean=tf.reduce_mean(cross_entropy)
    #加入L2正则化损失
    regularizer=tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularization=regularizer(weights1)+regularizer(weights2)
    loss=cross_entropy_mean+regularization
    
    #设置指数衰减的学习率
    learning_rate=tf.train.exponential_decay(
            LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY)
    
    train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
    #训练与更新参数的滑动平均值
    #将2大步操作打包在train_op中,第1大步操作是使用正则化和指数衰减更新参数值
    #第2大步操作是使用滑动平均再次更新参数值。
    #每次训练都完成这2大步操作。
    train_op=tf.group(train_step,variable_averages_op)
    #检验滑动平均平均模型的神经网络前向传播结果是否正确
    correct_predection=tf.equal(tf.argmax(y_hd,1),tf.argmax(y_,1))
    accuracy=tf.reduce_mean(tf.cast(correct_predection,tf.float32))
    
    #开始训练过程
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        #训练样本集
        validate_Feed={x_:mnist.validation.images,y_:mnist.validation.labels
                      }
        #测试集
        test_Feed={x_:mnist.test.images,y_:mnist.test.labels
                   }
        for i  in range(TRANING_STEPS):
            if i%1000==0:
                #每1000轮计算当前训练的结果
                validate_acc=sess.run(accuracy,Feed_dict=validate_Feed)
                print("%d次后=>正确率%g"%(i,validate_acc))
            #每一轮使用的样本,然后开始训练
            xs,ys=mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op,Feed_dict={x_:xs,y_:ys})
            
        
        #TRANING_STEPS次训练结束,对测试数据进行检测,检验神经网络准确度
        test_acc=sess.run(accuracy,Feed_dict=test_Feed)
        print("正确率:%g"%test_acc)

def main(argv=None):
    mnist=input_data.read_data_sets("/tmp/data",one_hot=True)
    train(mnist)
    
if __name__=='__main__':
    tf.app.run()
使用了非线性激活函数relu,防止梯度消失。

tf随笔-15 正则化+指数衰减+滑动平均的更多相关文章

  1. PyTorch实现MNIST数据集手写数字识别详情

    这篇文章主要介绍了PyTorch实现MNIST数据集手写数字识别详情,文章围绕主题展开详细的内容戒杀,具有一定的参考价值,需要的朋友可以参考一下

  2. 正则化DropPath/drop_path用法示例(Python实现)

    DropPath 类似于Dropout,不同的是 Drop将深度学习模型中的多分支结构随机"失效",而Dropout是对神经元随机"失效"这篇文章主要给大家介绍了关于正则化DropPath/drop_path用法的相关资料,需要的朋友可以参考下

  3. pytorch实现mnist手写彩色数字识别

    这篇文章主要介绍了pytorch-实现mnist手写彩色数字识别,文章围绕主题展开详细的内容姐介绍,具有一定的参考价值,需要的小伙伴可以参考一下

  4. caffe的python接口之手写数字识别mnist实例

    这篇文章主要为大家介绍了caffe的python接口之手写数字识别mnist实例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  5. ubuntu16.04+gtx1080ti+caffe安装记录+ N显卡驱动版本过高导致error==cudaSuccess(30 vs. 0) unknown error.

    ubuntu16.04+gtx1080ti+caffe安装记录麦克白攻城狮这几天安装cuda出现了很多问题,特意记录并分享给需要的人。反思是不是驱动版本太高。第四次直接用命令行apt-getinstall安装的375驱动,刚好是375.66,然后安装cuda8.0run文件,没有安装其自带的驱动,结果成功了,但是也遇到重启黑屏。第一部分:安装nvidia驱动,CUDA8.0和cudnn5.1备注:开机通过BIOS更改初始显卡可以选择集显或独显,独显进入ubuntu需要从启动菜单"高级模式"->"Recov

  6. ubuntu 16.04 gtx1060 显卡安装

    首先说明,这是在台式机上的安装测试经历,首先安装的win10,然后安装ubuntu16.04双系统,显卡为GTX1060台式机显示器接的是GTX1060HDMI口一、首先安装nvidia显卡驱动打开终端,先删除旧的驱动:sudoapt-getpurgenvidia*禁用自带的驱动(很重要!的错误,这是由于

  7. Ubuntu下Caffe安装记录

    总体过程Ubuntu16.04Caffe安装步骤记录(超详尽)手动安装Nvidia显卡驱动,安装Cuda时只需要安装ToolKit,也不要安装opengl问题解决修改Ubuntu16.04源为清华大学——解决依赖包无法安装问题解决APT-GET更新源报错:W:UNKNOWNMULTI-ARCHTYPE‘NO’FORPACKAGE‘COMPIZ-GNOME’ctrl+alt+F1~6进入不了字符界面

  8. 机器学习 – 为什么需要在机器学习问题中使用正则化?

    为什么在这种情况下我们更喜欢较小的重量?

  9. 机器学习 – TensorFlow – L2丢失的正则化,如何应用于所有权重,而不仅仅是最后一个?

    我有一个任务,涉及到使用L2丢失的一个隐藏的ReLU层向网络引入泛化。我不知道如何正确引入它,以便所有权重都受到惩罚,不仅仅是输出层的权重。代码网络没有泛化是在底部的帖子。引入L2的明显方法是用这样的代替损失计算:但在这种情况下,它将考虑到输出层权重的值。是否需要或引入惩罚的输出层将以某种方式保持隐藏的权重也在检查?hidden_weights,hidden_biases,out_weights和out_biases都是您正在创建的模型参数。您可以按如下所示对所有这些参数添加L2正则化:

  10. 机器学习 – TensorFlow – 将L2正则化和退出引入网络.有什么意义吗?

    如果是这样,怎么办?任何关于此事的参考将是有用的,我还没有找到任何信息.为了防止你有兴趣,我的代码为ANN与L2正则化在下面:好的,经过一些额外的努力,我设法解决它,并将L2和辍学引入我的网络,代码如下.在同一个网络中,我没有辍学略有改善.我仍然不确定是否真的很值得介绍他们两个,L2和辍学的努力,但至少它的作品,并略微提高了结果.

随机推荐

  1. 法国电话号码的正则表达式

    我正在尝试实施一个正则表达式,允许我检查一个号码是否是一个有效的法国电话号码.一定是这样的:要么:这是我实施的但是错了……

  2. 正则表达式 – perl分裂奇怪的行为

    PSperl是5.18.0问题是量词*允许零空间,你必须使用,这意味着1或更多.请注意,F和O之间的空间正好为零.

  3. 正则表达式 – 正则表达式大于和小于

    我想匹配以下任何一个字符:或=或=.这个似乎不起作用:[/]试试这个:它匹配可选地后跟=,或者只是=自身.

  4. 如何使用正则表达式用空格替换字符之间的短划线

    我想用正则表达式替换出现在带空格的字母之间的短划线.例如,用abcd替换ab-cd以下匹配字符–字符序列,但也替换字符[即ab-cd导致d,而不是abcd,因为我希望]我如何适应以上只能取代–部分?

  5. 正则表达式 – /bb | [^ b] {2} /它是如何工作的?

    有人可以解释一下吗?我在t-shirt上看到了这个:它似乎在说:“成为或不成为”怎么样?我好像没找到’e’?

  6. 正则表达式 – 在Scala中验证电子邮件一行

    在我的代码中添加简单的电子邮件验证,我创建了以下函数:这将传递像bob@testmymail.com这样的电子邮件和bobtestmymail.com之类的失败邮件,但是带有空格字符的邮件会漏掉,就像bob@testmymail也会返回true.我可能在这里很傻……当我测试你的正则表达式并且它正在捕捉简单的电子邮件时,我检查了你的代码并看到你正在使用findFirstIn.我相信这是你的问题.findFirstIn将跳转所有空格,直到它匹配字符串中任何位置的某个序列.我相信在你的情况下,最好使用unapp

  7. 正则表达式对小字符串的暴力

    在测试小字符串时,使用正则表达式会带来性能上的好处,还是会强制它们更快?不会通过检查给定字符串的字符是否在指定范围内比使用正则表达式更快来强制它们吗?

  8. 正则表达式 – 为什么`stoutest`不是有效的正则表达式?

    isthedelimiter,thenthematch-only-onceruleof?PATTERN?

  9. 正则表达式 – 替换..与.在R

    我怎样才能替换..我尝试过类似的东西:但它并不像我希望的那样有效.尝试添加fixed=T.

  10. 正则表达式 – 如何在字符串中的特定位置添加字符?

    我正在使用记事本,并希望使用正则表达式替换在字符串中的特定位置插入一个字符.例如,在每行的第6位插入一个逗号是什么意思?如果要在第六个字符后添加字符,请使用搜索和更换从技术上讲,这将用MatchGroup1替换每行的前6个字符,后跟逗号.

返回
顶部