1.什么是广播机制

根据线性代数的运算规则我们知道,矩阵运算往往都是在两个矩阵维度相同或者相匹配时才能运算。比如加减法需要两个矩阵的维度相同,乘法需要前一个矩阵的列数与后一个矩阵的行数相等。那么在 numpy、tensor 里也是同样的道理,但是在机器学习的某些算法中会出现两个维度不相同也不匹配的矩阵进行运算,那么这时候就需要用广播机制来解决,通过广播机制,其tensor参数可以自动扩展为相等大小(不需要复制数据)。下面我们以tensor为例来解释什么是广播机制。

2.广播机制的规则

先来说下广播机制的规则,只有遵循下面的规则两个张量才可以进行广播运算。

每个tensor至少有一个维度;

遍历tensor所有维度时,从末尾开始遍历(从右往左开始遍历),两个tensor存在下列情况

tensor维度相等。

tensor维度不等且其中一个维度为1或者不存在。

满足上面的条件才可以进行广播机制。

3.代码举例

相同维度,一定可以 broadcast:

import torch
x = torch.rand(1, 2, 3)
y = torch.rand(1, 2, 3)
z = x   y
print(x.shape)
print(y.shape)
print(z.shape)
print(x)
print(y)
print(z)

输出结果如下:

torch.Size([1, 2, 3])
torch.Size([1, 2, 3])
torch.Size([1, 2, 3])
tensor([[[0.0322, 0.2378, 0.4711],
         [0.9191, 0.0802, 0.4002]]])
tensor([[[0.5645, 0.9541, 0.3089],
         [0.7633, 0.7400, 0.7507]]])
tensor([[[0.5966, 1.1919, 0.7800],
         [1.6825, 0.8202, 1.1509]]])

有一个张量没有维度,一定不可以进行 broadcast:

import torch
x = torch.rand(0)
y = torch.rand(1, 2, 3)
print(x.shape)
print(y.shape)
z = x   y
print(z.shape)
print(x)
print(y)
print(z)

输出结果:

torch.Size([0])
torch.Size([1, 2, 3])
Traceback (most recent call last):
  File "D:/program/Test/broadcast/test.py", line 8, in <module>
    z = x y
RuntimeError: The size of tensor a (0) must match the size of tensor b (3) at non-singleton dimension 2

有一个张量缺少维度,一定可以进行 broadcast:

import torch
x = torch.rand(1, 2, 3, 4)
y = torch.rand(2, 3, 4)
print(x.shape)
print(y.shape)
z = x   y
print(z.shape)
print(x)
print(y)
print(z)

输出结果:

torch.Size([1, 2, 3, 4])
torch.Size([2, 3, 4])
torch.Size([1, 2, 3, 4])
tensor([[[[0.0094, 0.1863, 0.2657, 0.3782],
          [0.3296, 0.7454, 0.2080, 0.4156],
          [0.2092, 0.5414, 0.1053, 0.3872]],

         [[0.8161, 0.3554, 0.7352, 0.2116],
          [0.7459, 0.1662, 0.7555, 0.4548],
          [0.2611, 0.0353, 0.1862, 0.5948]]]])
tensor([[[0.4637, 0.3938, 0.2039, 0.3892],
         [0.4146, 0.8713, 0.3947, 0.5345],
         [0.2401, 0.3800, 0.3747, 0.8381]],

        [[0.0459, 0.1242, 0.3529, 0.1527],
         [0.2361, 0.2850, 0.8671, 0.8040],
         [0.6575, 0.4075, 0.8156, 0.2638]]])
tensor([[[[0.4730, 0.5801, 0.4695, 0.7674],
          [0.7442, 1.6167, 0.6027, 0.9501],
          [0.4493, 0.9214, 0.4800, 1.2253]],

         [[0.8620, 0.4796, 1.0881, 0.3643],
          [0.9820, 0.4512, 1.6227, 1.2588],
          [0.9186, 0.4428, 1.0018, 0.8586]]]])

上面的张量y跟张量x相比缺少一个维度,根据广播机制的规则我们从最后一个维度进行匹配,后面三个维度都一样,张量y的缺少一个维度,于是触发广播机制。

两个张量的维度不相等,其中有一个张量的对应维度为1或者缺失,一定可以进行 broadcast:

import torch
x = torch.rand(1, 2, 3, 4)
y = torch.rand(2, 1, 1)
print(x.shape)
print(y.shape)
z = x   y
print(z.shape)
print(x)
print(y)
print(z)

输出结果:

torch.Size([1, 2, 3, 4])
torch.Size([2, 1, 1])
torch.Size([1, 2, 3, 4])
tensor([[[[0.8670, 0.0134, 0.7929, 0.4109],
          [0.3595, 0.8457, 0.2819, 0.8470],
          [0.5040, 0.9281, 0.9161, 0.7305]],

         [[0.3798, 0.3866, 0.4680, 0.5744],
          [0.6984, 0.6501, 0.2235, 0.3099],
          [0.9861, 0.8598, 0.7635, 0.3238]]]])
tensor([[[0.3393]],

        [[0.1775]]])
tensor([[[[1.2062, 0.3527, 1.1322, 0.7501],
          [0.6987, 1.1850, 0.6212, 1.1863],
          [0.8433, 1.2674, 1.2554, 1.0698]],

         [[0.5574, 0.5641, 0.6455, 0.7519],
          [0.8759, 0.8276, 0.4010, 0.4875],
          [1.1636, 1.0373, 0.9410, 0.5013]]]])

以上就是广播机制的操作,只要记住几个规则就行了,注意tensor在进行运算的时候是从后往前匹配运算的。

4.原地操作

在进行广播机制的时候我们要注意一个原地操作运算,什么是原地操作运算?原地操作运算就是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch中经常加后缀“”来代表原地操作符,例:.add _()、.scatter(),原地操作不允许tensor使用广播机制那样来改变张量形状维度大小,如下例子所示。

import torch
x = torch.rand(1,3,1)
y = torch.rand(3,1,7)
print(x.shape)
print(y.shape)
z = x.add_(y)
print(z.shape)
print(x)
print(y)
print(z)

输出结果:

torch.Size([1, 3, 1])
torch.Size([3, 1, 7])
Traceback (most recent call last):
  File "D:/program/Test/broadcast/test.py", line 8, in <module>
    z = x.add_(y)
RuntimeError: output with shape [1, 3, 1] doesn't match the broadcast shape [3, 3, 7]

到此这篇关于Broadcast广播机制在Pytorch Tensor Numpy中的使用详解的文章就介绍到这了,更多相关Pytorch Broadcast内容请搜索Devmax以前的文章或继续浏览下面的相关文章希望大家以后多多支持Devmax!

Broadcast广播机制在Pytorch Tensor Numpy中的使用详解的更多相关文章

  1. 详解Python NumPy中矩阵和通用函数的使用

    在NumPy中,矩阵是ndarray的子类,与数学概念中的矩阵一样,NumPy中的矩阵也是二维的,可以使用 mat 、 matrix 以及 bmat 函数来创建矩阵。本文将详细讲解NumPy中矩阵和通用函数的使用,感兴趣的可以了解一下

  2. Python数据分析 Numpy 的使用方法

    这篇文章主要介绍了Python数据分析 Numpy 的使用方法,Numpy 是一个Python扩展库,专门做科学计算,也是大部分Python科学计算库的基础,关于其的使用方法,需要的小伙伴可以参考下面文章内容

  3. Python Numpy中数组的集合操作详解

    这篇文章主要为大家详细介绍了Python Numpy中数组的一些集合操作方法,文中的示例代码讲解详细,对我们学习Python有一定帮助,需要的可以参考一下

  4. Python使用pytorch动手实现LSTM模块

    这篇文章主要介绍了Python使用pytorch动手实现LSTM模块,LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出

  5. Pytorch搭建yolo3目标检测平台实现源码

    这篇文章主要为大家介绍了Pytorch搭建yolo3目标检测平台实现源码,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  6. PyTorch搭建双向LSTM实现时间序列负荷预测

    这篇文章主要为大家介绍了PyTorch搭建双向LSTM实现时间序列负荷预测,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  7. Numpy安装、升级与卸载的详细图文教程

    Python官网上的发行版是不包含 NumPy 模块的,下面这篇文章主要给大家介绍了关于Numpy安装、升级与卸载的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下

  8. pytorch使用nn.Moudle实现逻辑回归

    这篇文章主要为大家详细介绍了pytorch使用nn.Moudle实现逻辑回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  9. pytorch加载自己的图片数据集的2种方法详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力,下面这篇文章主要给大家介绍了关于pytorch加载自己的图片数据集的2种方法,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  10. python如何获取tensor()数据类型中的值

    这篇文章主要介绍了python如何获取tensor()数据类型中的值,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部