当前位置:主页 > python教程 > pytorch的Backward过程用时太长

pytorch的Backward过程用时太长问题及解决

发布:2023-04-04 09:55:02 59


给网友朋友们带来一篇相关的编程文章,网友刘笑天根据主题投稿了本篇教程内容,涉及到pytorch、Backward、Backward过程用时太长、Backward过程、pytorch的Backward过程用时太长相关内容,已被237网友关注,涉猎到的知识点内容可以在下方电子书获得。

pytorch的Backward过程用时太长

pytorch Backward过程用时太长

问题描述

使用pytorch对网络进行训练的时候遇到一个问题,forward阶段很快(只需要几毫秒),backward阶段却用时很长(需要十多秒)。

导致这个问题的原因很容易被大家忽视,而且网上基本上没有直接的解决方案,经过一天的折腾,总算把导致这个问题的原因搞清楚了。

解决方案

导致这个问题的原因在于训练数据的浅拷贝,由于backward过程中的梯度是和模型推理过程中的张量相关的,如果这些张量在被模型使用之前没有被深拷贝,意味着backward过程的会重复从这些张量的原始内存地址中取值,这个过程非常耗时。所以为了避免这个问题,需要养成一个好习惯,就是将张量数据输入模型之前进行深拷贝

pytorch的深拷贝方式如下:

tensor_a = tensor_b.clone().detach()

Pytorch backward()简单理解

backward()是反向传播求梯度,具体实现过程如下

import torch
 
x=torch.tensor([1,2,3],requires_grad=True,dtype=torch.double)
y=x**2
z=y.mean()
z.backward()
print(x.grad)

结果

tensor([0.6667, 1.3333, 2.0000], dtype=torch.float64)

有几个重要的点

1.必须要加上requires_grad=True才能求

2. 一般来说,需要标量才能求梯度。

3.具体过程如下:

z是一个标量(1*1矩阵)分别对x1,x2,x3求偏导, 再代入x1,x2,x3的数值,就是如上程序输出的结果

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持码农之家。


参考资料

相关文章

  • Jupyter notebook中如何添加Pytorch运行环境

    发布:2023-04-01

    这篇文章主要介绍了Jupyter notebook中如何添加Pytorch运行环境,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


  • pytorch 简介及常用工具包展示

    发布:2023-03-24

    Pytorch是torch的python版本,是由Facebook开源的神经网络框架,专门针对 GPU 加速的深度神经网络(DNN)编程,这篇文章主要介绍了pytorch 简介及常用工具包展示,需要的朋友可以参考下


  • pytorch transform数据处理转c++问题

    发布:2023-04-21

    这篇文章主要介绍了pytorch transform数据处理转c++问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


  • pytorch 实现情感分类问题小结

    发布:2023-04-11

    本文主要介绍了pytorch 实现情感分类问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧


  • Pytorch Mac GPU 训练与测评实例

    发布:2023-03-03

    这篇文章主要为大家介绍了Pytorch Mac GPU 训练与测评实例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪


  • pytorch建立mobilenetV3-ssd网络并进行训练与预测方式

    发布:2023-04-05

    这篇文章主要介绍了pytorch建立mobilenetV3-ssd网络并进行训练与预测方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


  • Pytorch统计参数网络参数数量方式

    发布:2023-04-03

    这篇文章主要介绍了Pytorch统计参数网络参数数量方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


  • 关于pytorch相关部分矩阵变换函数的问题分析

    发布:2022-04-24

    这篇文章主要介绍了pytorch相关部分矩阵变换函数,包括tensor维度顺序变换BCHW顺序的调整,矩阵乘法相关函数,矩阵乘,点乘,求取矩阵对角线元素或非对角线元素的问题,本文给大家介绍的非常详细,需要的朋友可以参考下


网友讨论