【MindSpore产品】【数据处理功能】加入数据增强之后,报出卷积输入类型不同的问题
【功能模块】
# 图像增强 trans = [ transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪 transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转 transforms.RandomRotation(degrees=20, fill_value=(255,255,255)), # transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w) ]
# 下载解压并加载CIFAR-10训练数据集 dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans) ds_train = dataset_train.run()
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])
【操作步骤&问题现象】
Traceback (most recent call last):
File "F:/8.Learning Task/MindSpore/ResNet/train.py", line 49, in <module>
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])
File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 906, in train
sink_size=sink_size)
File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 87, in wrapper
func(self, *args, **kwargs)
File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 546, in _train
self._train_process(epoch, train_dataset, list_callback, cb_params)
File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 794, in _train_process
outputs = self._train_network(*next_element)
File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 586, in __call__
out = self.compile_and_run(*args)
File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 964, in compile_and_run
self.compile(*inputs)
File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 937, in compile
_cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
File "D:\Anaconda1\lib\site-packages\mindspore\common\api.py", line 1006, in compile
result = self._graph_executor.compile(obj, args_list, phase, self._use_vm_mode())
TypeError: mindspore\core\utils\check_convert_utils.cc:701 _CheckTypeSame] For primitive[Conv2D], the input type must be same.
name:[w]:Ref[Tensor(F32)].
name:[x]:Tensor[UInt8].
WARNING: Logging before InitGoogleLogging() is written to STDERR
[CRITICAL] CORE(22848,1,?):2022-6-6 12:59:53 [mindspore\core\utils\check_convert_utils.cc:701] _CheckTypeSame] For primitive[Conv2D], the input type must be same.
name:[w]:Ref[Tensor(F32)].
name:[x]:Tensor[UInt8].
【日志信息】(可选,上传日志内容或者附件)
不知该如何让input的类型相同,求大佬们能看看,给个办法,谢谢!!!
总体代码如下:
# train.py
from mindvision.dataset import Cifar10 import mindspore.dataset.vision.c_transforms as transforms # 数据集根目录 data_dir = "./datasets" # 图像增强 # 图像增强 trans = [ transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪 transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转 transforms.RandomRotation(degrees=20, fill_value=(255,255,255)), # transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w) ] # 下载解压并加载CIFAR-10训练数据集 dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans) ds_train = dataset_train.run() step_size = ds_train.get_dataset_size() # 下载解压并加载CIFAR-10测试数据集 dataset_val = Cifar10(path=data_dir, split='test', batch_size=6, resize=32, download=True) ds_val = dataset_val.run() from mindspore.train import Model from mindvision.engine.callback import ValAccMonitor from mindvision.classification.models.head import DenseHead from mindspore import nn from ResNet.resnet import resnet50 # 定义ResNet50网络 network = resnet50(pretrained=True) # 全连接层输入层的大小 in_channel = network.head.dense.in_channels head = DenseHead(input_channel=in_channel, num_classes=10) # 重置全连接层 network.head = head # 设置学习率 num_epochs = 40 lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size * num_epochs, step_per_epoch=step_size, decay_epoch=num_epochs) # 定义优化器和损失函数 opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 实例化模型 model = Model(network, loss, opt, metrics={"Accuracy": nn.Accuracy()}) # 模型训练 model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])
从报错来看出问题代码应该在你自定义网络resnet里,所以我在本地尝试未复现成功;然后从这个报错来看应该是你的卷积算子CONV2D输入不一致,有试过在CONV2D算子前将所有的输入转为float32格式然后继续呢,如果不行的话麻烦您再提供一下自定义resnet网络的脚本。