当前位置:主页 > python教程 > pytorch和numpy默认浮点类型位数

pytorch和numpy默认浮点类型位数详解

发布:2023-04-22 14:05:01 59


给大家整理一篇相关的编程文章,网友史英睿根据主题投稿了本篇教程内容,涉及到pytorch numpy、numpy默认浮点类型位数、pytorch默认浮点类型、pytorch和numpy默认浮点类型位数相关内容,已被978网友关注,相关难点技巧可以阅读下方的电子资料。

pytorch和numpy默认浮点类型位数

pytorch和numpy默认浮点类型位数

numpy中默认浮点类型为64位,pytorch中默认浮点类型位32位

测试代码如下

  • numpy版本:1.19.2
  • pytorch版本:1.2.0
In [1]: import torch
In [2]: import numpy as np
# 版本信息
In [3]: "pytorch version: {}, numpy version: {}".format(torch.__version__, np.__version__)
Out[3]: 'pytorch version: 1.2.0, numpy version: 1.19.2'

# numpy
In [4]: dat_np = np.array([1,2,3], dtype="float")
In [5]: dat_np.dtype
Out[5]: dtype('float64')

# pytorch
In [6]: dat_torch = torch.tensor([1,2,3])
In [7]: dat_torch = dat_torch.float()
In [8]: dat_torch.dtype
Out[8]: torch.float32

pytorch和numpy的默认类型与转换问题

pytorch对于浮点类型默认为float32,而numpy的默认类型是float64,转换的代码:

torch.from_numpy(a).type(torch.FloatTensor)
torch.from_numpy(np.float32(a))

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持码农之家。


参考资料

相关文章

网友讨论