当前位置:主页 > python教程 > pytorch nn.Flatten()函数详解

pytorch中nn.Flatten()函数详解及示例

发布:2023-03-03 10:00:02 59


给大家整理了相关的编程文章,网友辛天华根据主题投稿了本篇教程内容,涉及到pytorch、nn.flatten()函数、pytorch、nn.flatten()、nn.Flatten、pytorch nn.Flatten()函数详解相关内容,已被643网友关注,如果对知识点想更进一步了解可以在下方电子资料中获取。

pytorch nn.Flatten()函数详解

torch.nn.Flatten(start_dim=1, end_dim=- 1)

作用:将连续的维度范围展平为张量。 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。

有俩个参数,start_dim和end_dim,分别表示开始的维度和终止的维度,默认值分别是1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)

同理,如果我这么写:

self.flat = nn.Flatten(start_dim=2, end_dim=3)

那么意思就是从第二维度开始,到第三维度全部给展平,也就是将2、3两个维度展平。

官网给出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#开头的代码是注释

整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。

1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二个位置代表的维度,也就是样例中的1。

因此进行展平后的结果也就是[32,1×5×5]➡[32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0, 2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。

因此结果就是[32×1×5,5]➡[160,5]

因此进行展平后的结果也就是[32,1*5*5]➡[32,25]

示例1

卷积公式

import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
    nn.Conv2d(1, 32, 5, 1, 1),  # 通过卷积,得到torch.size([32, 32, 3, 3]
    nn.Flatten())

output = m(input)
print(output.size())

>> torch.Size([32, 288])

示例2

import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
    nn.Conv2d(1, 32, 5, 1, 1),  # 通过卷积,得到torch.size([32, 32, 3, 3]
    nn.Flatten(start_dim=0))

output = m(input)
print(output.size())

>>torch.Size([9216])

总结

到此这篇关于pytorch中nn.Flatten()函数详解的文章就介绍到这了,更多相关pytorch nn.Flatten()函数详解内容请搜索码农之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持码农之家!


参考资料

相关文章

  • pytorch/transformers 最后一层不加激活函数的原因分析

    发布:2023-03-02

    这里给大家解释一下为什么bert模型最后都不加激活函数,是因为损失函数选择的原因,本文通过示例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧


  • 如何从PyTorch中获取过程特征图实例详解

    发布:2023-03-03

    特征提取是图像处理过程中常需要用到的一种方法,其效果好坏对模型的泛化能力有至关重要的影响,下面这篇文章主要给大家介绍了关于如何从PyTorch中获取过程特征图的相关资料,需要的朋友可以参考下


  • Pytorch中的 torch.distributions库详解

    发布:2023-03-25

    这篇文章主要介绍了Pytorch中的 torch.distributions库,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下


  • pytorch多GPU训练实例与性能对比

    发布:2020-01-16

    今天小编就为大家分享一篇关于pytorch多GPU训练实例与性能对比分析,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧


  • GPU版本安装Pytorch的最新方法步骤

    发布:2023-04-11

    最近深度学习需要用GPU版本的pytorch来加速运算,所以下面这篇文章主要给大家介绍了关于GPU版本安装Pytorch的最新方法步骤,文中通过实例代码介绍的非常详细,需要的朋友可以参考下


  • PyTorch 迁移学习实战

    发布:2023-03-06

    本文主要介绍了PyTorch 迁移学习实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧


  • Pytorch中的数据转换Transforms与DataLoader方式

    发布:2023-04-23

    这篇文章主要介绍了Pytorch中的数据转换Transforms与DataLoader方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


  • pytorch的Backward过程用时太长问题及解决

    发布:2023-04-04

    这篇文章主要介绍了pytorch的Backward过程用时太长问题及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


网友讨论