前言

在实现Per-FedAvg的代码时,遇到如下问题:

在这里插入图片描述


可以发现,我们需要求损失函数对模型参数的Hessian矩阵。

模型定义

我们定义一个比较简单的模型:

class ANN(nn.Module):
    def __init__(self):
        super(ANN, self).__init__()
        self.sigmoid = nn.Sigmoid()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(4, 5)

    def forward(self, data):
        x = self.fc1(data)
        x = self.fc2(x)

        return x

输出一下模型的参数:

model = ANN()
for param in model.parameters():
    print(param.size())

输出如下:

torch.Size([4, 3])
torch.Size([4])
torch.Size([5, 4])
torch.Size([5])

求解Hessian矩阵

我们首先定义数据:

data = torch.tensor([1, 2, 3], dtype=torch.float)
label = torch.tensor([1, 1, 5, 7, 8], dtype=torch.float)
pred = model(data)
loss_fn = nn.MSELoss()
loss = loss_fn(pred, label)

然后求解一阶梯度:

grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True)

输出一下grads:

(tensor([[-1.0530, -2.1059, -3.1589],
        [ 2.3615,  4.7229,  7.0844],
        [-1.5046, -3.0093, -4.5139],
        [-2.0272, -4.0543, -6.0815]], grad_fn=<TBackward0>), tensor([-1.0530,  2.3615, -1.5046, -2.0272], grad_fn=<SqueezeBackward1>), tensor([[ 0.2945, -0.2725, -0.8159, -0.6720],
        [ 0.1936, -0.1791, -0.5362, -0.4416],
        [ 1.0800, -0.9993, -2.9918, -2.4641],
        [ 1.3448, -1.2444, -3.7255, -3.0683],
        [ 1.2436, -1.1507, -3.4450, -2.8373]], grad_fn=<TBackward0>), tensor([-0.6045, -0.3972, -2.2165, -2.7600, -2.5522],
       grad_fn=<MseLossBackwardBackward0>))

可以发现一共4个Tensor,分别为损失函数对四个参数Tensor(两层,每层都有权重和偏置)的梯度。

然后针对每一个Tensor求解二阶梯度:

hessian_params = []
    for k in range(len(grads)):
        hess_params = torch.zeros_like(grads[k])
        for i in range(grads[k].size(0)):
            # 判断是w还是b
            if len(grads[k].size()) == 2:
                # w
                for j in range(grads[k].size(1)):
                    hess_params[i, j] = torch.autograd.grad(grads[k][i][j], model.parameters(), retain_graph=True)[k][i, j]
            else:
                # b
                hess_params[i] = torch.autograd.grad(grads[k][i], model.parameters(), retain_graph=True)[k][i]
        hessian_params.append(hess_params)

这里需要注意:由于模型一共两层,每一层都有权重和偏置,其中权重参数为二维,偏置参数为一维,在进行具体的二阶梯度求导时,需要进行判断。

最终得到的hessian_params是一个列表,列表中包含四个Tensor,对应损失函数对两层网络权重和偏置的二阶梯度。

以上就是PyTorch计算损失函数对模型参数的Hessian矩阵示例的详细内容,更多关于PyTorch计算损失函数Hessian矩阵的资料请关注Devmax其它相关文章!

PyTorch计算损失函数对模型参数的Hessian矩阵示例的更多相关文章

  1. HTML利用九宫格原理进行网页布局

    这篇文章主要介绍了HTML利用九宫格原理进行网页布局,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  2. ios – 围绕x轴旋转AVAssetWriter的输出180度

    我正在使用AVAssetWriter创建一个Quicktime电影文件.目前输出视频是“倒置”.理论上,我可以通过围绕水平轴旋转180度来纠正这个问题.最好的方法是什么?Appledocs和wikipedia都没有明确说明仿射变换矩阵是如何工作的.并且可能有更好的方式.解决方法如果要围绕z轴旋转视频180度,或者如果你想在x轴上反射

  3. Swift 2.0学习笔记Day 35——会使用下标吗?

    下标Swift中的下标相当于Java中的索引属性和C#中的索引器。getter访问器是一个方法,在最后使用return语句将计算结果返回。setter访问器“新属性值”是要赋值给属性值。参数的声明可以省略,系统会分配一个默认的参数newValue。可以自定义一个二维数组类型,然后通过两个下标参数访问它的元素,形式上类似于C语言的二维数组。

  4. Swift - 动画效果的实现方法总结附样例

    在iOS中,实现动画有两种方法。这三个方法都是类方法。里面可以设置动画的效果。

  5. 《从零开始学Swift》学习笔记Day 35――会使用下标吗?

    下标Swift中的下标相当于Java中的索引属性和C#中的索引器。getter访问器是一个方法,在最后使用return语句将计算结果返回。setter访问器“新属性值”是要赋值给属性值。参数的声明可以省略,系统会分配一个默认的参数newValue。可以自定义一个二维数组类型,然后通过两个下标参数访问它的元素,形式上类似于C语言的二维数组。

  6. 用Swift3实现n*n阶矩阵逆时针输出

  7. 用Swift3实现n*n阶矩阵顺时针输出

  8. 数组 – Swift中的二维数组

    我对于Swift中的2D数组感到困惑。如果我错了,请你纠正我。首先;声明一个空数组:其次填充数组。最后,编辑数组中的元素这可能是noob问题,但目标C后,我真的很困惑..定义可变数组要么:OR如果你需要一个预定义大小的数组:在位置更改元素要么更改子数组要么要么如果你有3×2数组或0(零),现在你有:因此,请注意,子数组是可变的,您可以重新定义表示矩阵的初始数组。在访问前检查大小/边界备注:3和N维数组的相同标记规则。

  9. 2.9 多维数组的创建和遍历 [Swift原创教程]

    它由两个数组元素组成。课程配套素材下载地址:资料下载

  10. Swift泛型:需要一种类型的加法和乘法能力

    现在,我想确保类型T可以比较,所以我可以写这个:这可能是有用的,以防我想比较2个矩阵,这意味着比较它们的值。我也想提供两个矩阵求和的能力,所以我也应该添加一个协议,要求可以添加由矩阵用户给出的类型“T”同样,我也想说:问题1:可以使用什么协议名称?更具体地说,这个)–>矩阵{产生错误:使用未声明的类型“T”。我的意思是结果将是一个具有相同类型的两个输入矩阵的矩阵,但我可能完全弄乱了语法。

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部