当前位置: 首页 > news >正文

【MindSpore产品】【数据处理功能】加入数据增强之后,报出卷积输入类型不同的问题

【功能模块】

# 图像增强
trans = [
    transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪
    transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转
    transforms.RandomRotation(degrees=20, fill_value=(255,255,255)),
    # transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w)
]
# 下载解压并加载CIFAR-10训练数据集
dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans)
ds_train = dataset_train.run()
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])

【操作步骤&问题现象】

Traceback (most recent call last):
  File "F:/8.Learning Task/MindSpore/ResNet/train.py", line 49, in <module>
    model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])
  File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 906, in train
    sink_size=sink_size)
  File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 87, in wrapper
    func(self, *args, **kwargs)
  File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 546, in _train
    self._train_process(epoch, train_dataset, list_callback, cb_params)
  File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 794, in _train_process
    outputs = self._train_network(*next_element)
  File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 586, in __call__
    out = self.compile_and_run(*args)
  File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 964, in compile_and_run
    self.compile(*inputs)
  File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 937, in compile
    _cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
  File "D:\Anaconda1\lib\site-packages\mindspore\common\api.py", line 1006, in compile
    result = self._graph_executor.compile(obj, args_list, phase, self._use_vm_mode())
TypeError: mindspore\core\utils\check_convert_utils.cc:701 _CheckTypeSame] For primitive[Conv2D], the input type must be same.
name:[w]:Ref[Tensor(F32)].
name:[x]:Tensor[UInt8].

WARNING: Logging before InitGoogleLogging() is written to STDERR
[CRITICAL] CORE(22848,1,?):2022-6-6 12:59:53 [mindspore\core\utils\check_convert_utils.cc:701] _CheckTypeSame] For primitive[Conv2D], the input type must be same.
name:[w]:Ref[Tensor(F32)].
name:[x]:Tensor[UInt8].

【日志信息】(可选,上传日志内容或者附件)

不知该如何让input的类型相同,求大佬们能看看,给个办法,谢谢!!!

总体代码如下:

# train.py

from mindvision.dataset import Cifar10
import mindspore.dataset.vision.c_transforms as transforms

# 数据集根目录
data_dir = "./datasets"
# 图像增强
# 图像增强
trans = [
    transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪
    transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转
    transforms.RandomRotation(degrees=20, fill_value=(255,255,255)),
    # transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w)
]

# 下载解压并加载CIFAR-10训练数据集
dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans)
ds_train = dataset_train.run()
step_size = ds_train.get_dataset_size()
# 下载解压并加载CIFAR-10测试数据集
dataset_val = Cifar10(path=data_dir, split='test', batch_size=6, resize=32, download=True)
ds_val = dataset_val.run()


from mindspore.train import Model
from mindvision.engine.callback import ValAccMonitor
from mindvision.classification.models.head import DenseHead
from mindspore import nn
from ResNet.resnet import resnet50

# 定义ResNet50网络
network = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channel = network.head.dense.in_channels
head = DenseHead(input_channel=in_channel, num_classes=10)
# 重置全连接层
network.head = head
# 设置学习率
num_epochs = 40
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size * num_epochs,
                        step_per_epoch=step_size, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 实例化模型
model = Model(network, loss, opt, metrics={"Accuracy": nn.Accuracy()})
# 模型训练
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])

 从报错来看出问题代码应该在你自定义网络resnet里,所以我在本地尝试未复现成功;然后从这个报错来看应该是你的卷积算子CONV2D输入不一致,有试过在CONV2D算子前将所有的输入转为float32格式然后继续呢,如果不行的话麻烦您再提供一下自定义resnet网络的脚本。

相关文章:

  • 基于Nexus搭建docker镜像源仓库
  • Estimating High-Dimensional Directed Acyclic Graphs with the PC-Algorithm
  • Linux文件查找find
  • Vue--》Vue中实现数据代理
  • 深度学习入门(十) 模型选择、过拟合和欠拟合
  • RK3399驱动开发 | 12 - AP6255 SDIO WiFi 调试(基于linux4.4.194内核)
  • 牛客网-《刷C语言百题》第二期
  • 测试开发需要掌握哪些技能?
  • 巴什博弈——范围拿物品问题
  • 【Mybatisplus】初识Mybatisplus+SpringBoot整合
  • 【编程碎笔】-Java中关于next(),nextInt(),nextLine()的深度解剖
  • 2023年荆州市高新技术企业申报条件以及奖励补贴政策(附申报时间)汇总!
  • macOS Ventura 正式版你确定不更新,好用到爆的功能你不想尝试一下?
  • 云存储架构框架设计 | 最佳实践
  • 阿里巴巴面试题- - -多线程并发篇(三十)
  • 计算机网络【UDP与TCP协议(三次握手、四次挥手)】
  • Linux进程控制
  • Unity 分享 功能 用Unity Native Share Plugin 实现链接、图片、视频等文件的分享+ 安卓 Ios 都可以,代码图文详解
  • 基于javaweb的嘟嘟二手书商城系统(java+jsp+springboot+mysql+thymeleaf+ftp)
  • 2.1.1 操作系统之进程的定义、特征、组成、组织